From efad630bd37080eb81f202531d601fb283c13a6e Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 20 May 2026 14:34:51 +0800 Subject: [PATCH 1/4] feat(export): normalize exported ONNX in-place via optimize_onnx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a post-export normalization step inside `export_pytorch()` so that every export path (the `winml export` CLI, `build_hf_model`, the `winml build` command, and direct Python API callers) benefits from graph normalization automatically. The new helper `_normalize_exported_model()`: - logs `Normalizing model` - runs `optimize_onnx()` into a `tempfile.mkdtemp()` working directory - on success, replaces the original `.onnx` + sidecar in-place via `copy_onnx_model()` (which overwrites both files correctly — confirmed by new overwrite tests in `tests/unit/onnx/test_external_data.py`) - on failure, logs a warning and leaves the original export untouched - removes the temp directory in a `finally` clause either way - records the outcome in `stats["model_normalization_succeeded"]` Tests: - `tests/unit/export/test_pytorch_export.py`: success path populates `value_info`; mocked-failure path leaves the model un-shape-inferenced. - `tests/unit/onnx/test_external_data.py`: `copy_onnx_model` overwrites a pre-existing dst correctly with and without external data sidecars. --- src/winml/modelkit/export/pytorch.py | 40 ++++++- tests/unit/export/test_pytorch_export.py | 46 ++++++++ tests/unit/onnx/test_external_data.py | 127 +++++++++++++++++++++-- 3 files changed, 205 insertions(+), 8 deletions(-) diff --git a/src/winml/modelkit/export/pytorch.py b/src/winml/modelkit/export/pytorch.py index 3d157ddba..a0d19286b 100644 --- a/src/winml/modelkit/export/pytorch.py +++ b/src/winml/modelkit/export/pytorch.py @@ -86,7 +86,7 @@ def export_pytorch( ) with warnings.catch_warnings(): warnings.filterwarnings("ignore") - return exporter.export( + stats = exporter.export( model=model, output_path=str(output_path), export_config=export_config, @@ -94,3 +94,41 @@ def export_pytorch( task=task, **kwargs, ) + + stats["model_normalization_succeeded"] = _normalize_exported_model(output_path) + + return stats + + +def _normalize_exported_model(output_path: Path) -> bool: + """Normalize the exported ONNX in-place via optimize_onnx. + + Writes the normalized model into a temporary directory, then replaces + the original export (and its `.data` sidecar, if any) via + copy_onnx_model — which overwrites both the `.onnx` file and its + external-data sidecar by design. On any failure during optimization + or copy, logs a warning and leaves the original export untouched. + The temp directory is removed in either case. + + Returns: + True if normalization succeeded, False otherwise. + """ + import shutil + import tempfile + + from ..onnx import copy_onnx_model + from ..optim import optimize_onnx + + logger.info("Normalizing model") + tmp_dir = Path(tempfile.mkdtemp()) + tmp_path = tmp_dir / output_path.name + + try: + optimize_onnx(model=output_path, output=tmp_path) + copy_onnx_model(tmp_path, output_path) + except Exception as e: + logger.warning("Normalization failed; keeping un-normalized export: %s", e) + return False + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + return True diff --git a/tests/unit/export/test_pytorch_export.py b/tests/unit/export/test_pytorch_export.py index 89fd797a3..ff656ca31 100644 --- a/tests/unit/export/test_pytorch_export.py +++ b/tests/unit/export/test_pytorch_export.py @@ -10,6 +10,8 @@ from __future__ import annotations +from unittest.mock import patch + import onnx import pytest import torch @@ -23,6 +25,20 @@ ) +def _all_value_info_have_shape(model: onnx.ModelProto) -> bool: + """Every intermediate (value_info) tensor has a concrete or symbolic shape.""" + if not model.graph.value_info: + return False + for vi in model.graph.value_info: + shape = vi.type.tensor_type.shape + if not shape.dim: + return False + for dim in shape.dim: + if not dim.HasField("dim_value") and not dim.HasField("dim_param"): + return False + return True + + # ============================================================================= # Test Models (pure PyTorch, no HF) # ============================================================================= @@ -268,6 +284,36 @@ def forward(self, ids): # This is expected — the test verifies the flow works up to export pass + def test_normalization_succeeds_and_shape_inferences(self, tmp_path) -> None: + """After export, stats reports normalization succeeded and value_info is fully shaped.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + result = export_pytorch(model, tmp_path / "model.onnx", config) + + assert result["model_normalization_succeeded"] is True + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert _all_value_info_have_shape(onnx_model) + + def test_failed_normalization_skips_shape_inference(self, tmp_path) -> None: + """When normalization is mocked to return False, no shape inference runs.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + with patch( + "winml.modelkit.export.pytorch._normalize_exported_model", + return_value=False, + ): + result = export_pytorch(model, tmp_path / "model.onnx", config) + + assert result["model_normalization_succeeded"] is False + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert not _all_value_info_have_shape(onnx_model) + def test_mismatched_input_order_exports_successfully(self, tmp_path) -> None: """Export succeeds when InputTensorSpec order differs from forward() param order. diff --git a/tests/unit/onnx/test_external_data.py b/tests/unit/onnx/test_external_data.py index 4a77e3daf..0d118f475 100644 --- a/tests/unit/onnx/test_external_data.py +++ b/tests/unit/onnx/test_external_data.py @@ -10,7 +10,7 @@ import numpy as np import onnx -from onnx import TensorProto, helper, numpy_helper +from onnx import TensorProto, external_data_helper, helper, numpy_helper from winml.modelkit.onnx.external_data import ( copy_onnx_model, @@ -27,14 +27,39 @@ def _make_small_model() -> onnx.ModelProto: """Create a minimal ONNX model (no external data).""" x_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4]) y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2]) - weight = numpy_helper.from_array( - np.random.randn(4, 2).astype(np.float32), name="W" - ) + weight = numpy_helper.from_array(np.random.randn(4, 2).astype(np.float32), name="W") node = helper.make_node("MatMul", ["X", "W"], ["Y"]) graph = helper.make_graph([node], "test", [x_info], [y_info], [weight]) return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) +def _make_filled_model(value: float, shape: tuple[int, ...]) -> onnx.ModelProto: + """Create a deterministic ONNX model with a constant-filled initializer. + + Used by overwrite tests where two distinguishable models are needed. + """ + weight = numpy_helper.from_array(np.full(shape, value, dtype=np.float32), name="W") + inp = helper.make_tensor_value_info("X", TensorProto.FLOAT, list(shape)) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(shape)) + node = helper.make_node("Add", ["X", "W"], ["Y"]) + graph = helper.make_graph([node], "g", [inp], [out], [weight]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def _serialize_without_external_location(model: onnx.ModelProto) -> bytes: + """Serialize the model with the `location` entry stripped from every + external_data tensor — for comparing two models that point to different + sidecar filenames but are otherwise identical.""" + clone = onnx.ModelProto() + clone.CopyFrom(model) + for tensor in external_data_helper._get_all_tensors(clone): + if tensor.data_location == TensorProto.EXTERNAL: + for entry in list(tensor.external_data): + if entry.key == "location": + tensor.external_data.remove(entry) + return clone.SerializeToString(deterministic=True) + + class TestGetExternalDataFiles: """Tests for get_external_data_files().""" @@ -51,7 +76,8 @@ def test_with_external_data(self, tmp_path: Path) -> None: model = _make_small_model() path = tmp_path / "ext.onnx" onnx.save_model( - model, str(path), + model, + str(path), save_as_external_data=True, all_tensors_to_one_file=True, location="ext.onnx.data", @@ -74,7 +100,8 @@ def test_with_external(self, tmp_path: Path) -> None: model = _make_small_model() path = tmp_path / "ext.onnx" onnx.save_model( - model, str(path), + model, + str(path), save_as_external_data=True, all_tensors_to_one_file=True, location="ext.onnx.data", @@ -106,7 +133,8 @@ def test_copy_with_external_data(self, tmp_path: Path) -> None: src = tmp_path / "src" / "model.onnx" src.parent.mkdir() onnx.save_model( - model, str(src), + model, + str(src), save_as_external_data=True, all_tensors_to_one_file=True, location="model.onnx.data", @@ -145,3 +173,88 @@ def test_copy_invalid_file_falls_back(self, tmp_path: Path) -> None: assert dst.exists() assert dst.read_text() == "not a real onnx file" + + def test_copy_overwrites_existing_dst_no_external_data(self, tmp_path: Path) -> None: + """Pre-existing dst (no external data) is overwritten byte-for-byte by src.""" + src = tmp_path / "src.onnx" + dst = tmp_path / "dst.onnx" + + onnx.save(_make_filled_model(1.0, (4, 4)), str(src)) + onnx.save(_make_filled_model(99.0, (8, 8)), str(dst)) # pre-existing, different + + pre_dst_bytes = dst.read_bytes() + src_bytes = src.read_bytes() + assert pre_dst_bytes != src_bytes + + copy_onnx_model(src, dst) + + post_dst_bytes = dst.read_bytes() + assert post_dst_bytes == src_bytes + assert post_dst_bytes != pre_dst_bytes + assert not (tmp_path / "dst.onnx.data").exists() + + def test_copy_overwrites_existing_dst_with_external_data(self, tmp_path: Path) -> None: + """Pre-existing dst + sidecar (external data) are both overwritten. + + Verifies: + - dst.onnx.data is byte-identical to src.onnx.data + - dst.onnx matches src.onnx except for the external_data.location field + - dst.onnx's location field points at dst.onnx.data + - Loaded initializer arrays are equal + """ + src = tmp_path / "src.onnx" + dst = tmp_path / "dst.onnx" + src_data = tmp_path / "src.onnx.data" + dst_data = tmp_path / "dst.onnx.data" + + onnx.save_model( + _make_filled_model(2.0, (64, 64)), + str(src), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="src.onnx.data", + size_threshold=0, + ) + onnx.save_model( + _make_filled_model(999.0, (32, 32)), + str(dst), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="dst.onnx.data", + size_threshold=0, + ) + + src_data_bytes = src_data.read_bytes() + pre_dst_data_bytes = dst_data.read_bytes() + pre_dst_onnx_bytes = dst.read_bytes() + assert src_data_bytes != pre_dst_data_bytes + + copy_onnx_model(src, dst) + + # .data file byte-identical to src's sidecar + post_dst_data_bytes = dst_data.read_bytes() + assert post_dst_data_bytes == src_data_bytes + assert post_dst_data_bytes != pre_dst_data_bytes + + # .onnx file no longer matches old dst + assert dst.read_bytes() != pre_dst_onnx_bytes + + # .onnx matches src modulo external_data.location field + src_model = onnx.load(str(src), load_external_data=False) + dst_model = onnx.load(str(dst), load_external_data=False) + assert _serialize_without_external_location( + src_model + ) == _serialize_without_external_location(dst_model) + + # dst.onnx's location must point at dst.onnx.data + for tensor in external_data_helper._get_all_tensors(dst_model): + if tensor.data_location == TensorProto.EXTERNAL: + info = external_data_helper.ExternalDataInfo(tensor) + assert info.location == "dst.onnx.data" + + # Semantic check: loaded initializer arrays are equal + src_full = onnx.load(str(src), load_external_data=True) + dst_full = onnx.load(str(dst), load_external_data=True) + src_arr = numpy_helper.to_array(src_full.graph.initializer[0]) + dst_arr = numpy_helper.to_array(dst_full.graph.initializer[0]) + assert np.array_equal(src_arr, dst_arr) From 52d7afcc18df3231b942e42343a0042ee3a4b101 Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 20 May 2026 15:25:29 +0800 Subject: [PATCH 2/4] address review: add normalize kwarg, 3-state status, exc_info, soften docstring - export_pytorch: new `normalize: bool = True` kwarg. Callers who need the raw torch.onnx.export output can opt out without dropping the normalize step for everyone else. - Stats key renamed `model_normalization_succeeded` (bool) -> `model_normalization_status` (str), values `"not_run"` / `"succeeded"` / `"failed"`. - `_normalize_exported_model`: keep the broad `except`, but pass `exc_info=True` so the traceback lands in the warning log -- a future helper-refactor bug now shows up instead of silently flipping a flag. - Soften the docstring: optimize_onnx failures leave the original untouched, but copy_onnx_model writes directly to the destination, so a copy failure can corrupt the dest. The previous docstring overclaimed. Tests: - Both existing tests updated to assert on the new status key. - New `test_normalize_false_skips_normalization` verifies the kwarg short-circuits the helper and yields `"not_run"`. --- src/winml/modelkit/export/pytorch.py | 39 ++++++++++++++++++------ tests/unit/export/test_pytorch_export.py | 25 ++++++++++++--- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/winml/modelkit/export/pytorch.py b/src/winml/modelkit/export/pytorch.py index a0d19286b..80fff109a 100644 --- a/src/winml/modelkit/export/pytorch.py +++ b/src/winml/modelkit/export/pytorch.py @@ -47,6 +47,7 @@ def export_pytorch( task: str | None = None, verbose: bool = False, enable_reporting: bool = False, + normalize: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Export a PyTorch nn.Module to ONNX. @@ -63,9 +64,15 @@ def export_pytorch( task: Task for auto-input generation fallback. verbose: Enable verbose logging. enable_reporting: Generate export report file. + normalize: If True (default), run optimize_onnx on the exported model + to apply graph-level optimizations and shape inference. Set False + to keep the raw torch.onnx.export output (useful when debugging + the exporter or running custom downstream optimization). Returns: - Export statistics dict from HTPExporter. + Export statistics dict from HTPExporter, with an extra + `model_normalization_status` entry: one of `"not_run"` (when + `normalize=False`), `"succeeded"`, or `"failed"`. """ from .htp.exporter import HTPExporter @@ -95,7 +102,12 @@ def export_pytorch( **kwargs, ) - stats["model_normalization_succeeded"] = _normalize_exported_model(output_path) + if normalize: + stats["model_normalization_status"] = ( + "succeeded" if _normalize_exported_model(output_path) else "failed" + ) + else: + stats["model_normalization_status"] = "not_run" return stats @@ -105,13 +117,19 @@ def _normalize_exported_model(output_path: Path) -> bool: Writes the normalized model into a temporary directory, then replaces the original export (and its `.data` sidecar, if any) via - copy_onnx_model — which overwrites both the `.onnx` file and its - external-data sidecar by design. On any failure during optimization - or copy, logs a warning and leaves the original export untouched. - The temp directory is removed in either case. + copy_onnx_model. The temp directory is removed either way. + + Failure modes are not symmetric: + - An optimize_onnx failure leaves the original export untouched: the + temp directory is the only write target, and it is cleaned up. + - A copy_onnx_model failure may leave the original `.onnx` and/or + `.data` sidecar partially overwritten: copy_onnx_model writes + directly to the destination (no temp-and-rename), so a process + kill or full disk mid-copy can corrupt the destination. Returns: - True if normalization succeeded, False otherwise. + True if normalization succeeded, False otherwise. On False, the + traceback is included in the warning log to aid debugging. """ import shutil import tempfile @@ -126,8 +144,11 @@ def _normalize_exported_model(output_path: Path) -> bool: try: optimize_onnx(model=output_path, output=tmp_path) copy_onnx_model(tmp_path, output_path) - except Exception as e: - logger.warning("Normalization failed; keeping un-normalized export: %s", e) + except Exception: + logger.warning( + "Normalization failed; keeping un-normalized export", + exc_info=True, + ) return False finally: shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/tests/unit/export/test_pytorch_export.py b/tests/unit/export/test_pytorch_export.py index ff656ca31..9b74c142a 100644 --- a/tests/unit/export/test_pytorch_export.py +++ b/tests/unit/export/test_pytorch_export.py @@ -285,20 +285,20 @@ def forward(self, ids): pass def test_normalization_succeeds_and_shape_inferences(self, tmp_path) -> None: - """After export, stats reports normalization succeeded and value_info is fully shaped.""" + """After export, status reports succeeded and value_info is fully shaped.""" model = TwoLayerNet() config = WinMLExportConfig( input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], ) result = export_pytorch(model, tmp_path / "model.onnx", config) - assert result["model_normalization_succeeded"] is True + assert result["model_normalization_status"] == "succeeded" onnx_model = onnx.load(str(tmp_path / "model.onnx")) assert _all_value_info_have_shape(onnx_model) def test_failed_normalization_skips_shape_inference(self, tmp_path) -> None: - """When normalization is mocked to return False, no shape inference runs.""" + """When normalization is mocked to return False, status is failed.""" model = TwoLayerNet() config = WinMLExportConfig( input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], @@ -309,7 +309,24 @@ def test_failed_normalization_skips_shape_inference(self, tmp_path) -> None: ): result = export_pytorch(model, tmp_path / "model.onnx", config) - assert result["model_normalization_succeeded"] is False + assert result["model_normalization_status"] == "failed" + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert not _all_value_info_have_shape(onnx_model) + + def test_normalize_false_skips_normalization(self, tmp_path) -> None: + """When normalize=False, the helper isn't called and status is not_run.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + with patch( + "winml.modelkit.export.pytorch._normalize_exported_model", + ) as mock_normalize: + result = export_pytorch(model, tmp_path / "model.onnx", config, normalize=False) + + mock_normalize.assert_not_called() + assert result["model_normalization_status"] == "not_run" onnx_model = onnx.load(str(tmp_path / "model.onnx")) assert not _all_value_info_have_shape(onnx_model) From 0894460b5e81ac7900f55b64d5e6ed305f69186a Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 20 May 2026 15:29:22 +0800 Subject: [PATCH 3/4] perf(export): place normalization temp dir next to output Pass `dir=output_path.parent` to `tempfile.mkdtemp()` so the staging directory lives on the same volume as the export. For multi-GB models (the primary WinML target) this avoids a cross-volume data transfer in `copy_onnx_model` and keeps large sidecars off the system drive's `%TEMP%`. Addresses the perf review comment on #681. --- src/winml/modelkit/export/pytorch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/winml/modelkit/export/pytorch.py b/src/winml/modelkit/export/pytorch.py index 80fff109a..3d4602a2c 100644 --- a/src/winml/modelkit/export/pytorch.py +++ b/src/winml/modelkit/export/pytorch.py @@ -138,7 +138,10 @@ def _normalize_exported_model(output_path: Path) -> bool: from ..optim import optimize_onnx logger.info("Normalizing model") - tmp_dir = Path(tempfile.mkdtemp()) + # Place the temp dir next to the output so copy_onnx_model stays on the + # same volume — avoids a cross-volume data transfer for multi-GB models + # and keeps the system drive's %TEMP% free of large sidecars. + tmp_dir = Path(tempfile.mkdtemp(dir=output_path.parent)) tmp_path = tmp_dir / output_path.name try: From 973943eb32ba7b9693b2b799d55f2af166eed2dd Mon Sep 17 00:00:00 2001 From: Yi Ren Date: Wed, 20 May 2026 15:37:51 +0800 Subject: [PATCH 4/4] fix(export): keep normalization inside warnings.catch_warnings() Move the `_normalize_exported_model` call inside the `warnings.catch_warnings()` block so the optimizer's shape-inference and ORT warnings stay suppressed for callers. PR #460 centralized warning suppression here; the new normalization call had escaped it and was leaking warnings to stderr. Addresses DingmaomaoBJTU's first review finding on #681. --- src/winml/modelkit/export/pytorch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/winml/modelkit/export/pytorch.py b/src/winml/modelkit/export/pytorch.py index 3d4602a2c..98202e8a7 100644 --- a/src/winml/modelkit/export/pytorch.py +++ b/src/winml/modelkit/export/pytorch.py @@ -102,12 +102,12 @@ def export_pytorch( **kwargs, ) - if normalize: - stats["model_normalization_status"] = ( - "succeeded" if _normalize_exported_model(output_path) else "failed" - ) - else: - stats["model_normalization_status"] = "not_run" + if normalize: + stats["model_normalization_status"] = ( + "succeeded" if _normalize_exported_model(output_path) else "failed" + ) + else: + stats["model_normalization_status"] = "not_run" return stats