Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def _run_onnx_benchmark(
"""
from ..session import WinMLSession

session = WinMLSession(onnx_path=onnx_path, device=device)
session = WinMLSession(onnx_path=onnx_path, device=device, ep=config.ep)

# Generate random inputs from session's I/O config
io_cfg = session.io_config
Expand Down
3 changes: 3 additions & 0 deletions src/winml/modelkit/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def from_onnx(
config=None,
device=device,
session_options=session_options,
ep=ep,
)

# Resolve output directory
Expand Down Expand Up @@ -228,6 +229,7 @@ def from_onnx(
config=None, # No HF PretrainedConfig for bare ONNX builds
device=device,
session_options=session_options,
ep=ep,
)

@classmethod
Expand Down Expand Up @@ -425,6 +427,7 @@ def from_pretrained(
onnx_path=onnx_path,
config=hf_config, # HF PretrainedConfig for pipeline compatibility
device=device, # pass user's original device string; WinMLSession handles "auto"
ep=resolved_ep,
)
model._build_config = config # resolved build config (task, quant, compile)
return model
Expand Down
3 changes: 3 additions & 0 deletions src/winml/modelkit/models/winml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
config: PretrainedConfig | None = None,
device: str = "auto",
session_options: Any | None = None,
ep: str | None = None,
) -> None:
"""Initialize inference model.

Expand All @@ -73,6 +74,7 @@ def __init__(
config: HuggingFace PretrainedConfig (num_labels, id2label, etc.)
device: Target device ("auto", "npu", "gpu", "cpu")
session_options: ORT SessionOptions (e.g., for graph_optimization_level)
ep: Explicit EP short name (e.g., "dml", "qnn"). Forwarded to WinMLSession.
"""
self._onnx_path = Path(onnx_path)
self.config = config
Expand All @@ -86,6 +88,7 @@ def __init__(
onnx_path=self._onnx_path,
device=device,
session_options=session_options,
ep=ep,
)

@property
Expand Down
74 changes: 60 additions & 14 deletions src/winml/modelkit/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,16 @@ def compile(self) -> None:
logger.warning("ModelCompiler failed, using original: %s", e)

try:
# Create InferenceSession
# Create InferenceSession.
# EP is either configured via add_provider_for_devices (WinML EP
# registry, e.g. QNN) or left to ORT's device policy (fallback).
# Never pass providers= — WinML-registered EPs don't support it.
sess_options = self._build_session_options(target_device)
with _suppress_native_output(compile_log):
session = ort.InferenceSession(str(model_path), sess_options=sess_options)
session = ort.InferenceSession(
str(model_path),
sess_options=sess_options,
)

# Log which providers were selected by ORT (based on policy)
actual_providers = session.get_providers()
Expand Down Expand Up @@ -415,34 +421,45 @@ def _is_verbose(self) -> bool:
def _build_session_options(self, device: str) -> ort.SessionOptions:
"""Build ORT SessionOptions from instance session_options and device.

When ``self._ep`` is set, uses ``add_provider_for_devices`` to
explicitly bind a specific EP (e.g., MIGraphX, NvTensorRTRTX). Otherwise
falls back to policy-based selection via DEVICE_POLICY_MAP.
When ``self._ep`` is set (and not ``"cpu"``), uses
``add_provider_for_devices`` to explicitly bind that EP.
``"cpu"`` falls through to policy-based selection so ORT handles
CPU-only inference without any EP registration.
When ``self._ep`` is not set, queries ``get_ep_devices()`` to
discover an available EP for the target device type. Falls back to
policy-based selection only as a last resort.

Note: Returns a **fresh** SessionOptions when using explicit EP to
avoid "already registered" errors from repeated calls.
"""
# Explicit EP targeting: create fresh opts to avoid double-registration
# Don't filter by device type — trust the user's --ep choice
# (e.g., QNN reports as NPU in get_ep_devices but can target GPU)
if self._ep and self._ep != "cpu":
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
target_name = self._EP_NAME_MAP.get(self._ep)
if target_name:
matched = self._find_ep_device(target_name)
if matched:
opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
logger.info(
"Explicit EP: %s (%s)",
self._ep,
target_name,
)
logger.info("Explicit EP: %s (%s)", self._ep, target_name)
return opts
logger.warning(
"EP '%s' (%s) not found in available devices; falling back to policy",
"EP '%s' (%s) not found in available devices",
self._ep,
target_name,
)

# Policy-based selection (default path)
# No explicit EP — discover available EP for this device type
if not self._ep and device.lower() != "cpu":
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
matched = self._find_ep_for_device(device)
if matched:
opts = ort.SessionOptions()
opts.add_provider_for_devices([matched], self._provider_options)
logger.info("Discovered EP for %s: %s", device, matched.ep_name)
return opts

# Policy-based selection (last resort)
opts = self._session_options
policy = DEVICE_POLICY_MAP.get(
device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU
Expand All @@ -453,16 +470,45 @@ def _build_session_options(self, device: str) -> ort.SessionOptions:

@staticmethod
def _find_ep_device(ep_name: str) -> Any:
"""Find an OrtEpDevice matching the given EP name.
"""Find the first OrtEpDevice matching the given EP name.

Args:
ep_name: Full EP name (e.g., "DmlExecutionProvider").

Returns:
The first matching OrtEpDevice, or None if not found.
The matching OrtEpDevice, or None if not found.
"""
for ep_dev in ort.get_ep_devices():
if ep_dev.ep_name == ep_name:
return ep_dev
return None

@staticmethod
def _find_ep_for_device(device: str) -> Any:
"""Find the first available OrtEpDevice for the given device type.

Queries ``ort.get_ep_devices()`` and returns the first EP whose
hardware device type matches (e.g., device="gpu" matches GPU EPs).

Note: Selection order is determined by the ORT EP registry, which is
not part of any documented contract. On systems where multiple EPs
match the same device type (e.g., QNN and DML both appear as GPU),
the result is registry-order dependent. When a specific EP is
required, use ``self._ep`` to bypass this discovery path entirely.

Returns:
The matching OrtEpDevice, or None if not found.
"""
from ..utils.constants import DEVICE_TO_DEVICE_TYPE

device_type = DEVICE_TO_DEVICE_TYPE.get(device.upper())
if device_type is None:
return None
for ep_dev in ort.get_ep_devices():
if ep_dev.device.type == device_type:
return ep_dev
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
return None

def _validate_inputs(self, inputs: dict[str, Any]) -> None:
"""Validate inputs against model expectations.

Expand Down
10 changes: 5 additions & 5 deletions src/winml/modelkit/sysinfo/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
# AMD
"MIGraphXExecutionProvider": "gpu",
"VitisAIExecutionProvider": "npu",
# Qualcomm
"QNNExecutionProvider": "npu",
# Qualcomm (QNN supports both NPU and GPU via Adreno backend)
"QNNExecutionProvider": "npu/gpu",
# Microsoft
"DmlExecutionProvider": "gpu",
# Intel
Expand All @@ -51,11 +51,11 @@
"CPUExecutionProvider": "cpu",
}

# Derived inverse mapping (excludes multi-device EPs like OpenVINO)
# Derived inverse mapping (multi-device EPs are included in each device)
_DEVICE_EP_MAP: dict[str, list[str]] = {}
for _ep, _device in _EP_DEVICE_MAP.items():
if "/" not in _device:
_DEVICE_EP_MAP.setdefault(_device, []).append(_ep)
for _d in _device.split("/"):
_DEVICE_EP_MAP.setdefault(_d, []).append(_ep)

# Valid explicit device values
_VALID_DEVICES = frozenset({"npu", "gpu", "cpu"})
Expand Down
13 changes: 8 additions & 5 deletions tests/unit/sysinfo/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,13 @@ def test_ep_device_map_values_are_lowercase(self) -> None:
for ep, device in _EP_DEVICE_MAP.items():
assert device == device.lower(), f"{ep} maps to non-lowercase '{device}'"

def test_device_ep_map_excludes_openvino(self) -> None:
"""_DEVICE_EP_MAP should not contain OpenVINO entries."""
all_eps = [ep for eps in _DEVICE_EP_MAP.values() for ep in eps]
assert "OpenVINOExecutionProvider" not in all_eps
def test_device_ep_map_includes_multi_device_eps(self) -> None:
"""Multi-device EPs (QNN, OpenVINO) should appear in each device."""
assert "QNNExecutionProvider" in _DEVICE_EP_MAP["npu"]
assert "QNNExecutionProvider" in _DEVICE_EP_MAP["gpu"]
assert "OpenVINOExecutionProvider" in _DEVICE_EP_MAP["npu"]
assert "OpenVINOExecutionProvider" in _DEVICE_EP_MAP["gpu"]
assert "OpenVINOExecutionProvider" in _DEVICE_EP_MAP["cpu"]

def test_device_ep_map_derived_from_ep_device_map(self) -> None:
"""_DEVICE_EP_MAP should be consistent with _EP_DEVICE_MAP."""
Expand All @@ -148,7 +151,7 @@ def test_device_ep_map_derived_from_ep_device_map(self) -> None:
assert ep in _EP_DEVICE_MAP, (
f"EP '{ep}' in _DEVICE_EP_MAP but not in _EP_DEVICE_MAP"
)
assert _EP_DEVICE_MAP[ep] == device
assert device in _EP_DEVICE_MAP[ep].split("/")

def test_nv_tensorrt_rtx_is_gpu_ep(self) -> None:
"""NvTensorRTRTXExecutionProvider should map to gpu."""
Expand Down
Loading