Skip to content

Commit 088c357

Browse files
authored
shorten int8-quantized naming (#149)
1 parent c891253 commit 088c357

File tree

8 files changed

+9
-9
lines changed

8 files changed

+9
-9
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

tools/quantize/quantize-ort.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,29 +59,30 @@ def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_
5959
# data reader
6060
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
6161

62-
def check_opset(self, convert=True):
62+
def check_opset(self):
6363
model = onnx.load(self.model_path)
6464
if model.opset_import[0].version != 13:
6565
print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
6666
# convert opset version to 13
6767
model_opset13 = version_converter.convert_version(model, 13)
6868
# save converted model
69-
output_name = '{}-opset.onnx'.format(self.model_path[:-5])
69+
output_name = '{}-opset13.onnx'.format(self.model_path[:-5])
7070
onnx.save_model(model_opset13, output_name)
7171
# update model_path for quantization
72-
self.model_path = output_name
72+
return output_name
73+
return self.model_path
7374

7475
def run(self):
7576
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
76-
self.check_opset()
77-
output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
78-
quantize_static(self.model_path, output_name, self.dr,
77+
new_model_path = self.check_opset()
78+
output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type)
79+
quantize_static(new_model_path, output_name, self.dr,
7980
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
8081
per_channel=self.per_channel,
8182
weight_type=self.type_dict[self.wt_type],
8283
activation_type=self.type_dict[self.act_type])
83-
os.remove('augmented_model.onnx')
84-
os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
84+
if new_model_path != self.model_path:
85+
os.remove(new_model_path)
8586
print('\tQuantized model saved to {}'.format(output_name))
8687

8788
models=dict(
@@ -132,4 +133,3 @@ def run(self):
132133
for selected_model_name in selected_models:
133134
q = models[selected_model_name]
134135
q.run()
135-

0 commit comments

Comments
 (0)