Skip to content

Commit 7ea96a7

Browse files
peterbell10pytorchmergebot
authored andcommitted
[quant][fx] Don't assume bias is a keyword-argument (pytorch#71426)
Summary: Pull Request resolved: pytorch#71426 dbr quantization makes faulty assumptions about which arguments are passed as keyword arguments and which are passed as positional arguments. This happens to work currently due to a quirk of how `__torch_function__` is implemented in python functions, but will break when the operators are moved to C++. Test Plan: Imported from OSS Reviewed By: george-qi Differential Revision: D33754262 Pulled By: albanD fbshipit-source-id: 63515d7a166449726e1beaba6659443b6261742d (cherry picked from commit f7b1884)
1 parent a5e27c4 commit 7ea96a7

File tree

4 files changed

+43
-22
lines changed

4 files changed

+43
-22
lines changed

torch/ao/quantization/_dbr/auto_trace_rewriter.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,12 @@ def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
147147

148148
# TODO move op-specific logic out of here
149149
if target is torch.ops.quantized.linear:
150-
new_args = [*args]
151-
new_args.append(additional_kwargs['scale'])
152-
new_args.append(additional_kwargs['zero_point'])
153-
args = tuple(new_args)
154-
del kwargs['bias']
150+
def linear_rewrite_args(input, weight, bias=None):
151+
return (input, weight,
152+
additional_kwargs['scale'],
153+
additional_kwargs['zero_point'])
154+
args = linear_rewrite_args(*args, **kwargs)
155+
kwargs = {}
155156
elif old_target != F.conv2d or target is F.conv2d:
156157
kwargs.update(**additional_kwargs)
157158
else:

torch/ao/quantization/_dbr/model_utils.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ObserverBase,
1212
FakeQuantizeBase,
1313
)
14+
from typing import Optional
1415

1516
def pack_weights_for_functionals(
1617
module: torch.nn.Module,
@@ -66,10 +67,18 @@ def pack_weights_for_functionals(
6667

6768
elif seen_op_info.type == F.linear:
6869
# fetch all the info needed for packed params
69-
assert seen_op_info.packable_tensor_idx_to_name[1] is not None
70-
weight = getattr(module, seen_op_info.packable_tensor_idx_to_name[1])
71-
bias_name = seen_op_info.packable_tensor_kwarg_name_to_name['bias']
72-
bias = getattr(module, bias_name) if bias_name else None
70+
def get_tensor_param_name(idx: int, name: str) -> Optional[str]:
71+
param_name = seen_op_info.packable_tensor_idx_to_name.get(idx, None)
72+
if param_name is not None:
73+
return param_name
74+
return seen_op_info.packable_tensor_kwarg_name_to_name.get(name, None)
75+
76+
weight_name = get_tensor_param_name(1, 'weight')
77+
assert weight_name is not None
78+
weight = getattr(module, weight_name)
79+
80+
bias_name = get_tensor_param_name(2, 'bias')
81+
bias = getattr(module, bias_name) if bias_name is not None else None
7382

7483
# quantize the weight
7584
# TODO: create weight observers from qconfig.weight

torch/ao/quantization/_dbr/quantization_state.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def op_convert_before_hook(
441441

442442
# TODO move op-specific logic out of here
443443
if op is torch.ops.quantized.linear:
444-
del kwargs['bias']
444+
kwargs.pop('bias', None)
445445

446446
return op, tuple(new_args), kwargs
447447

@@ -666,7 +666,7 @@ def _first_call_op_prepare_before_hook_create_subgraphs(
666666
of this op in `self`.
667667
"""
668668
op_packing_only_uses_module_attributes = \
669-
get_op_packing_only_uses_module_attributes(op, args, root_module)
669+
get_op_packing_only_uses_module_attributes(op, args, kwargs, root_module)
670670
arg_tensor_infos: List[Optional[QTensorInfo]] = []
671671
for arg in args:
672672
if isinstance(arg, (list, tuple)):
@@ -684,6 +684,8 @@ def _first_call_op_prepare_before_hook_create_subgraphs(
684684
packable_tensor_arg_idxs = get_packable_tensor_arg_idxs(op)
685685
if packable_tensor_arg_idxs is not None:
686686
for arg_idx in packable_tensor_arg_idxs:
687+
if arg_idx >= len(args):
688+
continue
687689
arg = args[arg_idx]
688690
param_name = get_param_name(root_module, arg)
689691
packable_tensor_idx_to_name[arg_idx] = param_name
@@ -697,6 +699,8 @@ def _first_call_op_prepare_before_hook_create_subgraphs(
697699
get_packable_tensor_kwarg_names(op)
698700
if packable_tensor_kwarg_names is not None:
699701
for kwarg_name in packable_tensor_kwarg_names:
702+
if kwarg_name not in kwargs:
703+
continue
700704
kwarg = kwargs[kwarg_name]
701705
kwarg_name_on_module = get_param_name(root_module, kwarg)
702706
packable_tensor_kwarg_name_to_name[kwarg_name] = \

torch/ao/quantization/_dbr/utils.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,15 @@ def get_func_output_dtype_type(
301301

302302
return FuncOutputDTypeType.DTYPE_DEPENDS_ON_QCONFIG
303303

304+
def get_weight_argument_info(op: Callable) -> Optional[Tuple[int, str]]:
305+
if op in (F.linear, F.conv2d):
306+
return (1, 'weight')
307+
return None
308+
304309
def get_op_packing_only_uses_module_attributes(
305310
op: Callable,
306311
args: Tuple[Any, ...],
312+
kwargs: Dict[str, Any],
307313
module: torch.nn.Module,
308314
) -> bool:
309315
"""
@@ -316,12 +322,13 @@ def get_op_packing_only_uses_module_attributes(
316322
"""
317323
# check for ops which need packed weights but the weights are
318324
# coming from another function
319-
packable_tensor_arg_idxs = get_packable_tensor_arg_idxs(op)
320-
if packable_tensor_arg_idxs is not None:
321-
for arg_idx in packable_tensor_arg_idxs:
322-
arg_name_in_root = get_param_name(module, args[arg_idx])
323-
if arg_name_in_root is None:
324-
return False
325+
info = get_weight_argument_info(op)
326+
if info is not None:
327+
idx, name = info
328+
param_name = args[idx] if idx < len(args) else kwargs[name]
329+
arg_name_in_root = get_param_name(module, param_name)
330+
if arg_name_in_root is None:
331+
return False
325332
return True
326333

327334
def get_quantized_op(
@@ -372,16 +379,16 @@ def get_packable_tensor_arg_idxs(op: Callable) -> Optional[List[int]]:
372379
if op == F.conv2d:
373380
return [1, 2]
374381
elif op == F.linear:
375-
return [1]
382+
return [1, 2]
376383
return None
377384

378385
def get_packable_tensor_kwarg_names(op: Callable) -> Optional[List[str]]:
379386
"""
380387
Returns tensor kwarg names which correspond to parameters which will
381388
need to be packed.
382389
"""
383-
if op == F.linear:
384-
return ['bias']
390+
if op in (F.conv2d, F.linear):
391+
return ['weight', 'bias']
385392
return None
386393

387394
def get_param_name(module: torch.nn.Module, arg: Any) -> Optional[str]:
@@ -409,8 +416,8 @@ def get_packable_arg_idxs(op: Callable) -> Optional[List[int]]:
409416
# weight, bias, stride, padding, dilation, groups
410417
return [1, 2, 3, 4, 5, 6]
411418
elif op == F.linear:
412-
# weight
413-
return [1]
419+
# weight, bias
420+
return [1, 2]
414421
return None
415422

416423
def get_weight_arg_idx(op: Callable) -> Optional[int]:

0 commit comments

Comments
 (0)