Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ def _perf_modules(
verbose: bool,
console: Console,
monitor: bool = False,
device: str = "auto",
ep: str | None = None,
precision: str = "auto",
) -> None:
"""Run per-module build and benchmark for matching submodules.

Expand All @@ -523,21 +526,31 @@ def _perf_modules(
verbose: If True, log exceptions at DEBUG level.
console: Rich console for output.
monitor: If True, wrap each per-module benchmark with HWMonitor.
device: Target device policy ("auto", "cpu", "gpu", "npu").
ep: Explicit execution provider (e.g., "qnn", "dml"). Overrides
device-to-provider mapping when set.
precision: Precision mode passed through to the build stage.
"""
import json as json_mod
import tempfile

from ..build import build_hf_model
from ..config import generate_hf_build_config
from ..sysinfo import resolve_device
from .build import _instantiate_parent_model

resolved_device, _ = resolve_device(device=device)

console.print(f"[dim]Generating module configs for {module_class}...[/dim]")

try:
module_configs = generate_hf_build_config(
model_id=hf_model,
task=task,
module=module_class,
device=resolved_device,
precision=precision,
ep=ep,
)
except Exception as e:
console.print(f"[red]Error generating module configs: {e}[/red]")
Expand Down Expand Up @@ -590,12 +603,18 @@ def _perf_modules(
config=cfg,
output_dir=Path(tmpdir),
pytorch_model=submodule,
ep=ep,
device=resolved_device,
)

# Benchmark using WinMLSession
from ..session import WinMLSession

session = WinMLSession(str(build_result.final_onnx_path))
session = WinMLSession(
str(build_result.final_onnx_path),
device=resolved_device,
ep=ep,
)
io_cfg = session.io_config
inputs = generate_random_inputs(io_cfg, batch_size=batch_size)

Expand Down Expand Up @@ -1250,6 +1269,9 @@ def perf(
verbose=verbose,
console=console,
monitor=monitor,
device=device.lower(),
ep=ep.lower() if ep else None,
precision=precision.lower(),
)
return

Expand Down
83 changes: 50 additions & 33 deletions src/winml/modelkit/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def compile(self) -> None:
# Log which providers were selected by ORT (based on policy)
actual_providers = session.get_providers()
logger.info(
"Session created with policy %s, providers: %s",
"Session created with device %s, providers: %s",
target_device,
actual_providers,
)
Expand Down Expand Up @@ -432,27 +432,40 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:
Note: Returns a **fresh** SessionOptions when using explicit EP to
avoid "already registered" errors from repeated calls.
"""
# Explicit EP targeting: create fresh opts to avoid double-registration
# Don't filter by device type — trust the user's --ep choice
# (e.g., QNN reports as NPU in get_ep_devices but can target GPU)
# Explicit EP targeting: create fresh opts to avoid double-registration.
# When device is also specified (non-"auto"), narrow by both EP name
# and device type so e.g. `--ep qnn --device cpu` finds QNN-on-CPU
# instead of the first QNN ep_device (which may report as NPU).
if self._ep and self._ep != "cpu":
target_name = self._EP_NAME_MAP.get(self._ep)
if target_name:
matched = self._find_ep_device(target_name)
matched = self._find_ep_device(ep_name=target_name, device=device)
if matched:
from ..utils.constants import DEVICE_TYPE_TO_DEVICE

opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
logger.info("Explicit EP: %s (%s)", self._ep, target_name)
resolved = DEVICE_TYPE_TO_DEVICE.get(
matched.device.type, str(matched.device.type)
)
logger.info(
"Explicit EP: %s (%s) device=%s -> %s",
self._ep,
target_name,
device,
resolved,
)
return opts
logger.warning(
"EP '%s' (%s) not found in available devices",
"EP '%s' (%s) not found for device '%s'",
self._ep,
target_name,
device,
)

# No explicit EP — discover available EP for this device type
if not self._ep and device.lower() != "cpu":
matched = self._find_ep_for_device(device)
matched = self._find_ep_device(device=device)
if matched:
opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
Expand All @@ -465,48 +478,52 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:
device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU
)
opts.set_provider_selection_policy(policy)
logger.info("Using provider selection policy %s for device %s", policy, device)

return opts

@staticmethod
def _find_ep_device(ep_name: str) -> Any:
"""Find the first OrtEpDevice matching the given EP name.

Args:
ep_name: Full EP name (e.g., "DmlExecutionProvider").

Returns:
The matching OrtEpDevice, or None if not found.
"""
for ep_dev in ort.get_ep_devices():
if ep_dev.ep_name == ep_name:
return ep_dev
return None

@staticmethod
def _find_ep_for_device(device: str) -> Any:
"""Find the first available OrtEpDevice for the given device type.

Queries ``ort.get_ep_devices()`` and returns the first EP whose
hardware device type matches (e.g., device="gpu" matches GPU EPs).
def _find_ep_device(device: str, ep_name: str | None = None) -> Any:
"""Find the first OrtEpDevice matching the given filters.

Behavior:
- ``ep_name`` set, ``device == "auto"`` → first ep_device
matching ``ep_name`` (or None).
- ``ep_name`` unset, ``device == "auto"`` → ``None`` (no
effective filter — refuse to pick an arbitrary ep_device).
- ``ep_name`` unset, ``device`` is a concrete type → first
ep_device matching that device type (or None).
- Both set → ep_device must satisfy both (or None).

Note: Selection order is determined by the ORT EP registry, which is
not part of any documented contract. On systems where multiple EPs
match the same device type (e.g., QNN and DML both appear as GPU),
the result is registry-order dependent. When a specific EP is
required, use ``self._ep`` to bypass this discovery path entirely.
a device-only query returns the first one in registry order. Pass
``ep_name`` to disambiguate.

Args:
device: Device policy ("cpu", "gpu", "npu", "auto"). ``"auto"``
and unknown strings act as no-op device filters.
ep_name: Full EP name (e.g., "DmlExecutionProvider"), or None
to skip EP-name filtering.

Returns:
The matching OrtEpDevice, or None if not found.
"""
from ..utils.constants import DEVICE_TO_DEVICE_TYPE

device_type = DEVICE_TO_DEVICE_TYPE.get(device.upper())
if device_type is None:

# No effective filter — refuse to pick an arbitrary ep_device.
if not ep_name and device_type is None:
return None

for ep_dev in ort.get_ep_devices():
if ep_dev.device.type == device_type:
return ep_dev
if ep_name and ep_dev.ep_name != ep_name:
continue
if device_type is not None and ep_dev.device.type != device_type:
continue
return ep_dev
return None

def _validate_inputs(self, inputs: dict[str, Any]) -> None:
Expand Down
5 changes: 5 additions & 0 deletions src/winml/modelkit/sysinfo/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def resolve_device(device: str = "auto") -> tuple[str, list[str]]:
for dev in available_devices:
compatible_eps = _DEVICE_EP_MAP.get(dev, [])
if any(ep in available_eps for ep in compatible_eps):
logger.info(
Comment thread
xieofxie marked this conversation as resolved.
"Auto-selected device '%s' with compatible EPs: %s for auto device",
dev,
sorted(ep for ep in compatible_eps if ep in available_eps),
)
return dev, available_devices
# Fallback: CPU is always valid
return "cpu", available_devices
Expand Down
89 changes: 89 additions & 0 deletions tests/unit/commands/test_perf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

from click.testing import CliRunner

from winml.modelkit.cli import main
from winml.modelkit.commands.perf import generate_output_path


if TYPE_CHECKING:
from pathlib import Path


class TestPerfModuleFlag:
"""Tests for --module flag on winml perf."""

Expand Down Expand Up @@ -44,3 +51,85 @@ def test_module_default_output_includes_class_name(self) -> None:
module_path = Path(f"{slug}_{module_class}_perf.json")
assert module_class in str(module_path)
assert str(module_path) != str(path)


class TestPerfModuleParameterForwarding:
"""Verify --device/--ep/--precision flow from CLI through _perf_modules
into generate_hf_build_config, build_hf_model, and WinMLSession.

Regression guard: these kwargs were silently dropped before.
"""

def test_device_and_ep_forwarded_through_module_path(self, tmp_path: Path) -> None:
# Fake module config -- only the attributes _perf_modules touches
fake_cfg = MagicMock()
fake_cfg.loader.model_type = "bert"
fake_cfg.loader.module_path = "encoder.layer.0"

fake_build_result = MagicMock()
fake_build_result.final_onnx_path = tmp_path / "model.onnx"

# Make WinMLSession.perf() raise so the benchmark loop is short-circuited
# via the existing try/except in _perf_modules. We still capture the
# constructor kwargs, which is what we care about.
fake_session = MagicMock()
fake_session.perf.side_effect = RuntimeError("test-skip-benchmark")

with (
patch(
"winml.modelkit.sysinfo.resolve_device",
return_value=("npu", "qnn"),
),
patch(
"winml.modelkit.config.generate_hf_build_config",
return_value=[fake_cfg],
) as mock_gen,
patch(
"winml.modelkit.commands.build._instantiate_parent_model",
return_value=MagicMock(),
),
patch(
"winml.modelkit.build.build_hf_model",
return_value=fake_build_result,
) as mock_build,
patch(
"winml.modelkit.session.WinMLSession",
return_value=fake_session,
) as mock_session_cls,
):
runner = CliRunner()
result = runner.invoke(
main,
[
"perf",
"-m",
"fake/model",
"--module",
"BertLayer",
"--device",
"npu",
"--ep",
"qnn",
"--iterations",
"1",
"--warmup",
"0",
"-o",
str(tmp_path / "out.json"),
],
)

assert result.exit_code == 0, result.output

gen_kwargs = mock_gen.call_args.kwargs
assert gen_kwargs["device"] == "npu"
assert gen_kwargs["ep"] == "qnn"
assert gen_kwargs["precision"] == "auto"

build_kwargs = mock_build.call_args.kwargs
assert build_kwargs["ep"] == "qnn"
assert build_kwargs["device"] == "npu"

session_kwargs = mock_session_cls.call_args.kwargs
assert session_kwargs["device"] == "npu"
assert session_kwargs["ep"] == "qnn"
Loading
Loading