@@ -119,41 +119,37 @@ class QConfigProperties:
119119 Default is torch.qint8.
120120 :param activation_bits: number of bits for activations. Default is 8.
121121 :param weight_bits: number of bits for weights. Default is 8.
122+ :param tensorrt: if True sets quantization configuration for compatibility with
123+ explict quantization as supported by TensorRT 8.2.
122124 """
123125
124- _symmetric_activations : Optional [ bool ] = None
125- _symmetric_weights : Optional [ bool ] = None
126+ _symmetric_activations : bool = False
127+ _symmetric_weights : bool = True
126128 reduce_range : bool = False
127129 activation_dtype : torch .dtype = torch .quint8
128130 weight_dtype : torch .dtype = torch .qint8
129131 activation_bits : int = 8
130132 weight_bits : int = 8
131133 activation_qconfig_kwargs : Dict [str , Any ] = field (default_factory = dict )
132134 weight_qconfig_kwargs : Dict [str , Any ] = field (default_factory = dict )
135+ tensorrt : bool = False
133136
134137 @property
135138 def symmetric_activations (self ) -> bool :
136- if self ._symmetric_activations :
137- return self ._symmetric_activations
138- else :
139- return False
139+ # always use symmetric activations in tensorrt mode
140+ return self .tensorrt or self ._symmetric_activations
140141
141142 @symmetric_activations .setter
142143 def symmetric_activations (self , value : bool ):
143- if self ._symmetric_activations is None :
144- self ._symmetric_activations = value
144+ self ._symmetric_activations = value
145145
146146 @property
147147 def symmetric_weights (self ) -> bool :
148- if self ._symmetric_weights :
149- return self ._symmetric_weights
150- else :
151- return True
148+ return self .tensorrt or self ._symmetric_weights
152149
153150 @symmetric_weights .setter
154151 def symmetric_weights (self , value : bool ):
155- if self ._symmetric_weights is None :
156- self ._symmetric_weights = value
152+ self ._symmetric_weights = value
157153
158154
159155class QATWrapper (Module ):
@@ -365,9 +361,10 @@ def _load_qconfigs(
365361 f"Found string with value { qconfig } in { name } "
366362 )
367363
368- qproperties .symmetric_activations = qconfig == "symmetric"
364+ qproperties_idx = deepcopy (qproperties )
365+ qproperties_idx .symmetric_activations = qconfig == "symmetric"
369366
370- qconfigs [idx ] = get_qat_qconfig (qproperties )
367+ qconfigs [idx ] = get_qat_qconfig (qproperties_idx )
371368
372369 return qconfigs
373370
@@ -578,6 +575,11 @@ def fix_observer_quant_range(module: Module):
578575 fake_quantize .quant_min is None
579576 or fake_quantize .quant_max is None
580577 or (observer .quant_min is not None or observer .quant_max is not None )
578+ or ( # do not propagate default uint8 symmetric range
579+ observer .qscheme == torch .per_tensor_symmetric
580+ and fake_quantize .quant_min == 0
581+ and fake_quantize .quant_max == 255
582+ )
581583 ):
582584 continue
583585 observer .quant_min = fake_quantize .quant_min
0 commit comments