Skip to content

Commit b789cbb

Browse files
authored
Arm backend: Increase tolerance on Tanh int16 test (#16094)
### Summary - Updates the tolerance on Tanh int16 which intermittently fails. - Refactored test case to use int16 test pipeline
1 parent a94cfea commit b789cbb

File tree

1 file changed

+9
-51
lines changed

1 file changed

+9
-51
lines changed

backends/arm/test/ops/test_tanh.py

Lines changed: 9 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,15 @@
88

99
import pytest
1010
import torch
11-
from executorch.backends.arm.quantizer.arm_quantizer import (
12-
get_symmetric_a16w8_quantization_config,
13-
TOSAQuantizer,
14-
)
1511

16-
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test import common
1713
from executorch.backends.arm.test.tester.test_pipeline import (
1814
EthosU55PipelineINT,
1915
EthosU85PipelineINT,
2016
TosaPipelineFP,
2117
TosaPipelineINT,
2218
VgfPipeline,
2319
)
24-
from executorch.backends.arm.tosa.specification import TosaSpecification
25-
from executorch.backends.xnnpack.test.tester import Quantize
2620

2721
aten_op = "torch.ops.aten.tanh.default"
2822
input_t1 = Tuple[torch.Tensor] # Input x
@@ -114,29 +108,6 @@ def test_tanh_vgf_INT(test_data: Tuple):
114108
pipeline.run()
115109

116110

117-
def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False):
118-
tosa_version = conftest.get_option("tosa_version")
119-
tosa_profiles = {
120-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121-
}
122-
123-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124-
125-
# Use a smaller episilon value to not greatly inflate [qmin, qmax]
126-
quantizer.set_global(
127-
get_symmetric_a16w8_quantization_config(
128-
is_per_channel=per_channel_quantization, epsilon=2**-16
129-
)
130-
)
131-
132-
return Quantize(
133-
quantizer,
134-
get_symmetric_a16w8_quantization_config(
135-
is_per_channel=per_channel_quantization, epsilon=2**-16
136-
),
137-
)
138-
139-
140111
@common.parametrize("test_data", test_data_suite)
141112
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
142113
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
@@ -150,13 +121,8 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
150121
per_channel_quantization=per_channel_quantization,
151122
use_to_edge_transform_and_lower=True,
152123
tosa_extensions=["int16"],
153-
)
154-
155-
pipeline.change_args(
156-
"quantize",
157-
get_symmetric_a16w8_tanh_quantizer(
158-
per_channel_quantization=per_channel_quantization
159-
),
124+
epsilon=2**-16,
125+
rtol=2e-03,
160126
)
161127
pipeline.run()
162128

@@ -177,13 +143,9 @@ def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
177143
exir_ops=[],
178144
per_channel_quantization=per_channel_quantization,
179145
use_to_edge_transform_and_lower=True,
180-
)
181-
182-
pipeline.change_args(
183-
"quantize",
184-
get_symmetric_a16w8_tanh_quantizer(
185-
per_channel_quantization=per_channel_quantization
186-
),
146+
a16w8_quantization=True,
147+
epsilon=2**-16,
148+
rtol=2e-03,
187149
)
188150
pipeline.run()
189151

@@ -201,12 +163,8 @@ def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor):
201163
exir_ops=[],
202164
per_channel_quantization=per_channel_quantization,
203165
use_to_edge_transform_and_lower=True,
204-
)
205-
206-
pipeline.change_args(
207-
"quantize",
208-
get_symmetric_a16w8_tanh_quantizer(
209-
per_channel_quantization=per_channel_quantization
210-
),
166+
a16w8_quantization=True,
167+
epsilon=2**-16,
168+
rtol=2e-03,
211169
)
212170
pipeline.run()

0 commit comments

Comments
 (0)