@@ -59,29 +59,30 @@ def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_
59
59
# data reader
60
60
self .dr = DataReader (self .model_path , self .calibration_image_dir , self .transforms , data_dim )
61
61
62
- def check_opset (self , convert = True ):
62
+ def check_opset (self ):
63
63
model = onnx .load (self .model_path )
64
64
if model .opset_import [0 ].version != 13 :
65
65
print ('\t model opset version: {}. Converting to opset 13' .format (model .opset_import [0 ].version ))
66
66
# convert opset version to 13
67
67
model_opset13 = version_converter .convert_version (model , 13 )
68
68
# save converted model
69
- output_name = '{}-opset .onnx' .format (self .model_path [:- 5 ])
69
+ output_name = '{}-opset13 .onnx' .format (self .model_path [:- 5 ])
70
70
onnx .save_model (model_opset13 , output_name )
71
71
# update model_path for quantization
72
- self .model_path = output_name
72
+ return output_name
73
+ return self .model_path
73
74
74
75
def run (self ):
75
76
print ('Quantizing {}: act_type {}, wt_type {}' .format (self .model_path , self .act_type , self .wt_type ))
76
- self .check_opset ()
77
- output_name = '{}-act_{}-wt_{}-quantized .onnx' .format (self .model_path [:- 5 ], self . act_type , self .wt_type )
78
- quantize_static (self . model_path , output_name , self .dr ,
77
+ new_model_path = self .check_opset ()
78
+ output_name = '{}_{} .onnx' .format (self .model_path [:- 5 ], self .wt_type )
79
+ quantize_static (new_model_path , output_name , self .dr ,
79
80
quant_format = QuantFormat .QOperator , # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
80
81
per_channel = self .per_channel ,
81
82
weight_type = self .type_dict [self .wt_type ],
82
83
activation_type = self .type_dict [self .act_type ])
83
- os . remove ( 'augmented_model.onnx' )
84
- os .remove ('{}-opt.onnx' . format ( self . model_path [: - 5 ]) )
84
+ if new_model_path != self . model_path :
85
+ os .remove (new_model_path )
85
86
print ('\t Quantized model saved to {}' .format (output_name ))
86
87
87
88
models = dict (
@@ -132,4 +133,3 @@ def run(self):
132
133
for selected_model_name in selected_models :
133
134
q = models [selected_model_name ]
134
135
q .run ()
135
-
0 commit comments