Skip to content

Commit 24c28e1

Browse files
3l1meta-codesync[bot]
authored andcommitted
Enable int16 rsqrt on Ethos-U55/U85 (#14770)
Summary: Pull Request resolved: #14770 Fix Rsqrt op for int16 Add unit tests bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: Ninja91, digantdesai Differential Revision: D83802158
1 parent 5467a4d commit 24c28e1

File tree

3 files changed

+120
-11
lines changed

3 files changed

+120
-11
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,27 @@ def f(x: torch.Tensor) -> torch.Tensor:
185185
)
186186
# Dont use the 7 LSBs.
187187
x = in_quantargs.dequantize_value((x & ~0x7F))
188+
# x = in_quantargs.dequantize_value(x) // (1 << 7)
188189
x = torch_op(x)
190+
# x = x * (1 << 7)
189191
return out_quantargs.quantize_value(x)
190192

191-
lut_values = f(
192-
torch.linspace(
193-
start=in_quantargs.qmin,
194-
end=in_quantargs.qmax + 1,
195-
steps=513,
196-
# use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1.
197-
dtype=torch.int32,
198-
)
193+
# Create the 9.7 fixed-point value
194+
r = torch.linspace(
195+
start=in_quantargs.qmin,
196+
end=in_quantargs.qmax + 1,
197+
steps=513,
198+
# use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1.
199+
dtype=torch.int32,
199200
)
201+
# # Cast input to a wider type (int32)
202+
# r_int32 = r.to(torch.int32)
203+
# # Extract most significant 9 bits
204+
# index = (r_int32 >> 7) & 0x1FF
205+
# # Extract the fractional 7 bits
206+
# fraction = r_int32 & 0x7F
207+
208+
lut_values = f(r)
200209
# Calculate how much we need to shift table values to fit in 16 signed bits
201210
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
202211
# Example:

backends/arm/test/ops/test_rsqrt.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

13-
from executorch.backends.arm.test import common
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
1620
EthosU85PipelineINT,
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
21-
25+
from executorch.backends.arm.tosa import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2227

2328
aten_op = "torch.ops.aten.rsqrt.default"
2429
input_t1 = Tuple[torch.Tensor] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934
"ones_4d": lambda: (torch.ones(1, 10, 10, 10),),
3035
"rand_4d_1": lambda: (torch.rand(1, 10, 10, 10),),
3136
"rand_4d_2": lambda: (torch.rand(1, 5, 10, 20),),
32-
"rand_3d": lambda: (torch.rand(5, 10, 20),),
37+
"rand_3d": lambda: (torch.rand(5, 10, 20) + 1.0,),
3338
}
3439

3540
def forward(self, x: torch.Tensor):
@@ -104,3 +109,97 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109
tosa_version="TOSA-1.0+INT",
105110
)
106111
pipeline.run()
112+
113+
114+
def get_symmetric_a16w8_rsqrt_quantizer(
115+
u55_config=False, per_channel_quantization=False
116+
):
117+
tosa_version = conftest.get_option("tosa_version")
118+
tosa_profiles = {
119+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
120+
}
121+
122+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
123+
quantizer.set_global(
124+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
125+
)
126+
127+
return Quantize(
128+
quantizer,
129+
get_symmetric_a16w8_quantization_config(
130+
is_per_channel=per_channel_quantization
131+
),
132+
)
133+
134+
135+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
136+
def test_rsqrt_int16_tosa_INT(test_tensor: torch.Tensor):
137+
"""Test rsqrt operation with int16 quantization"""
138+
# Create pipeline with custom 16A8W quantization config
139+
pipeline = TosaPipelineINT[input_t1](
140+
Rsqrt(),
141+
test_tensor(),
142+
aten_op,
143+
exir_op=[],
144+
per_channel_quantization=False,
145+
use_to_edge_transform_and_lower=True,
146+
tosa_extensions=["int16"],
147+
)
148+
149+
pipeline.change_args(
150+
"quantize",
151+
get_symmetric_a16w8_rsqrt_quantizer(
152+
per_channel_quantization=False
153+
),
154+
)
155+
# Run the pipeline
156+
pipeline.run()
157+
158+
159+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
160+
@common.XfailIfNoCorstone300
161+
def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
162+
"""Test rsqrt operation with int16 quantization on U55"""
163+
pipeline = EthosU55PipelineINT[input_t1](
164+
Rsqrt(),
165+
test_tensor(),
166+
aten_op,
167+
exir_ops=[],
168+
per_channel_quantization=True,
169+
use_to_edge_transform_and_lower=True,
170+
atol=1e-02,
171+
rtol=1e-02,
172+
run_on_fvp=True,
173+
)
174+
175+
pipeline.change_args(
176+
"quantize",
177+
get_symmetric_a16w8_rsqrt_quantizer(
178+
per_channel_quantization=True
179+
),
180+
)
181+
pipeline.run()
182+
183+
184+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
185+
@common.XfailIfNoCorstone320
186+
def test_rsqrt_int16_u85_INT16(test_tensor: torch.Tensor):
187+
"""Test rsqrt operation with int16 quantization on U85"""
188+
pipeline = EthosU85PipelineINT[input_t1](
189+
Rsqrt(),
190+
test_tensor(),
191+
aten_op,
192+
exir_ops=[],
193+
use_to_edge_transform_and_lower=True,
194+
atol=1e-02,
195+
rtol=1e-02,
196+
run_on_fvp=True,
197+
)
198+
199+
pipeline.change_args(
200+
"quantize",
201+
get_symmetric_a16w8_rsqrt_quantizer(
202+
per_channel_quantization=False
203+
),
204+
)
205+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def define_arm_tests():
2424
"ops/test_linear.py",
2525
"ops/test_mul.py",
2626
"ops/test_permute.py",
27+
"ops/test_rsqrt.py",
2728
"ops/test_slice.py",
2829
"ops/test_sigmoid.py",
2930
"ops/test_sub.py",

0 commit comments

Comments
 (0)