From ad1f940ee390efc049ee78c81b72a12ada7e05f4 Mon Sep 17 00:00:00 2001 From: inisis Date: Tue, 28 Oct 2025 19:50:28 +0800 Subject: [PATCH 1/2] feat: add onnxslim support Signed-off-by: inisis --- modelopt/onnx/quantization/quantize.py | 8 ++++---- setup.py | 1 + tests/gpu/onnx/test_simplify.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 9bc025e33..948bc28d1 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -133,16 +133,16 @@ def _preprocess_onnx( if simplify: logger.info("Attempting to simplify model") try: - import onnxsim + import onnxslim except ModuleNotFoundError as e: logger.warning( - "onnxsim is not installed. Please install it with 'pip install onnxsim'." + "onnxslim is not installed. Please install it with 'pip install onnxslim'." ) raise e try: - model_simp, check = onnxsim.simplify(onnx_model) - if check: + model_simp = onnxslim.slim(onnx_model) + if model_simp: onnx_model = model_simp onnx_path = os.path.join(output_dir, f"{model_name}_simp.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) diff --git a/setup.py b/setup.py index 67bf114ae..3ded3c87c 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ "onnxscript", # For test_onnx_dynamo_export unit test "onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'", "polygraphy>=0.49.22", + "onnxslim", ], "hf": [ "accelerate>=1.0.0", diff --git a/tests/gpu/onnx/test_simplify.py b/tests/gpu/onnx/test_simplify.py index 538380126..003e52200 100644 --- a/tests/gpu/onnx/test_simplify.py +++ b/tests/gpu/onnx/test_simplify.py @@ -61,10 +61,10 @@ def test_onnx_simplification(tmp_path): graph = gs.import_onnx(onnx.load(simplified_onnx_path)) identity_nodes = [n for n in graph.nodes if n.op == "Identity"] assert not identity_nodes, "Simplified ONNX model contains Identity nodes but it shouldn't." - assert len(graph.nodes) == 3, ( - f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 3." + assert len(graph.nodes) == 2, ( + f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 2." ) - assert all(n.op in ["Conv", "BatchNormalization", "Relu"] for n in graph.nodes), ( + assert all(n.op in ["Conv", "Relu"] for n in graph.nodes), ( "Graph contains more ops than expected." ) From abe6de4aa08f7882a00f24b8d85e4bc85944e452 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 29 Oct 2025 01:17:04 +0800 Subject: [PATCH 2/2] refactor: remove unused package Signed-off-by: inisis --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 3ded3c87c..ee3488374 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,6 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For test_onnx_dynamo_export unit test - "onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'", "polygraphy>=0.49.22", "onnxslim", ],