Skip to content

Commit 14db274

Browse files
committed
feat: add onnxslim support
1 parent 14fa1e5 commit 14db274

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

modelopt/onnx/quantization/quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,16 @@ def _preprocess_onnx(
133133
if simplify:
134134
logger.info("Attempting to simplify model")
135135
try:
136-
import onnxsim
136+
import onnxslim
137137
except ModuleNotFoundError as e:
138138
logger.warning(
139-
"onnxsim is not installed. Please install it with 'pip install onnxsim'."
139+
"onnxslim is not installed. Please install it with 'pip install onnxslim'."
140140
)
141141
raise e
142142

143143
try:
144-
model_simp, check = onnxsim.simplify(onnx_model)
145-
if check:
144+
model_simp = onnxslim.slim(onnx_model)
145+
if model_simp:
146146
onnx_model = model_simp
147147
onnx_path = os.path.join(output_dir, f"{model_name}_simp.onnx")
148148
save_onnx(onnx_model, onnx_path, use_external_data_format)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"onnxscript", # For test_onnx_dynamo_export unit test
5555
"onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'",
5656
"polygraphy>=0.49.22",
57+
"onnxslim",
5758
],
5859
"hf": [
5960
"accelerate>=1.0.0",

tests/gpu/onnx/test_simplify.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def test_onnx_simplification(tmp_path):
6161
graph = gs.import_onnx(onnx.load(simplified_onnx_path))
6262
identity_nodes = [n for n in graph.nodes if n.op == "Identity"]
6363
assert not identity_nodes, "Simplified ONNX model contains Identity nodes but it shouldn't."
64-
assert len(graph.nodes) == 3, (
65-
f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 3."
64+
assert len(graph.nodes) == 2, (
65+
f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 2."
6666
)
67-
assert all(n.op in ["Conv", "BatchNormalization", "Relu"] for n in graph.nodes), (
67+
assert all(n.op in ["Conv", "Relu"] for n in graph.nodes), (
6868
"Graph contains more ops than expected."
6969
)
7070

0 commit comments

Comments
 (0)