Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 81f5f33

Browse files
authored
Cherry pick 012 distilbert qat fixes (#726)
* fix QATWrapper not properly overwritting qconfig properties for symmetric activations (#724) * re-add fix symmetric zero points for unit8 quantization (#604) (#725)
1 parent fe598cb commit 81f5f33

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

src/sparseml/pytorch/sparsification/quantization/helpers.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,41 +119,37 @@ class QConfigProperties:
119119
Default is torch.qint8.
120120
:param activation_bits: number of bits for activations. Default is 8.
121121
:param weight_bits: number of bits for weights. Default is 8.
122+
:param tensorrt: if True sets quantization configuration for compatibility with
123+
explict quantization as supported by TensorRT 8.2.
122124
"""
123125

124-
_symmetric_activations: Optional[bool] = None
125-
_symmetric_weights: Optional[bool] = None
126+
_symmetric_activations: bool = False
127+
_symmetric_weights: bool = True
126128
reduce_range: bool = False
127129
activation_dtype: torch.dtype = torch.quint8
128130
weight_dtype: torch.dtype = torch.qint8
129131
activation_bits: int = 8
130132
weight_bits: int = 8
131133
activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
132134
weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
135+
tensorrt: bool = False
133136

134137
@property
135138
def symmetric_activations(self) -> bool:
136-
if self._symmetric_activations:
137-
return self._symmetric_activations
138-
else:
139-
return False
139+
# always use symmetric activations in tensorrt mode
140+
return self.tensorrt or self._symmetric_activations
140141

141142
@symmetric_activations.setter
142143
def symmetric_activations(self, value: bool):
143-
if self._symmetric_activations is None:
144-
self._symmetric_activations = value
144+
self._symmetric_activations = value
145145

146146
@property
147147
def symmetric_weights(self) -> bool:
148-
if self._symmetric_weights:
149-
return self._symmetric_weights
150-
else:
151-
return True
148+
return self.tensorrt or self._symmetric_weights
152149

153150
@symmetric_weights.setter
154151
def symmetric_weights(self, value: bool):
155-
if self._symmetric_weights is None:
156-
self._symmetric_weights = value
152+
self._symmetric_weights = value
157153

158154

159155
class QATWrapper(Module):
@@ -365,9 +361,10 @@ def _load_qconfigs(
365361
f"Found string with value {qconfig} in {name}"
366362
)
367363

368-
qproperties.symmetric_activations = qconfig == "symmetric"
364+
qproperties_idx = deepcopy(qproperties)
365+
qproperties_idx.symmetric_activations = qconfig == "symmetric"
369366

370-
qconfigs[idx] = get_qat_qconfig(qproperties)
367+
qconfigs[idx] = get_qat_qconfig(qproperties_idx)
371368

372369
return qconfigs
373370

@@ -578,6 +575,11 @@ def fix_observer_quant_range(module: Module):
578575
fake_quantize.quant_min is None
579576
or fake_quantize.quant_max is None
580577
or (observer.quant_min is not None or observer.quant_max is not None)
578+
or ( # do not propagate default uint8 symmetric range
579+
observer.qscheme == torch.per_tensor_symmetric
580+
and fake_quantize.quant_min == 0
581+
and fake_quantize.quant_max == 255
582+
)
581583
):
582584
continue
583585
observer.quant_min = fake_quantize.quant_min

src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,8 @@ def _enable_module_qat(self, module: Module):
615615
"Overriding quantization scheme to symmetric int8 "
616616
"for both weights and activations because tensorrt flag is True."
617617
)
618-
qproperties.symmetric_activations = True
618+
qproperties.tensorrt = True
619619
qproperties.activation_dtype = torch.qint8
620-
qproperties.symmetric_weights = True
621620
qproperties.weight_dtype = torch.qint8
622621

623622
qconfig = get_qat_qconfig(qproperties)

0 commit comments

Comments
 (0)