Skip to content
10 changes: 6 additions & 4 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def _build_modules(
help="Skip quantization (overrides config)",
)
@click.option(
"--no-compile",
is_flag=True,
default=False,
help="Skip compilation (overrides config)",
"--no-compile/--compile",
"no_compile",
default=True,
help="Skip compilation (overrides config). Default: skip.",
)
@click.option(
"--ep",
Expand Down Expand Up @@ -961,6 +961,8 @@ def _run_compile_stage(
and Path(compile_result.output_path).resolve() != compiled_path.resolve()
):
copy_onnx_model(compile_result.output_path, compiled_path)
if not compiled_path.exists():
raise RuntimeError(f"Compile reported success but output not found: {compiled_path}")
current_path = compiled_path
_compile_elapsed = time.monotonic() - t0
sl.set_done(_compile_elapsed)
Expand Down
6 changes: 5 additions & 1 deletion src/winml/modelkit/compiler/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def for_provider(cls, provider: str | None) -> WinMLCompileConfig | None:
}
factory = factories.get(provider)
if factory:
return factory()
config = factory()
# EPs that don't produce EPContext have no offline compile step
if not config.ep_config.enable_ep_context:
return None
return config
# Generic fallback for unknown/custom providers
return cls(ep_config=EPConfig(provider=provider, enable_ep_context=False))

Expand Down
115 changes: 115 additions & 0 deletions tests/unit/commands/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,118 @@ def test_no_optimize_default_not_present(

extra = mock_build_api.call_args.kwargs["extra_kwargs"]
assert "skip_optimize" not in extra


# =============================================================================
# _run_compile_stage UNIT TESTS
# =============================================================================


class TestRunCompileStageNoOutput:
"""Test _run_compile_stage output validation."""

@patch("winml.modelkit.compiler.compile_onnx")
def test_none_compile_config_skips_stage(
self,
mock_compile: MagicMock,
tmp_path: Path,
) -> None:
"""compile=None skips compile_onnx entirely and returns current_path unchanged."""
from winml.modelkit.commands.build import _run_compile_stage
from winml.modelkit.config import WinMLBuildConfig

input_path = tmp_path / "quantized.onnx"
input_path.write_bytes(b"dummy")
compiled_path = tmp_path / "compiled.onnx"

config = WinMLBuildConfig(compile=None)
timings: list[tuple[str, float | None]] = []

result = _run_compile_stage(
config=config,
current_path=input_path,
compiled_path=compiled_path,
stage_timings=timings,
)

mock_compile.assert_not_called()
assert result == input_path

@patch("winml.modelkit.utils.console.get_onnx_graph_summary")
@patch("winml.modelkit.utils.console.StageLive")
@patch("winml.modelkit.compiler.compile_onnx")
def test_raises_when_ep_context_expected_but_missing(
self,
mock_compile: MagicMock,
mock_stage_live: MagicMock,
mock_graph_summary: MagicMock,
tmp_path: Path,
) -> None:
"""When enable_ep_context=True and compile succeeds but file is absent, raise."""
from winml.modelkit.commands.build import _run_compile_stage
from winml.modelkit.compiler.configs import WinMLCompileConfig
from winml.modelkit.compiler.result import CompileResult
from winml.modelkit.config import WinMLBuildConfig

mock_compile.return_value = CompileResult(success=True, output_path=None)
mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock())
mock_stage_live.return_value.__exit__ = MagicMock(return_value=False)

input_path = tmp_path / "quantized.onnx"
input_path.write_bytes(b"dummy")
compiled_path = tmp_path / "compiled.onnx" # Does NOT exist

config = WinMLBuildConfig(compile=WinMLCompileConfig.for_qnn())
timings: list[tuple[str, float | None]] = []

with pytest.raises(RuntimeError, match="output not found"):
_run_compile_stage(
config=config,
current_path=input_path,
compiled_path=compiled_path,
stage_timings=timings,
)

@patch("winml.modelkit.utils.console.get_onnx_graph_summary")
@patch("winml.modelkit.utils.console.StageLive")
@patch("winml.modelkit.compiler.compile_onnx")
@patch("winml.modelkit.onnx.external_data.copy_onnx_model")
def test_returns_compiled_path_when_file_exists(
self,
mock_copy: MagicMock,
mock_compile: MagicMock,
mock_stage_live: MagicMock,
mock_graph_summary: MagicMock,
tmp_path: Path,
) -> None:
"""When compile produces an output file, current_path should update."""
from winml.modelkit.commands.build import _run_compile_stage
from winml.modelkit.compiler.configs import WinMLCompileConfig
from winml.modelkit.compiler.result import CompileResult
from winml.modelkit.config import WinMLBuildConfig

input_path = tmp_path / "quantized.onnx"
input_path.write_bytes(b"dummy")
compiled_path = tmp_path / "compiled.onnx"
compiled_path.write_bytes(b"compiled_model") # File EXISTS

mock_compile.return_value = CompileResult(
success=True,
output_path=str(compiled_path),
)
mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock())
mock_stage_live.return_value.__exit__ = MagicMock(return_value=False)
mock_graph_summary.return_value = {"op_counts": {"EPContext": 1}}

config = WinMLBuildConfig(compile=WinMLCompileConfig.for_qnn())
timings: list[tuple[str, float | None]] = []

result = _run_compile_stage(
config=config,
current_path=input_path,
compiled_path=compiled_path,
stage_timings=timings,
)

# current_path should be updated to compiled_path
assert result == compiled_path
29 changes: 22 additions & 7 deletions tests/unit/compiler/test_compiler_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,19 @@ class TestForProvider:
"provider,expect_provider",
[
(None, None),
# EPs that produce EPContext → compile config returned
("qnn", "qnn"),
("dml", "dml"),
("cuda", "cuda"),
("nv_tensorrt_rtx", "nv_tensorrt_rtx"),
("openvino", "openvino"),
("vitisai", "vitisai"),
("migraphx", "migraphx"),
("cpu", "cpu"),
("custom_ep", "custom_ep"), # generic fallback
# EPs with enable_ep_context=False → no offline compile step → None
("dml", None),
("cpu", None),
("cuda", None),
("nv_tensorrt_rtx", None),
("vitisai", None),
("migraphx", None),
# Unknown/custom EPs use the generic fallback (enable_ep_context=False
# in the fallback does NOT apply the None rule — only known factories do)
("custom_ep", "custom_ep"),
],
)
def test_for_provider(
Expand All @@ -323,6 +327,17 @@ def test_for_provider(
assert result is not None
assert result.ep_config.provider == expect_provider

@pytest.mark.parametrize(
"factory_name",
["for_dml", "for_cpu", "for_cuda", "for_vitisai", "for_migraphx", "for_nv_tensorrt_rtx"],
)
def test_direct_factory_still_works(self, factory_name: str) -> None:
"""Low-level for_* factories are still callable directly even though
for_provider() returns None for these EPs."""
config = getattr(WinMLCompileConfig, factory_name)()
assert config is not None
assert config.ep_config.enable_ep_context is False

def test_for_provider_custom_ep_no_context(self):
"""Custom EP fallback disables EP context."""
result = WinMLCompileConfig.for_provider("custom_ep")
Expand Down
33 changes: 16 additions & 17 deletions tests/unit/config/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,9 +1785,9 @@ def _mock_deps(
("npu", "auto", True, "uint8", "uint16", "qnn"),
("npu", "fp16", False, None, None, "qnn"),
("npu", "int8", True, "uint8", "uint8", "qnn"),
("gpu", "auto", False, None, None, "dml"),
("gpu", "int8", True, "uint8", "uint8", "dml"),
("gpu", "fp16", False, None, None, "dml"),
("gpu", "auto", False, None, None, None),
("gpu", "int8", True, "uint8", "uint8", None),
("gpu", "fp16", False, None, None, None),
("cpu", "auto", False, None, None, None),
("cpu", "int8", True, "uint8", "uint8", None),
("cpu", "int16", True, "int16", "uint16", None),
Expand Down Expand Up @@ -2035,8 +2035,7 @@ def test_device_gpu_precision_fp16(self, tmp_path) -> None:
assert result.exit_code == 0, f"CLI failed: {result.output}"
data = json.loads(output_file.read_text())
assert data["quant"] is None
assert data["compile"] is not None
assert data["compile"]["execution_provider"] == "dml"
assert data["compile"] is None

def test_device_cpu_precision_fp32(self, tmp_path) -> None:
"""--device cpu --precision fp32 → no quant, no compile."""
Expand Down Expand Up @@ -2607,10 +2606,9 @@ def test_raw_onnx_with_gpu(self, tmp_path) -> None:
):
config = generate_onnx_build_config(str(onnx_file), device="gpu")

# GPU auto-precision is fp16 -> no quantization, compile=dml
# GPU auto-precision is fp16 -> no quantization, no compile (DML has no offline step)
assert config.quant is None
assert config.compile is not None
assert config.compile.ep_config.provider == "dml"
assert config.compile is None

def test_ep_override_forwarded(self, tmp_path) -> None:
"""Explicit ep parameter is forwarded to resolve_quant_compile_config."""
Expand All @@ -2631,8 +2629,8 @@ def test_ep_override_forwarded(self, tmp_path) -> None:
ep="migraphx",
)

assert config.compile is not None
assert config.compile.ep_config.provider == "migraphx"
# migraphx has enable_ep_context=False → no offline compile step
assert config.compile is None


# =============================================================================
Expand Down Expand Up @@ -2672,17 +2670,16 @@ def test_npu_returns_quant_and_compile(self) -> None:
assert isinstance(compile_cfg, WinMLCompileConfig)
assert compile_cfg.ep_config.provider == "qnn"

def test_gpu_returns_none_quant_and_dml_compile(self) -> None:
"""device=gpu returns (None, WinMLCompileConfig(dml))."""
def test_gpu_returns_none_quant_and_none_compile(self) -> None:
"""device=gpu returns (None, None) — DML has no offline compile step."""
with patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("gpu", ["gpu", "cpu"]),
):
quant, compile_cfg = resolve_quant_compile_config(device="gpu")

assert quant is None
assert isinstance(compile_cfg, WinMLCompileConfig)
assert compile_cfg.ep_config.provider == "dml"
assert compile_cfg is None

def test_cpu_returns_none_none(self) -> None:
"""device=cpu returns (None, None) since CPU has no compile provider."""
Expand All @@ -2696,7 +2693,10 @@ def test_cpu_returns_none_none(self) -> None:
assert compile_cfg is None

def test_ep_override_changes_provider(self) -> None:
"""Explicit ep overrides the default device-to-provider mapping."""
"""Explicit ep overrides the default device-to-provider mapping.

nv_tensorrt_rtx has enable_ep_context=False so for_provider returns None.
"""
with patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("gpu", ["gpu", "cpu"]),
Expand All @@ -2706,8 +2706,7 @@ def test_ep_override_changes_provider(self) -> None:
ep="nv_tensorrt_rtx",
)

assert compile_cfg is not None
assert compile_cfg.ep_config.provider == "nv_tensorrt_rtx"
assert compile_cfg is None

def test_task_forwarded_to_resolve_precision(self) -> None:
"""task parameter is forwarded to resolve_precision.
Expand Down
34 changes: 20 additions & 14 deletions tests/unit/config/test_build_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,11 @@ def test_compiled_does_not_call_resolve_quant_compile(self, tmp_path) -> None:
mock_resolve.assert_not_called()

def test_raw_onnx_with_gpu(self, tmp_path) -> None:
"""Raw ONNX + device=gpu resolves quant=None, compile=dml."""
"""Raw ONNX + device=gpu resolves quant=None, compile=None.

DML has enable_ep_context=False so for_provider("dml") returns None —
no offline compile step is needed.
"""
onnx_file = tmp_path / "model.onnx"
onnx_file.write_bytes(b"fake")

Expand All @@ -611,13 +615,15 @@ def test_raw_onnx_with_gpu(self, tmp_path) -> None:
):
config = generate_onnx_build_config(str(onnx_file), device="gpu")

# GPU auto-precision is fp16 -> no quantization, compile=dml
# GPU auto-precision is fp16 -> no quantization; DML has no EPContext step
assert config.quant is None
assert config.compile is not None
assert config.compile.ep_config.provider == "dml"
assert config.compile is None

def test_ep_override_forwarded(self, tmp_path) -> None:
"""Explicit ep parameter is forwarded to resolve_quant_compile_config."""
"""Explicit ep parameter is forwarded to resolve_quant_compile_config.

migraphx has enable_ep_context=False so for_provider("migraphx") returns None.
"""
onnx_file = tmp_path / "model.onnx"
onnx_file.write_bytes(b"fake")

Expand All @@ -635,8 +641,7 @@ def test_ep_override_forwarded(self, tmp_path) -> None:
ep="migraphx",
)

assert config.compile is not None
assert config.compile.ep_config.provider == "migraphx"
assert config.compile is None


# =============================================================================
Expand Down Expand Up @@ -676,17 +681,16 @@ def test_npu_returns_quant_and_compile(self) -> None:
assert isinstance(compile_cfg, WinMLCompileConfig)
assert compile_cfg.ep_config.provider == "qnn"

def test_gpu_returns_none_quant_and_dml_compile(self) -> None:
"""device=gpu returns (None, WinMLCompileConfig(dml))."""
def test_gpu_returns_none_quant_and_none_compile(self) -> None:
"""device=gpu returns (None, None) — DML has no EPContext step."""
with patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("gpu", ["gpu", "cpu"]),
):
quant, compile_cfg = resolve_quant_compile_config(device="gpu")

assert quant is None
assert isinstance(compile_cfg, WinMLCompileConfig)
assert compile_cfg.ep_config.provider == "dml"
assert compile_cfg is None

def test_cpu_returns_none_none(self) -> None:
"""device=cpu returns (None, None) since CPU has no compile provider."""
Expand All @@ -700,7 +704,10 @@ def test_cpu_returns_none_none(self) -> None:
assert compile_cfg is None

def test_ep_override_changes_provider(self) -> None:
"""Explicit ep overrides the default device-to-provider mapping."""
"""Explicit ep overrides the default device-to-provider mapping.

nv_tensorrt_rtx has enable_ep_context=False so for_provider returns None.
"""
with patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("gpu", ["gpu", "cpu"]),
Expand All @@ -710,8 +717,7 @@ def test_ep_override_changes_provider(self) -> None:
ep="nv_tensorrt_rtx",
)

assert compile_cfg is not None
assert compile_cfg.ep_config.provider == "nv_tensorrt_rtx"
assert compile_cfg is None

def test_task_forwarded_to_resolve_precision(self) -> None:
"""task parameter is forwarded to resolve_precision.
Expand Down
Loading