diff --git a/src/winml/modelkit/export/pytorch.py b/src/winml/modelkit/export/pytorch.py index 3d157ddba..98202e8a7 100644 --- a/src/winml/modelkit/export/pytorch.py +++ b/src/winml/modelkit/export/pytorch.py @@ -47,6 +47,7 @@ def export_pytorch( task: str | None = None, verbose: bool = False, enable_reporting: bool = False, + normalize: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Export a PyTorch nn.Module to ONNX. @@ -63,9 +64,15 @@ def export_pytorch( task: Task for auto-input generation fallback. verbose: Enable verbose logging. enable_reporting: Generate export report file. + normalize: If True (default), run optimize_onnx on the exported model + to apply graph-level optimizations and shape inference. Set False + to keep the raw torch.onnx.export output (useful when debugging + the exporter or running custom downstream optimization). Returns: - Export statistics dict from HTPExporter. + Export statistics dict from HTPExporter, with an extra + `model_normalization_status` entry: one of `"not_run"` (when + `normalize=False`), `"succeeded"`, or `"failed"`. """ from .htp.exporter import HTPExporter @@ -86,7 +93,7 @@ def export_pytorch( ) with warnings.catch_warnings(): warnings.filterwarnings("ignore") - return exporter.export( + stats = exporter.export( model=model, output_path=str(output_path), export_config=export_config, @@ -94,3 +101,58 @@ def export_pytorch( task=task, **kwargs, ) + + if normalize: + stats["model_normalization_status"] = ( + "succeeded" if _normalize_exported_model(output_path) else "failed" + ) + else: + stats["model_normalization_status"] = "not_run" + + return stats + + +def _normalize_exported_model(output_path: Path) -> bool: + """Normalize the exported ONNX in-place via optimize_onnx. + + Writes the normalized model into a temporary directory, then replaces + the original export (and its `.data` sidecar, if any) via + copy_onnx_model. The temp directory is removed either way. + + Failure modes are not symmetric: + - An optimize_onnx failure leaves the original export untouched: the + temp directory is the only write target, and it is cleaned up. + - A copy_onnx_model failure may leave the original `.onnx` and/or + `.data` sidecar partially overwritten: copy_onnx_model writes + directly to the destination (no temp-and-rename), so a process + kill or full disk mid-copy can corrupt the destination. + + Returns: + True if normalization succeeded, False otherwise. On False, the + traceback is included in the warning log to aid debugging. + """ + import shutil + import tempfile + + from ..onnx import copy_onnx_model + from ..optim import optimize_onnx + + logger.info("Normalizing model") + # Place the temp dir next to the output so copy_onnx_model stays on the + # same volume — avoids a cross-volume data transfer for multi-GB models + # and keeps the system drive's %TEMP% free of large sidecars. + tmp_dir = Path(tempfile.mkdtemp(dir=output_path.parent)) + tmp_path = tmp_dir / output_path.name + + try: + optimize_onnx(model=output_path, output=tmp_path) + copy_onnx_model(tmp_path, output_path) + except Exception: + logger.warning( + "Normalization failed; keeping un-normalized export", + exc_info=True, + ) + return False + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + return True diff --git a/tests/unit/export/test_pytorch_export.py b/tests/unit/export/test_pytorch_export.py index 89fd797a3..9b74c142a 100644 --- a/tests/unit/export/test_pytorch_export.py +++ b/tests/unit/export/test_pytorch_export.py @@ -10,6 +10,8 @@ from __future__ import annotations +from unittest.mock import patch + import onnx import pytest import torch @@ -23,6 +25,20 @@ ) +def _all_value_info_have_shape(model: onnx.ModelProto) -> bool: + """Every intermediate (value_info) tensor has a concrete or symbolic shape.""" + if not model.graph.value_info: + return False + for vi in model.graph.value_info: + shape = vi.type.tensor_type.shape + if not shape.dim: + return False + for dim in shape.dim: + if not dim.HasField("dim_value") and not dim.HasField("dim_param"): + return False + return True + + # ============================================================================= # Test Models (pure PyTorch, no HF) # ============================================================================= @@ -268,6 +284,53 @@ def forward(self, ids): # This is expected — the test verifies the flow works up to export pass + def test_normalization_succeeds_and_shape_inferences(self, tmp_path) -> None: + """After export, status reports succeeded and value_info is fully shaped.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + result = export_pytorch(model, tmp_path / "model.onnx", config) + + assert result["model_normalization_status"] == "succeeded" + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert _all_value_info_have_shape(onnx_model) + + def test_failed_normalization_skips_shape_inference(self, tmp_path) -> None: + """When normalization is mocked to return False, status is failed.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + with patch( + "winml.modelkit.export.pytorch._normalize_exported_model", + return_value=False, + ): + result = export_pytorch(model, tmp_path / "model.onnx", config) + + assert result["model_normalization_status"] == "failed" + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert not _all_value_info_have_shape(onnx_model) + + def test_normalize_false_skips_normalization(self, tmp_path) -> None: + """When normalize=False, the helper isn't called and status is not_run.""" + model = TwoLayerNet() + config = WinMLExportConfig( + input_tensors=[InputTensorSpec(name="x", dtype="float32", shape=(1, 10))], + ) + with patch( + "winml.modelkit.export.pytorch._normalize_exported_model", + ) as mock_normalize: + result = export_pytorch(model, tmp_path / "model.onnx", config, normalize=False) + + mock_normalize.assert_not_called() + assert result["model_normalization_status"] == "not_run" + + onnx_model = onnx.load(str(tmp_path / "model.onnx")) + assert not _all_value_info_have_shape(onnx_model) + def test_mismatched_input_order_exports_successfully(self, tmp_path) -> None: """Export succeeds when InputTensorSpec order differs from forward() param order. diff --git a/tests/unit/onnx/test_external_data.py b/tests/unit/onnx/test_external_data.py index 4a77e3daf..0d118f475 100644 --- a/tests/unit/onnx/test_external_data.py +++ b/tests/unit/onnx/test_external_data.py @@ -10,7 +10,7 @@ import numpy as np import onnx -from onnx import TensorProto, helper, numpy_helper +from onnx import TensorProto, external_data_helper, helper, numpy_helper from winml.modelkit.onnx.external_data import ( copy_onnx_model, @@ -27,14 +27,39 @@ def _make_small_model() -> onnx.ModelProto: """Create a minimal ONNX model (no external data).""" x_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4]) y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2]) - weight = numpy_helper.from_array( - np.random.randn(4, 2).astype(np.float32), name="W" - ) + weight = numpy_helper.from_array(np.random.randn(4, 2).astype(np.float32), name="W") node = helper.make_node("MatMul", ["X", "W"], ["Y"]) graph = helper.make_graph([node], "test", [x_info], [y_info], [weight]) return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) +def _make_filled_model(value: float, shape: tuple[int, ...]) -> onnx.ModelProto: + """Create a deterministic ONNX model with a constant-filled initializer. + + Used by overwrite tests where two distinguishable models are needed. + """ + weight = numpy_helper.from_array(np.full(shape, value, dtype=np.float32), name="W") + inp = helper.make_tensor_value_info("X", TensorProto.FLOAT, list(shape)) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(shape)) + node = helper.make_node("Add", ["X", "W"], ["Y"]) + graph = helper.make_graph([node], "g", [inp], [out], [weight]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def _serialize_without_external_location(model: onnx.ModelProto) -> bytes: + """Serialize the model with the `location` entry stripped from every + external_data tensor — for comparing two models that point to different + sidecar filenames but are otherwise identical.""" + clone = onnx.ModelProto() + clone.CopyFrom(model) + for tensor in external_data_helper._get_all_tensors(clone): + if tensor.data_location == TensorProto.EXTERNAL: + for entry in list(tensor.external_data): + if entry.key == "location": + tensor.external_data.remove(entry) + return clone.SerializeToString(deterministic=True) + + class TestGetExternalDataFiles: """Tests for get_external_data_files().""" @@ -51,7 +76,8 @@ def test_with_external_data(self, tmp_path: Path) -> None: model = _make_small_model() path = tmp_path / "ext.onnx" onnx.save_model( - model, str(path), + model, + str(path), save_as_external_data=True, all_tensors_to_one_file=True, location="ext.onnx.data", @@ -74,7 +100,8 @@ def test_with_external(self, tmp_path: Path) -> None: model = _make_small_model() path = tmp_path / "ext.onnx" onnx.save_model( - model, str(path), + model, + str(path), save_as_external_data=True, all_tensors_to_one_file=True, location="ext.onnx.data", @@ -106,7 +133,8 @@ def test_copy_with_external_data(self, tmp_path: Path) -> None: src = tmp_path / "src" / "model.onnx" src.parent.mkdir() onnx.save_model( - model, str(src), + model, + str(src), save_as_external_data=True, all_tensors_to_one_file=True, location="model.onnx.data", @@ -145,3 +173,88 @@ def test_copy_invalid_file_falls_back(self, tmp_path: Path) -> None: assert dst.exists() assert dst.read_text() == "not a real onnx file" + + def test_copy_overwrites_existing_dst_no_external_data(self, tmp_path: Path) -> None: + """Pre-existing dst (no external data) is overwritten byte-for-byte by src.""" + src = tmp_path / "src.onnx" + dst = tmp_path / "dst.onnx" + + onnx.save(_make_filled_model(1.0, (4, 4)), str(src)) + onnx.save(_make_filled_model(99.0, (8, 8)), str(dst)) # pre-existing, different + + pre_dst_bytes = dst.read_bytes() + src_bytes = src.read_bytes() + assert pre_dst_bytes != src_bytes + + copy_onnx_model(src, dst) + + post_dst_bytes = dst.read_bytes() + assert post_dst_bytes == src_bytes + assert post_dst_bytes != pre_dst_bytes + assert not (tmp_path / "dst.onnx.data").exists() + + def test_copy_overwrites_existing_dst_with_external_data(self, tmp_path: Path) -> None: + """Pre-existing dst + sidecar (external data) are both overwritten. + + Verifies: + - dst.onnx.data is byte-identical to src.onnx.data + - dst.onnx matches src.onnx except for the external_data.location field + - dst.onnx's location field points at dst.onnx.data + - Loaded initializer arrays are equal + """ + src = tmp_path / "src.onnx" + dst = tmp_path / "dst.onnx" + src_data = tmp_path / "src.onnx.data" + dst_data = tmp_path / "dst.onnx.data" + + onnx.save_model( + _make_filled_model(2.0, (64, 64)), + str(src), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="src.onnx.data", + size_threshold=0, + ) + onnx.save_model( + _make_filled_model(999.0, (32, 32)), + str(dst), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="dst.onnx.data", + size_threshold=0, + ) + + src_data_bytes = src_data.read_bytes() + pre_dst_data_bytes = dst_data.read_bytes() + pre_dst_onnx_bytes = dst.read_bytes() + assert src_data_bytes != pre_dst_data_bytes + + copy_onnx_model(src, dst) + + # .data file byte-identical to src's sidecar + post_dst_data_bytes = dst_data.read_bytes() + assert post_dst_data_bytes == src_data_bytes + assert post_dst_data_bytes != pre_dst_data_bytes + + # .onnx file no longer matches old dst + assert dst.read_bytes() != pre_dst_onnx_bytes + + # .onnx matches src modulo external_data.location field + src_model = onnx.load(str(src), load_external_data=False) + dst_model = onnx.load(str(dst), load_external_data=False) + assert _serialize_without_external_location( + src_model + ) == _serialize_without_external_location(dst_model) + + # dst.onnx's location must point at dst.onnx.data + for tensor in external_data_helper._get_all_tensors(dst_model): + if tensor.data_location == TensorProto.EXTERNAL: + info = external_data_helper.ExternalDataInfo(tensor) + assert info.location == "dst.onnx.data" + + # Semantic check: loaded initializer arrays are equal + src_full = onnx.load(str(src), load_external_data=True) + dst_full = onnx.load(str(dst), load_external_data=True) + src_arr = numpy_helper.to_array(src_full.graph.initializer[0]) + dst_arr = numpy_helper.to_array(dst_full.graph.initializer[0]) + assert np.array_equal(src_arr, dst_arr)