diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index 841080fe1..600403524 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -504,6 +504,9 @@ def _perf_modules( verbose: bool, console: Console, monitor: bool = False, + device: str = "auto", + ep: str | None = None, + precision: str = "auto", ) -> None: """Run per-module build and benchmark for matching submodules. @@ -523,14 +526,21 @@ def _perf_modules( verbose: If True, log exceptions at DEBUG level. console: Rich console for output. monitor: If True, wrap each per-module benchmark with HWMonitor. + device: Target device policy ("auto", "cpu", "gpu", "npu"). + ep: Explicit execution provider (e.g., "qnn", "dml"). Overrides + device-to-provider mapping when set. + precision: Precision mode passed through to the build stage. """ import json as json_mod import tempfile from ..build import build_hf_model from ..config import generate_hf_build_config + from ..sysinfo import resolve_device from .build import _instantiate_parent_model + resolved_device, _ = resolve_device(device=device) + console.print(f"[dim]Generating module configs for {module_class}...[/dim]") try: @@ -538,6 +548,9 @@ def _perf_modules( model_id=hf_model, task=task, module=module_class, + device=resolved_device, + precision=precision, + ep=ep, ) except Exception as e: console.print(f"[red]Error generating module configs: {e}[/red]") @@ -590,12 +603,18 @@ def _perf_modules( config=cfg, output_dir=Path(tmpdir), pytorch_model=submodule, + ep=ep, + device=resolved_device, ) # Benchmark using WinMLSession from ..session import WinMLSession - session = WinMLSession(str(build_result.final_onnx_path)) + session = WinMLSession( + str(build_result.final_onnx_path), + device=resolved_device, + ep=ep, + ) io_cfg = session.io_config inputs = generate_random_inputs(io_cfg, batch_size=batch_size) @@ -1250,6 +1269,9 @@ def perf( verbose=verbose, console=console, monitor=monitor, + device=device.lower(), + ep=ep.lower() if ep else None, + precision=precision.lower(), ) return diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index 6349cac62..9c12b7c96 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -301,7 +301,7 @@ def compile(self) -> None: # Log which providers were selected by ORT (based on policy) actual_providers = session.get_providers() logger.info( - "Session created with policy %s, providers: %s", + "Session created with device %s, providers: %s", target_device, actual_providers, ) @@ -432,27 +432,40 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: Note: Returns a **fresh** SessionOptions when using explicit EP to avoid "already registered" errors from repeated calls. """ - # Explicit EP targeting: create fresh opts to avoid double-registration - # Don't filter by device type — trust the user's --ep choice - # (e.g., QNN reports as NPU in get_ep_devices but can target GPU) + # Explicit EP targeting: create fresh opts to avoid double-registration. + # When device is also specified (non-"auto"), narrow by both EP name + # and device type so e.g. `--ep qnn --device cpu` finds QNN-on-CPU + # instead of the first QNN ep_device (which may report as NPU). if self._ep and self._ep != "cpu": target_name = self._EP_NAME_MAP.get(self._ep) if target_name: - matched = self._find_ep_device(target_name) + matched = self._find_ep_device(ep_name=target_name, device=device) if matched: + from ..utils.constants import DEVICE_TYPE_TO_DEVICE + opts = ort.SessionOptions() opts.add_provider_for_devices([matched], self._provider_options) - logger.info("Explicit EP: %s (%s)", self._ep, target_name) + resolved = DEVICE_TYPE_TO_DEVICE.get( + matched.device.type, str(matched.device.type) + ) + logger.info( + "Explicit EP: %s (%s) device=%s -> %s", + self._ep, + target_name, + device, + resolved, + ) return opts logger.warning( - "EP '%s' (%s) not found in available devices", + "EP '%s' (%s) not found for device '%s'", self._ep, target_name, + device, ) # No explicit EP — discover available EP for this device type if not self._ep and device.lower() != "cpu": - matched = self._find_ep_for_device(device) + matched = self._find_ep_device(device=device) if matched: opts = ort.SessionOptions() opts.add_provider_for_devices([matched], self._provider_options) @@ -465,36 +478,34 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU ) opts.set_provider_selection_policy(policy) + logger.info("Using provider selection policy %s for device %s", policy, device) return opts @staticmethod - def _find_ep_device(ep_name: str) -> Any: - """Find the first OrtEpDevice matching the given EP name. - - Args: - ep_name: Full EP name (e.g., "DmlExecutionProvider"). - - Returns: - The matching OrtEpDevice, or None if not found. - """ - for ep_dev in ort.get_ep_devices(): - if ep_dev.ep_name == ep_name: - return ep_dev - return None - - @staticmethod - def _find_ep_for_device(device: str) -> Any: - """Find the first available OrtEpDevice for the given device type. - - Queries ``ort.get_ep_devices()`` and returns the first EP whose - hardware device type matches (e.g., device="gpu" matches GPU EPs). + def _find_ep_device(device: str, ep_name: str | None = None) -> Any: + """Find the first OrtEpDevice matching the given filters. + + Behavior: + - ``ep_name`` set, ``device == "auto"`` → first ep_device + matching ``ep_name`` (or None). + - ``ep_name`` unset, ``device == "auto"`` → ``None`` (no + effective filter — refuse to pick an arbitrary ep_device). + - ``ep_name`` unset, ``device`` is a concrete type → first + ep_device matching that device type (or None). + - Both set → ep_device must satisfy both (or None). Note: Selection order is determined by the ORT EP registry, which is not part of any documented contract. On systems where multiple EPs match the same device type (e.g., QNN and DML both appear as GPU), - the result is registry-order dependent. When a specific EP is - required, use ``self._ep`` to bypass this discovery path entirely. + a device-only query returns the first one in registry order. Pass + ``ep_name`` to disambiguate. + + Args: + device: Device policy ("cpu", "gpu", "npu", "auto"). ``"auto"`` + and unknown strings act as no-op device filters. + ep_name: Full EP name (e.g., "DmlExecutionProvider"), or None + to skip EP-name filtering. Returns: The matching OrtEpDevice, or None if not found. @@ -502,11 +513,17 @@ def _find_ep_for_device(device: str) -> Any: from ..utils.constants import DEVICE_TO_DEVICE_TYPE device_type = DEVICE_TO_DEVICE_TYPE.get(device.upper()) - if device_type is None: + + # No effective filter — refuse to pick an arbitrary ep_device. + if not ep_name and device_type is None: return None + for ep_dev in ort.get_ep_devices(): - if ep_dev.device.type == device_type: - return ep_dev + if ep_name and ep_dev.ep_name != ep_name: + continue + if device_type is not None and ep_dev.device.type != device_type: + continue + return ep_dev return None def _validate_inputs(self, inputs: dict[str, Any]) -> None: diff --git a/src/winml/modelkit/sysinfo/device.py b/src/winml/modelkit/sysinfo/device.py index 35ad1d403..c7fc1a942 100644 --- a/src/winml/modelkit/sysinfo/device.py +++ b/src/winml/modelkit/sysinfo/device.py @@ -188,6 +188,11 @@ def resolve_device(device: str = "auto") -> tuple[str, list[str]]: for dev in available_devices: compatible_eps = _DEVICE_EP_MAP.get(dev, []) if any(ep in available_eps for ep in compatible_eps): + logger.info( + "Auto-selected device '%s' with compatible EPs: %s for auto device", + dev, + sorted(ep for ep in compatible_eps if ep in available_eps), + ) return dev, available_devices # Fallback: CPU is always valid return "cpu", available_devices diff --git a/tests/unit/commands/test_perf_module.py b/tests/unit/commands/test_perf_module.py index 46ae2c2c3..b0fee4bc3 100644 --- a/tests/unit/commands/test_perf_module.py +++ b/tests/unit/commands/test_perf_module.py @@ -6,12 +6,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + from click.testing import CliRunner from winml.modelkit.cli import main from winml.modelkit.commands.perf import generate_output_path +if TYPE_CHECKING: + from pathlib import Path + + class TestPerfModuleFlag: """Tests for --module flag on winml perf.""" @@ -44,3 +51,85 @@ def test_module_default_output_includes_class_name(self) -> None: module_path = Path(f"{slug}_{module_class}_perf.json") assert module_class in str(module_path) assert str(module_path) != str(path) + + +class TestPerfModuleParameterForwarding: + """Verify --device/--ep/--precision flow from CLI through _perf_modules + into generate_hf_build_config, build_hf_model, and WinMLSession. + + Regression guard: these kwargs were silently dropped before. + """ + + def test_device_and_ep_forwarded_through_module_path(self, tmp_path: Path) -> None: + # Fake module config -- only the attributes _perf_modules touches + fake_cfg = MagicMock() + fake_cfg.loader.model_type = "bert" + fake_cfg.loader.module_path = "encoder.layer.0" + + fake_build_result = MagicMock() + fake_build_result.final_onnx_path = tmp_path / "model.onnx" + + # Make WinMLSession.perf() raise so the benchmark loop is short-circuited + # via the existing try/except in _perf_modules. We still capture the + # constructor kwargs, which is what we care about. + fake_session = MagicMock() + fake_session.perf.side_effect = RuntimeError("test-skip-benchmark") + + with ( + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("npu", "qnn"), + ), + patch( + "winml.modelkit.config.generate_hf_build_config", + return_value=[fake_cfg], + ) as mock_gen, + patch( + "winml.modelkit.commands.build._instantiate_parent_model", + return_value=MagicMock(), + ), + patch( + "winml.modelkit.build.build_hf_model", + return_value=fake_build_result, + ) as mock_build, + patch( + "winml.modelkit.session.WinMLSession", + return_value=fake_session, + ) as mock_session_cls, + ): + runner = CliRunner() + result = runner.invoke( + main, + [ + "perf", + "-m", + "fake/model", + "--module", + "BertLayer", + "--device", + "npu", + "--ep", + "qnn", + "--iterations", + "1", + "--warmup", + "0", + "-o", + str(tmp_path / "out.json"), + ], + ) + + assert result.exit_code == 0, result.output + + gen_kwargs = mock_gen.call_args.kwargs + assert gen_kwargs["device"] == "npu" + assert gen_kwargs["ep"] == "qnn" + assert gen_kwargs["precision"] == "auto" + + build_kwargs = mock_build.call_args.kwargs + assert build_kwargs["ep"] == "qnn" + assert build_kwargs["device"] == "npu" + + session_kwargs = mock_session_cls.call_args.kwargs + assert session_kwargs["device"] == "npu" + assert session_kwargs["ep"] == "qnn" diff --git a/tests/unit/session/test_winml_session.py b/tests/unit/session/test_winml_session.py index a0cc54705..c82e0ddd6 100644 --- a/tests/unit/session/test_winml_session.py +++ b/tests/unit/session/test_winml_session.py @@ -680,3 +680,96 @@ def test_perf_stats_accessible_after_context( assert stats.count == 3 # 5 - 2 warmup assert stats.mean_ms > 0 assert stats.p99_ms > 0 + + +class TestFindEpDevice: + """Tests for WinMLSession._find_ep_device combined ep_name + device filter. + + Regression guard: when both filters are set, both must match (AND logic). + Previously these were two separate methods and the explicit-EP path + ignored device, so `--ep qnn --device cpu` could return QNN-on-NPU. + """ + + @staticmethod + def _ep_dev(name: str, dev_type) -> object: + """Build a fake OrtEpDevice-like object.""" + from types import SimpleNamespace + + return SimpleNamespace(ep_name=name, device=SimpleNamespace(type=dev_type)) + + def _patch_devices(self, devices: list) -> object: + """Return a contextmanager that patches ort.get_ep_devices.""" + from unittest.mock import patch + + return patch( + "winml.modelkit.session.session.ort.get_ep_devices", + return_value=devices, + ) + + def test_ep_name_only(self) -> None: + """ep_name filter with device='auto' returns first matching name.""" + import onnxruntime as ort + + devs = [ + self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU), + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(device="auto", ep_name="QNNExecutionProvider") + assert match is not None + assert match.ep_name == "QNNExecutionProvider" + + def test_device_only(self) -> None: + """device filter alone returns first matching device type.""" + import onnxruntime as ort + + devs = [ + self._ep_dev("CPUExecutionProvider", ort.OrtHardwareDeviceType.CPU), + self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(device="gpu") + assert match is not None + assert match.ep_name == "DmlExecutionProvider" + + def test_ep_name_and_device_both_required(self) -> None: + """When both filters are set, both must match (AND logic).""" + import onnxruntime as ort + + devs = [ + # QNN-on-NPU comes first; user asks for QNN-on-CPU + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU), + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.CPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="cpu") + assert match is not None + assert match.device.type == ort.OrtHardwareDeviceType.CPU + + def test_no_match_returns_none(self) -> None: + """Non-matching combination returns None even if individual filters would match.""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="cpu") + assert match is None + + def test_auto_device_acts_as_no_filter(self) -> None: + """device='auto' falls back to ep_name-only matching.""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="auto") + assert match is not None + assert match.ep_name == "QNNExecutionProvider" + + def test_ep_none_and_device_auto_returns_none(self) -> None: + """ep_name=None and device='auto' → None (no effective filter).""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + assert WinMLSession._find_ep_device(device="auto") is None + assert WinMLSession._find_ep_device(device="auto", ep_name=None) is None