Skip to content

Commit 8d31706

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX] Support restricted quantized range for activation.
PyTorch restricts activations to be in the range (0, 127). In ONNX, the supported ranges are (0, 255) and (-128, 127), respectfully, uint8 and int8. This PR extends support for range (0, 127), by adding additional clipping when detected. Pull Request resolved: pytorch#76055 Approved by: https://github.com/garymm
1 parent cada2cd commit 8d31706

File tree

3 files changed

+49
-10
lines changed

3 files changed

+49
-10
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8731,6 +8731,32 @@ def forward(self, input):
87318731
x = torch.randn(6, 4, 3, 3)
87328732
self.run_test(FakeQuantizePerChannelModel(), (x))
87338733

8734+
@skipIfUnsupportedMinOpsetVersion(13)
8735+
@disableScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear
8736+
def test_fake_quantize_activation(self):
8737+
from torch import quantization
8738+
m = torch.nn.Linear(1, 1)
8739+
m.qconfig = quantization.QConfig(
8740+
activation=quantization.default_fake_quant,
8741+
weight=quantization.default_per_channel_weight_fake_quant)
8742+
quantization.prepare_qat(m.train(), inplace=True)
8743+
m.apply(quantization.enable_observer)
8744+
m.apply(quantization.enable_fake_quant)
8745+
for module in m.modules():
8746+
if isinstance(module, quantization.FakeQuantize):
8747+
module.calculate_qparams()
8748+
8749+
m.apply(quantization.disable_observer)
8750+
m.eval()
8751+
8752+
# Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
8753+
# while standard 8bit quantization range is (-128, 127) or (0, 255).
8754+
# Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
8755+
m.weight = torch.nn.Parameter(torch.tensor([[1.], [1.], [1.]]))
8756+
m.bias = torch.nn.Parameter(torch.tensor([0.]))
8757+
x = torch.tensor([[150.], [127.], [-5.]])
8758+
self.run_test(m, x)
8759+
87348760
def test_batchnorm_training(self):
87358761
class MyModule(torch.nn.Module):
87368762
def __init__(self):

torch/onnx/symbolic_opset10.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,12 @@ def embedding_bag(g,
300300

301301
@parse_args("v", "v", "v", "i", "i")
302302
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
303+
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
304+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
305+
if (quant_min, quant_max) == (0, 127):
306+
sym_help._onnx_opset_unsupported_detailed(
307+
"fake_quantize_per_tensor_affine", 10, 13,
308+
"Quantize range (0, 127) not supported, requires opset 13 Clip")
303309
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
304310
raise RuntimeError(
305311
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "

torch/onnx/symbolic_opset13.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.onnx.symbolic_helper import parse_args, _unimplemented
88
from torch.onnx.symbolic_opset9 import (overload_by_arg_count, _maybe_cast_reduce_op_input,
99
nonzero, expand, zeros, ones, size, linear, conv2d,
10-
relu)
10+
relu, unused)
1111
from torch.onnx.symbolic_opset11 import unsqueeze
1212
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
1313

@@ -132,33 +132,40 @@ def where(g, condition, self=None, other=None, _outputs=None):
132132

133133
@parse_args("v", "v", "v", "i", "i", "i")
134134
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
135-
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
135+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
136+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
137+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
136138
raise RuntimeError(
137-
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
139+
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
138140
"Got ({}, {})".format(quant_min, quant_max))
139141
# ONNX defines zero_point to be int8 or uint8
140142
if quant_min == 0:
141143
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
142144
else:
143145
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
144-
return g.op(
145-
"DequantizeLinear",
146-
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
147-
scale, zero_point, axis_i=axis)
146+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
147+
if (quant_min, quant_max) == (0, 127):
148+
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
149+
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
148150

149151
@parse_args("v", "v", "v", "i", "i")
150152
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
151-
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
153+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
154+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
155+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
152156
raise RuntimeError(
153-
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
157+
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
154158
"Got ({}, {})".format(quant_min, quant_max))
155159
if quant_min == 0:
156160
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
157161
else:
158162
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
159163
if scale.type().scalarType() != "Float":
160164
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
161-
return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
165+
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
166+
if (quant_min, quant_max) == (0, 127):
167+
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
168+
return g.op("DequantizeLinear", quantized, scale, zero_point)
162169

163170
def _reduce_op_symbolic(onnx_op_name):
164171
def symbolic(g, self, dim=None, keepdim=None):

0 commit comments

Comments
 (0)