Skip to content

Commit 47ef487

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add tests for int16 rsqrt on Ethos-U55/U85 (#15631)
Summary: Fix Rsqrt op for int16 Relands D83802158 Some TOSA INT16 tests are successful internally but failing in OSS. Marking them with xfail. https://hud.pytorch.org/pr/pytorch/executorch/15631#54772221446 Differential Revision: D86402524
1 parent 556142c commit 47ef487

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

backends/arm/test/ops/test_rsqrt.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1218

13-
from executorch.backends.arm.test import common
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
21-
26+
from executorch.backends.arm.tosa import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2228

2329
aten_op = "torch.ops.aten.rsqrt.default"
2430
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,3 +110,99 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104110
tosa_version="TOSA-1.0+INT",
105111
)
106112
pipeline.run()
113+
114+
115+
def get_symmetric_a16w8_rsqrt_quantizer(
116+
u55_config=False, per_channel_quantization=False
117+
):
118+
tosa_version = conftest.get_option("tosa_version")
119+
tosa_profiles = {
120+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121+
}
122+
123+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
quantizer.set_global(
125+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126+
)
127+
128+
return Quantize(
129+
quantizer,
130+
get_symmetric_a16w8_quantization_config(
131+
is_per_channel=per_channel_quantization
132+
),
133+
)
134+
135+
136+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
137+
@pytest.mark.xfail(
138+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
139+
)
140+
def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor):
141+
"""Test rsqrt operation with int16 quantization"""
142+
pipeline = TosaPipelineINT[input_t1](
143+
Rsqrt(),
144+
test_tensor(),
145+
aten_op,
146+
exir_op=[],
147+
per_channel_quantization=False,
148+
use_to_edge_transform_and_lower=True,
149+
tosa_extensions=["int16"],
150+
)
151+
152+
pipeline.change_args(
153+
"quantize",
154+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
155+
)
156+
# Run the pipeline
157+
pipeline.run()
158+
159+
160+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
161+
@common.XfailIfNoCorstone300
162+
@pytest.mark.xfail(
163+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
164+
)
165+
def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor):
166+
"""Test rsqrt operation with int16 quantization on U55"""
167+
pipeline = EthosU55PipelineINT[input_t1](
168+
Rsqrt(),
169+
test_tensor(),
170+
aten_op,
171+
exir_ops=[],
172+
per_channel_quantization=True,
173+
use_to_edge_transform_and_lower=True,
174+
atol=1e-03,
175+
rtol=1e-03,
176+
run_on_fvp=True,
177+
)
178+
179+
pipeline.change_args(
180+
"quantize",
181+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=True),
182+
)
183+
pipeline.run()
184+
185+
186+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
187+
@common.XfailIfNoCorstone320
188+
@pytest.mark.xfail(
189+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
190+
)
191+
def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor):
192+
"""Test rsqrt operation with int16 quantization on U85"""
193+
pipeline = EthosU85PipelineINT[input_t1](
194+
Rsqrt(),
195+
test_tensor(),
196+
aten_op,
197+
exir_ops=[],
198+
use_to_edge_transform_and_lower=True,
199+
atol=1e-03,
200+
rtol=1e-03,
201+
run_on_fvp=True,
202+
)
203+
204+
pipeline.change_args(
205+
"quantize",
206+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
207+
)
208+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def define_arm_tests():
2222
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
25+
"ops/test_rsqrt.py",
2526
"ops/test_slice.py",
2627
"ops/test_sigmoid.py",
2728
"ops/test_sub.py",

0 commit comments

Comments
 (0)