|
7 | 7 | from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
8 | 8 | from torch.onnx.symbolic_opset9 import (overload_by_arg_count, _maybe_cast_reduce_op_input,
|
9 | 9 | nonzero, expand, zeros, ones, size, linear, conv2d,
|
10 |
| - relu) |
| 10 | + relu, unused) |
11 | 11 | from torch.onnx.symbolic_opset11 import unsqueeze
|
12 | 12 | from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
|
13 | 13 |
|
@@ -132,33 +132,40 @@ def where(g, condition, self=None, other=None, _outputs=None):
|
132 | 132 |
|
133 | 133 | @parse_args("v", "v", "v", "i", "i", "i")
|
134 | 134 | def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
|
135 |
| - if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: |
| 135 | + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). |
| 136 | + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 |
| 137 | + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: |
136 | 138 | raise RuntimeError(
|
137 |
| - "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " |
| 139 | + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " |
138 | 140 | "Got ({}, {})".format(quant_min, quant_max))
|
139 | 141 | # ONNX defines zero_point to be int8 or uint8
|
140 | 142 | if quant_min == 0:
|
141 | 143 | zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
|
142 | 144 | else:
|
143 | 145 | zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
|
144 |
| - return g.op( |
145 |
| - "DequantizeLinear", |
146 |
| - g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis), |
147 |
| - scale, zero_point, axis_i=axis) |
| 146 | + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) |
| 147 | + if (quant_min, quant_max) == (0, 127): |
| 148 | + quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8))) |
| 149 | + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) |
148 | 150 |
|
149 | 151 | @parse_args("v", "v", "v", "i", "i")
|
150 | 152 | def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
|
151 |
| - if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: |
| 153 | + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). |
| 154 | + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 |
| 155 | + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: |
152 | 156 | raise RuntimeError(
|
153 |
| - "For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " |
| 157 | + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " |
154 | 158 | "Got ({}, {})".format(quant_min, quant_max))
|
155 | 159 | if quant_min == 0:
|
156 | 160 | zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
|
157 | 161 | else:
|
158 | 162 | zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
|
159 | 163 | if scale.type().scalarType() != "Float":
|
160 | 164 | scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
161 |
| - return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point) |
| 165 | + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) |
| 166 | + if (quant_min, quant_max) == (0, 127): |
| 167 | + quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8))) |
| 168 | + return g.op("DequantizeLinear", quantized, scale, zero_point) |
162 | 169 |
|
163 | 170 | def _reduce_op_symbolic(onnx_op_name):
|
164 | 171 | def symbolic(g, self, dim=None, keepdim=None):
|
|
0 commit comments