4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , List , Optional , Tuple
7
+ import functools
8
+ from typing import Any , List , Optional , Sequence , Tuple
8
9
10
+ import coremltools as ct
9
11
import executorch
10
12
import executorch .backends .test .harness .stages as BaseStages
11
-
12
13
import torch
14
+
15
+ from executorch .backends .apple .coreml .compiler import CoreMLBackend
13
16
from executorch .backends .apple .coreml .partition import CoreMLPartitioner
17
+ from executorch .backends .apple .coreml .quantizer import CoreMLQuantizer
14
18
from executorch .backends .test .harness import Tester as TesterBase
15
19
from executorch .backends .test .harness .stages import StageType
16
20
from executorch .exir import EdgeCompileConfig
17
21
from executorch .exir .backend .partitioner import Partitioner
18
22
19
23
24
+ def _create_default_partitioner (
25
+ minimum_deployment_target : Any = ct .target .iOS15 ,
26
+ ) -> CoreMLPartitioner :
27
+ return CoreMLPartitioner (
28
+ compile_specs = CoreMLBackend .generate_compile_specs (
29
+ minimum_deployment_target = minimum_deployment_target
30
+ )
31
+ )
32
+
33
+
34
+ def _get_static_int8_linear_qconfig ():
35
+ return ct .optimize .torch .quantization .LinearQuantizerConfig (
36
+ global_config = ct .optimize .torch .quantization .ModuleLinearQuantizerConfig (
37
+ quantization_scheme = "symmetric" ,
38
+ activation_dtype = torch .quint8 ,
39
+ weight_dtype = torch .qint8 ,
40
+ weight_per_channel = True ,
41
+ )
42
+ )
43
+
44
+
45
+ class Quantize (BaseStages .Quantize ):
46
+ def __init__ (
47
+ self ,
48
+ quantizer : Optional [CoreMLQuantizer ] = None ,
49
+ quantization_config : Optional [Any ] = None ,
50
+ calibrate : bool = True ,
51
+ calibration_samples : Optional [Sequence [Any ]] = None ,
52
+ is_qat : Optional [bool ] = False ,
53
+ ):
54
+ super ().__init__ (
55
+ quantizer = quantizer
56
+ or CoreMLQuantizer (
57
+ quantization_config or _get_static_int8_linear_qconfig ()
58
+ ),
59
+ calibrate = calibrate ,
60
+ calibration_samples = calibration_samples ,
61
+ is_qat = is_qat ,
62
+ )
63
+
64
+
20
65
class Partition (BaseStages .Partition ):
21
- def __init__ (self , partitioner : Optional [Partitioner ] = None ):
66
+ def __init__ (
67
+ self ,
68
+ partitioner : Optional [Partitioner ] = None ,
69
+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
70
+ ):
22
71
super ().__init__ (
23
- partitioner = partitioner or CoreMLPartitioner ,
72
+ partitioner = partitioner
73
+ or _create_default_partitioner (minimum_deployment_target ),
24
74
)
25
75
26
76
@@ -29,9 +79,12 @@ def __init__(
29
79
self ,
30
80
partitioners : Optional [List [Partitioner ]] = None ,
31
81
edge_compile_config : Optional [EdgeCompileConfig ] = None ,
82
+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
32
83
):
33
84
super ().__init__ (
34
- default_partitioner_cls = CoreMLPartitioner ,
85
+ default_partitioner_cls = lambda : _create_default_partitioner (
86
+ minimum_deployment_target
87
+ ),
35
88
partitioners = partitioners ,
36
89
edge_compile_config = edge_compile_config ,
37
90
)
@@ -43,13 +96,20 @@ def __init__(
43
96
module : torch .nn .Module ,
44
97
example_inputs : Tuple [torch .Tensor ],
45
98
dynamic_shapes : Optional [Tuple [Any ]] = None ,
99
+ minimum_deployment_target : Optional [Any ] = ct .target .iOS15 ,
46
100
):
47
101
# Specialize for XNNPACK
48
102
stage_classes = (
49
103
executorch .backends .test .harness .Tester .default_stage_classes ()
50
104
| {
51
- StageType .PARTITION : Partition ,
52
- StageType .TO_EDGE_TRANSFORM_AND_LOWER : ToEdgeTransformAndLower ,
105
+ StageType .QUANTIZE : Quantize ,
106
+ StageType .PARTITION : functools .partial (
107
+ Partition , minimum_deployment_target = minimum_deployment_target
108
+ ),
109
+ StageType .TO_EDGE_TRANSFORM_AND_LOWER : functools .partial (
110
+ ToEdgeTransformAndLower ,
111
+ minimum_deployment_target = minimum_deployment_target ,
112
+ ),
53
113
}
54
114
)
55
115
0 commit comments