99from typing import Tuple
1010
1111import torch
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1217
13- from executorch .backends .arm .test import common
1418from executorch .backends .arm .test .tester .test_pipeline import (
1519 EthosU55PipelineINT ,
1620 EthosU85PipelineINT ,
1721 TosaPipelineFP ,
1822 TosaPipelineINT ,
1923 VgfPipeline ,
2024)
21-
25+ from executorch .backends .arm .tosa import TosaSpecification
26+ from executorch .backends .xnnpack .test .tester import Quantize
2227
2328aten_op = "torch.ops.aten.rsqrt.default"
2429input_t1 = Tuple [torch .Tensor ] # Input x
@@ -29,7 +34,7 @@ class Rsqrt(torch.nn.Module):
2934 "ones_4d" : lambda : (torch .ones (1 , 10 , 10 , 10 ),),
3035 "rand_4d_1" : lambda : (torch .rand (1 , 10 , 10 , 10 ),),
3136 "rand_4d_2" : lambda : (torch .rand (1 , 5 , 10 , 20 ),),
32- "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ),),
37+ "rand_3d" : lambda : (torch .rand (5 , 10 , 20 ) + 1.0 ,),
3338 }
3439
3540 def forward (self , x : torch .Tensor ):
@@ -104,3 +109,97 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104109 tosa_version = "TOSA-1.0+INT" ,
105110 )
106111 pipeline .run ()
112+
113+
114+ def get_symmetric_a16w8_rsqrt_quantizer (
115+ u55_config = False , per_channel_quantization = False
116+ ):
117+ tosa_version = conftest .get_option ("tosa_version" )
118+ tosa_profiles = {
119+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
120+ }
121+
122+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
123+ quantizer .set_global (
124+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
125+ )
126+
127+ return Quantize (
128+ quantizer ,
129+ get_symmetric_a16w8_quantization_config (
130+ is_per_channel = per_channel_quantization
131+ ),
132+ )
133+
134+
135+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
136+ def test_rsqrt_int16_tosa_INT (test_tensor : torch .Tensor ):
137+ """Test rsqrt operation with int16 quantization"""
138+ # Create pipeline with custom 16A8W quantization config
139+ pipeline = TosaPipelineINT [input_t1 ](
140+ Rsqrt (),
141+ test_tensor (),
142+ aten_op ,
143+ exir_op = [],
144+ per_channel_quantization = False ,
145+ use_to_edge_transform_and_lower = True ,
146+ tosa_extensions = ["int16" ],
147+ )
148+
149+ pipeline .change_args (
150+ "quantize" ,
151+ get_symmetric_a16w8_rsqrt_quantizer (
152+ per_channel_quantization = False
153+ ),
154+ )
155+ # Run the pipeline
156+ pipeline .run ()
157+
158+
159+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
160+ @common .XfailIfNoCorstone300
161+ def test_rsqrt_int16_u55_INT16 (test_tensor : torch .Tensor ):
162+ """Test rsqrt operation with int16 quantization on U55"""
163+ pipeline = EthosU55PipelineINT [input_t1 ](
164+ Rsqrt (),
165+ test_tensor (),
166+ aten_op ,
167+ exir_ops = [],
168+ per_channel_quantization = True ,
169+ use_to_edge_transform_and_lower = True ,
170+ atol = 1e-02 ,
171+ rtol = 1e-02 ,
172+ run_on_fvp = True ,
173+ )
174+
175+ pipeline .change_args (
176+ "quantize" ,
177+ get_symmetric_a16w8_rsqrt_quantizer (
178+ per_channel_quantization = True
179+ ),
180+ )
181+ pipeline .run ()
182+
183+
184+ @common .parametrize ("test_tensor" , Rsqrt .test_parameters )
185+ @common .XfailIfNoCorstone320
186+ def test_rsqrt_int16_u85_INT16 (test_tensor : torch .Tensor ):
187+ """Test rsqrt operation with int16 quantization on U85"""
188+ pipeline = EthosU85PipelineINT [input_t1 ](
189+ Rsqrt (),
190+ test_tensor (),
191+ aten_op ,
192+ exir_ops = [],
193+ use_to_edge_transform_and_lower = True ,
194+ atol = 1e-02 ,
195+ rtol = 1e-02 ,
196+ run_on_fvp = True ,
197+ )
198+
199+ pipeline .change_args (
200+ "quantize" ,
201+ get_symmetric_a16w8_rsqrt_quantizer (
202+ per_channel_quantization = False
203+ ),
204+ )
205+ pipeline .run ()
0 commit comments