Skip to content

Commit 80c9040

Browse files
authored
Arm backend: Update for missing operators for int16x8 (#15521)
### Summary Updates to operators for int16: - avg_pool2d, clamp, constant_pad_nd, eq, ge, gt, le, lt, max_pool2d, upsample_bilinear, upsample_nearest2d ### Test plan Unit tests added for affected operators cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Saoirse Stewart <[email protected]>
1 parent 5a55382 commit 80c9040

27 files changed

+650
-22
lines changed

backends/arm/_passes/rewrite_upsample.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_node,
1212
get_first_fake_tensor,
1313
)
14+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1415
from executorch.backends.arm.tosa.utils import get_resize_parameters
1516
from executorch.exir.dialects._ops import ops as exir_ops
1617
from executorch.exir.pass_base import ExportPass, PassResult
@@ -52,7 +53,9 @@ def call(self, graph_module):
5253
node.replace_all_uses_with(tosa_resize_node)
5354
graph_module.graph.erase_node(node)
5455
input_dtype = get_first_fake_tensor(x).dtype
55-
if input_dtype == torch.int8 and resize_mode == "bilinear":
56+
if (
57+
input_dtype == torch.int8 or input_dtype == torch.int16
58+
) and resize_mode == "bilinear":
5659
input_size = get_first_fake_tensor(x).shape
5760
input_size_xy = input_size[2:]
5861
output_size = get_first_fake_tensor(node).shape
@@ -71,6 +74,11 @@ def call(self, graph_module):
7174
exir_ops.backend.tosa.RESCALE.default,
7275
)
7376
tosa_resize_node.replace_all_uses_with(rescale_node)
77+
if input_dtype == torch.int16:
78+
tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = (
79+
TosaSpecialDtype.INT48
80+
)
81+
7482
rescale_node.args = (
7583
tosa_resize_node,
7684
output_dtype,

backends/arm/operators/op_avg_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def define_node(
118118
validate_valid_dtype(
119119
self.target,
120120
[inputs[0], output],
121-
[ts.DType.INT8, ts.DType.FP32],
121+
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
122122
output.tosa_spec,
123123
)
124124

125-
if inputs[0].dtype == ts.DType.INT8:
125+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
126126
accumulator_type = ts.DType.INT32
127127
input_qargs = get_input_qparams(node)
128128
input_zp = input_qargs[0].get_zp_per_tensor()

backends/arm/operators/op_clamp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
7474
return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist()
7575
elif dtype == torch.int8:
7676
return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist()
77+
elif dtype == torch.int16:
78+
return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist()
7779
else:
7880
raise ValueError(f"Unsupported dtype for to_bytes: {dtype}")
7981

@@ -89,7 +91,7 @@ def define_node(
8991
validate_valid_dtype(
9092
self.target,
9193
[inputs[0], output],
92-
[ts.DType.INT8, ts.DType.FP16, ts.DType.FP32],
94+
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
9395
output.tosa_spec,
9496
)
9597

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def define_node(
5050
[inputs[0], output],
5151
[
5252
ts.DType.INT8,
53+
ts.DType.INT16,
5354
ts.DType.INT32,
5455
ts.DType.FP32,
5556
ts.DType.BOOL,
@@ -62,6 +63,11 @@ def define_node(
6263
qargs = input_qparams[0]
6364
pad_const_val = qargs.quantize_value(inputs[2].number).item()
6465
pad_const_dtype = ts.DType.INT8
66+
elif inputs[0].dtype == ts.DType.INT16:
67+
input_qparams = get_input_qparams(node)
68+
qargs = input_qparams[0]
69+
pad_const_val = qargs.quantize_value(inputs[2].number).item()
70+
pad_const_dtype = ts.DType.INT16
6571
else:
6672
pad_const_val = inputs[2].number
6773
pad_const_dtype = inputs[0].dtype

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_le.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_max_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
[inputs[0], output],
50-
[ts.DType.INT8, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353

0 commit comments

Comments
 (0)