diff --git a/src/winml/modelkit/__init__.py b/src/winml/modelkit/__init__.py index c10fcc9de..e92bf2368 100644 --- a/src/winml/modelkit/__init__.py +++ b/src/winml/modelkit/__init__.py @@ -28,16 +28,20 @@ model = WinMLAutoModel.from_pretrained("facebook/convnext-tiny-224", config=config) """ +from __future__ import annotations + +import importlib from importlib.metadata import PackageNotFoundError, version +from typing import TYPE_CHECKING -from . import _warnings # Configure warning filters before importing subpackages -from .config import WinMLBuildConfig -from .models import ( - WinMLAutoModel, - WinMLModelForImageClassification, - WinMLPreTrainedModel, -) +if TYPE_CHECKING: + from .config import WinMLBuildConfig + from .models import ( + WinMLAutoModel, + WinMLModelForImageClassification, + WinMLPreTrainedModel, + ) try: __version__ = version("winml-modelkit") @@ -51,3 +55,28 @@ "WinMLPreTrainedModel", "__version__", ] + +# Lazy imports — heavy ML dependencies (torch, transformers, optimum, +# diffusers) are only loaded when a symbol is actually accessed, so +# lightweight entry-points like ``winml sys`` stay fast. +_LAZY_IMPORT_MAP: dict[str, str] = { + "WinMLBuildConfig": ".config", + "WinMLAutoModel": ".models", + "WinMLModelForImageClassification": ".models", + "WinMLPreTrainedModel": ".models", +} + + +def __getattr__(name: str) -> object: + module_path = _LAZY_IMPORT_MAP.get(name) + if module_path is not None: + # Configure warning filters once before the first heavy import + if not globals().get("_warnings_loaded"): + globals()["_warnings_loaded"] = True + from . import _warnings + mod = importlib.import_module(module_path, __name__) + attr = getattr(mod, name) + # Cache on the module so __getattr__ is not called again + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/winml/modelkit/cli.py b/src/winml/modelkit/cli.py index 263494c87..434034090 100644 --- a/src/winml/modelkit/cli.py +++ b/src/winml/modelkit/cli.py @@ -31,7 +31,78 @@ logger = logging.getLogger(__name__) -@click.group() +class _LazyGroup(click.Group): + """Click group that discovers and imports commands lazily. + + Command modules under ``commands/`` are imported only when the user + actually invokes them (or asks for ``--help``). This avoids pulling + in heavy ML dependencies (torch, transformers, …) for lightweight + sub-commands like ``winml sys``. + """ + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._commands_dir = Path(__file__).parent / "commands" + # Lazily-discovered module names (without .py, excluding _private) + self._lazy_names: list[str] | None = None + + # ------------------------------------------------------------------ + def _scan_names(self) -> list[str]: + if self._lazy_names is None: + if self._commands_dir.exists(): + self._lazy_names = [ + f.stem for f in self._commands_dir.glob("*.py") if not f.name.startswith("_") + ] + else: + self._lazy_names = [] + return self._lazy_names + + # ------------------------------------------------------------------ + def list_commands(self, ctx: click.Context) -> list[str]: + # Merge eagerly-registered commands (if any) with lazy ones + eager = set(super().list_commands(ctx)) + lazy = set(self._scan_names()) + return sorted(eager | lazy) + + # ------------------------------------------------------------------ + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + # Already registered? + cmd = super().get_command(ctx, cmd_name) + if cmd is not None: + return cmd + + # Lazy-load from commands/.py + if cmd_name not in self._scan_names(): + return None + + try: + module = import_module( + f".commands.{cmd_name}", + package=__package__, + ) + except ImportError as exc: + logger.warning("Failed to import command module %s: %s", cmd_name, exc) + return None + except Exception as exc: + logger.error("Error loading command %s: %s", cmd_name, exc) + return None + + # Find the Click command in the module + discovered: click.Command | None = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, click.Group): + discovered = attr + break + if isinstance(attr, click.Command) and discovered is None: + discovered = attr + + if discovered is not None: + self.add_command(discovered, name=cmd_name) + return discovered + + +@click.group(cls=_LazyGroup) @click.version_option(version=__version__, prog_name="winml") @click.option( "--debug", @@ -57,64 +128,5 @@ def main(ctx: click.Context, debug: bool) -> None: ctx.obj["debug"] = debug -def _discover_commands() -> None: - """Auto-discover Click commands from commands/ directory. - - This function scans the commands/ directory for Python modules and - registers any Click commands found. Commands are registered using - the module filename as the command name. - - Command Discovery Rules: - - Skips files starting with underscore (_) - - Looks for any object that is a click.Command instance - - Uses module filename (without .py) as command name - """ - commands_dir = Path(__file__).parent / "commands" - - # Early exit if commands directory doesn't exist - if not commands_dir.exists(): - logger.debug("Commands directory not found: %s", commands_dir) - return - - # Scan for Python modules - for py_file in commands_dir.glob("*.py"): - # Skip private modules - if py_file.name.startswith("_"): - continue - - module_name = py_file.stem - try: - # Import the module - module = import_module( - f".commands.{module_name}", - package=__package__, - ) - - # Find Click command in module - # Prefer click.Group over click.Command for hierarchical commands - discovered_command = None - for attr_name in dir(module): - attr = getattr(module, attr_name) - if isinstance(attr, click.Group): - discovered_command = attr - break - if isinstance(attr, click.Command) and discovered_command is None: - discovered_command = attr - - if discovered_command: - # Register command with module name - main.add_command(discovered_command, name=module_name) - logger.debug("Discovered command: %s", module_name) - - except ImportError as e: - logger.warning("Failed to import command module %s: %s", module_name, e) - except Exception as e: - logger.error("Error loading command %s: %s", module_name, e) - - -# Discover and register commands at module load time -_discover_commands() - - if __name__ == "__main__": main() diff --git a/src/winml/modelkit/commands/sys.py b/src/winml/modelkit/commands/sys.py index 86cb9c9dd..07f5712cf 100644 --- a/src/winml/modelkit/commands/sys.py +++ b/src/winml/modelkit/commands/sys.py @@ -26,7 +26,6 @@ import json import logging import platform -import sys from typing import Any import click @@ -35,7 +34,7 @@ from rich.panel import Panel from rich.table import Table -from ..sysinfo import OS, get_ep_device_map +from ..sysinfo import get_ep_device_map logger = logging.getLogger(__name__) @@ -44,30 +43,39 @@ def _get_python_info() -> dict[str, Any]: """Gather Python environment information.""" + import importlib + + # Explicitly load the standard-library 'sys' module to avoid importing this 'sys' module. + stdlib_sys = importlib.import_module("sys") return { "version": platform.python_version(), - "executable": sys.executable, + "executable": stdlib_sys.executable, "implementation": platform.python_implementation(), } +def _is_windows_11() -> bool: + """Detect Windows 11 from the build number in platform.version(). + + platform.version() returns e.g. '10.0.26200'. Build >= 22000 is Win 11. + This avoids the expensive PowerShell/WMI call that OS.get() requires. + """ + try: + build = int(platform.version().split(".")[2]) + return build >= 22000 + except (IndexError, ValueError): + return False + + def _get_platform_info() -> dict[str, Any]: """Gather OS and platform information.""" system = platform.system() release = platform.release() - # For Windows, use OS class for accurate Windows 11 detection - # platform.release() may incorrectly report '10' on some Python versions - if system == "Windows": - try: - os_info = OS.get() - # Only override if it's actually Windows 11 - # Otherwise keep the original platform.release() value - if os_info.is_windows_11(): - release = "11" - except Exception: - # Fallback to platform.release() if OS detection fails - pass + # platform.release() may incorrectly report '10' on Windows 11. + # Detect via build number (pure Python, no PowerShell) instead. + if system == "Windows" and _is_windows_11(): + release = "11" return { "system": system, @@ -366,40 +374,66 @@ def _output_compact(info: dict[str, Any]) -> None: def _gather_device_info() -> list[dict[str, Any]]: """Gather available device information in priority order. + Runs a single PowerShell process to query all CIM classes (CPU, GPU) + and PnP devices (NPU) with their extra properties at once. + Returns: List of device dicts with type, priority, and details. """ from ..sysinfo import CPU, GPU, NPU + from ..sysinfo.helper import query_all_hardware + hw = query_all_hardware( + cim_class_names=["Win32_Processor", "Win32_VideoController"], + pnp_class_name="ComputeAccelerator", + pnp_extra_keys=NPU._EXTRA_PROPERTY_KEYS, + ) + + cim_map = hw["cim"] + + cpu_items: list[CPU] = [] + try: + cpu_items = [CPU(ci) for ci in cim_map.get("Win32_Processor", [])] + except Exception as e: + logger.warning("Failed to parse CPU info: %s", e) + + gpu_items: list[GPU] = [] + try: + gpu_items = [ + GPU(ci) + for ci in cim_map.get("Win32_VideoController", []) + if ci.try_get_property("PNPDeviceID", str, "").startswith(("PCI\\", "ACPI\\")) + ] + except Exception as e: + logger.warning("Failed to parse GPU info: %s", e) + + npu_items: list[NPU] = [] + try: + npu_items = [NPU(pnp) for pnp in hw["pnp"]] + except Exception as e: + logger.warning("Failed to get NPU details: %s", e) + + # --- build result list in NPU > GPU > CPU priority order --- result: list[dict[str, Any]] = [] priority = 1 - # Query hardware directly in NPU > GPU > CPU priority order. - # This avoids depending on _get_available_devices() and eliminates - # redundant PowerShell queries (we need the details anyway). - hw_queries: list[tuple[str, type]] = [ - ("NPU", NPU), - ("GPU", GPU), - ("CPU", CPU), + hw_groups: list[tuple[str, list]] = [ + ("NPU", npu_items), + ("GPU", gpu_items), + ("CPU", cpu_items), ] - for device_label, hw_class in hw_queries: - try: - items = hw_class.get_all() - except Exception as e: - logger.warning("Failed to get %s details: %s", device_label, e) - # Only append an error entry if this was expected to have results - # CPU always exists, NPU/GPU may not - if device_label == "CPU": - result.append( - { - "priority": priority, - "type": device_label, - "name": "(detection error)", - "details": {"error": str(e)}, - } - ) - priority += 1 + for device_label, items in hw_groups: + if not items and device_label == "CPU": + result.append( + { + "priority": priority, + "type": device_label, + "name": "(detection error)", + "details": {"error": "No CPU detected"}, + } + ) + priority += 1 continue for item in items: @@ -607,91 +641,112 @@ def sysinfo( if use_json: # Combine both into a single JSON object so output is always valid JSON result: dict[str, Any] = {} - if list_device: - try: - result["devices"] = _gather_device_info() - except Exception as e: - logger.exception("Failed to detect devices") - raise click.ClickException(f"Error detecting devices: {e}") from e - if list_ep: - try: - result["executionProviders"] = _gather_ep_info() - except Exception as e: - logger.exception("Failed to detect execution providers") - msg = f"Error detecting execution providers: {e}" - raise click.ClickException(msg) from e + with console.status("[bold blue]Detecting devices...[/bold blue]"): + if list_device: + try: + result["devices"] = _gather_device_info() + except Exception as e: + logger.exception("Failed to detect devices") + raise click.ClickException(f"Error detecting devices: {e}") from e + if list_ep: + try: + result["executionProviders"] = _gather_ep_info() + except Exception as e: + logger.exception("Failed to detect execution providers") + msg = f"Error detecting execution providers: {e}" + raise click.ClickException(msg) from e click.echo(json.dumps(result, indent=2)) elif output_format.lower() == "compact": + parts: list[str] = [] + ep_parts: list[str] = [] + with console.status("[bold blue]Detecting devices...[/bold blue]"): + if list_device: + try: + devices = _gather_device_info() + parts = [f"{d['type']}: {d['name'].strip()}" for d in devices] + except Exception as e: + logger.exception("Failed to detect devices") + raise click.ClickException(f"Error detecting devices: {e}") from e + if list_ep: + try: + eps = _gather_ep_info() + ep_parts = [f"{ep['name']}({ep['device']})" for ep in eps] + except Exception as e: + logger.exception("Failed to detect execution providers") + msg = f"Error detecting execution providers: {e}" + raise click.ClickException(msg) from e if list_device: - try: - devices = _gather_device_info() - parts = [f"{d['type']}: {d['name'].strip()}" for d in devices] - click.echo(" | ".join(parts) if parts else "No devices found") - except Exception as e: - logger.exception("Failed to detect devices") - raise click.ClickException(f"Error detecting devices: {e}") from e + click.echo(" | ".join(parts) if parts else "No devices found") if list_ep: - try: - eps = _gather_ep_info() - parts = [f"{ep['name']}({ep['device']})" for ep in eps] - click.echo("EPs: " + ", ".join(parts) if parts else "EPs: none") - except Exception as e: - logger.exception("Failed to detect execution providers") - msg = f"Error detecting execution providers: {e}" - raise click.ClickException(msg) from e + click.echo(("EPs: " + ", ".join(ep_parts)) if ep_parts else "EPs: none") else: - if list_device: - try: - devices = _gather_device_info() - _output_device_text(devices) - except Exception as e: - console.print(f"[bold red]Error detecting devices:[/bold red] {e}") - logger.exception("Failed to detect devices") - raise click.ClickException(f"Error detecting devices: {e}") from e - if list_ep: - try: - eps = _gather_ep_info() - _output_ep_text(eps) - except Exception as e: - err_msg = f"[bold red]Error detecting execution providers:[/bold red] {e}" - console.print(err_msg) - logger.exception("Failed to detect execution providers") - msg = f"Error detecting execution providers: {e}" - raise click.ClickException(msg) from e + devices = None + eps = None + with console.status("[bold blue]Detecting devices...[/bold blue]"): + if list_device: + try: + devices = _gather_device_info() + except Exception as e: + console.print(f"[bold red]Error detecting devices:[/bold red] {e}") + logger.exception("Failed to detect devices") + raise click.ClickException(f"Error detecting devices: {e}") from e + if list_ep: + try: + eps = _gather_ep_info() + except Exception as e: + err_msg = ( + f"[bold red]Error detecting execution providers:[/bold red] {e}" + ) + console.print(err_msg) + logger.exception("Failed to detect execution providers") + msg = f"Error detecting execution providers: {e}" + raise click.ClickException(msg) from e + if devices is not None: + _output_device_text(devices) + if eps is not None: + _output_ep_text(eps) return - # Default: full sysinfo including devices and EPs + # Default: full sysinfo including devices and EPs. + # Kick off hardware detection (PowerShell) in a background thread + # so it overlaps with the Python-side work (torch import, library + # version scanning, etc.). try: - info = _gather_system_info(verbose=verbose) + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=2) as pool: + hw_future = pool.submit(_gather_device_info) + ep_future = pool.submit(_gather_ep_info) + + # Python-side work runs on the main thread while PS runs + info = _gather_system_info(verbose=verbose) + + # Collect results + try: + devices = hw_future.result() + except Exception: + devices = [] + logger.debug("Device detection failed in default output") + try: + eps = ep_future.result() + except Exception: + eps = [] + logger.debug("EP detection failed in default output") if use_json: - # Add devices and EPs to JSON output - try: - info["devices"] = _gather_device_info() - except Exception: - info["devices"] = [] - try: - info["executionProviders"] = _gather_ep_info() - except Exception: - info["executionProviders"] = [] + info["devices"] = devices + info["executionProviders"] = eps _output_json(info) elif output_format.lower() == "compact": _output_compact(info) else: _output_text(info, verbose=verbose) - # Append devices and EPs to text output console.print() - try: - devices = _gather_device_info() + if devices: _output_device_text(devices) - except Exception: - logger.debug("Device detection failed in default output") console.print() - try: - eps = _gather_ep_info() + if eps: _output_ep_text(eps) - except Exception: - logger.debug("EP detection failed in default output") except Exception as e: console.print(f"[bold red]Error gathering system information:[/bold red] {e}") diff --git a/src/winml/modelkit/onnx/detection.py b/src/winml/modelkit/onnx/detection.py index 007ef9267..e0aa5250b 100644 --- a/src/winml/modelkit/onnx/detection.py +++ b/src/winml/modelkit/onnx/detection.py @@ -14,7 +14,6 @@ import logging from typing import TYPE_CHECKING -from ..compiler.utils import QDQ_OP_TYPES from .persistence import load_onnx @@ -42,6 +41,9 @@ def _load_model_lightweight(model_path: Path, operation: str) -> onnx.ModelProto def is_quantized_onnx(model_path: Path) -> bool: """Check if ONNX model is quantized (contains QuantizeLinear/DequantizeLinear nodes).""" + # Import lazily to break onnx ↔ compiler circular import at module level. + from ..compiler.utils import QDQ_OP_TYPES + model = _load_model_lightweight(model_path, "quantization check") return any(n.op_type in QDQ_OP_TYPES for n in model.graph.node) diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py index 5148da0b3..099984907 100644 --- a/src/winml/modelkit/session/__init__.py +++ b/src/winml/modelkit/session/__init__.py @@ -4,15 +4,22 @@ # -------------------------------------------------------------------------- """WinMLSession - ONNX Runtime session manager with WinML EP integration.""" -from .ep_registry import WinMLEPRegistry -from .monitor.ep_monitor import EPMonitor, NullEPMonitor -from .monitor.hw_monitor import HWMonitor -from .monitor.openvino_monitor import OpenVinoMonitor -from .monitor.qnn_monitor import QNNMonitor -from .monitor.vitisai_monitor import VitisAIMonitor -from .qairt.qairt_session import WinMLQairtSession -from .session import InferenceError, SessionState, WinMLSession -from .stats import PerfStats +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from .ep_registry import WinMLEPRegistry + from .monitor.ep_monitor import EPMonitor, NullEPMonitor + from .monitor.hw_monitor import HWMonitor + from .monitor.openvino_monitor import OpenVinoMonitor + from .monitor.qnn_monitor import QNNMonitor + from .monitor.vitisai_monitor import VitisAIMonitor + from .qairt.qairt_session import WinMLQairtSession + from .session import InferenceError, SessionState, WinMLSession + from .stats import PerfStats __all__ = [ @@ -29,3 +36,29 @@ "WinMLQairtSession", "WinMLSession", ] + +_LAZY_IMPORT_MAP: dict[str, tuple[str, str]] = { + "WinMLEPRegistry": (".ep_registry", "WinMLEPRegistry"), + "EPMonitor": (".monitor.ep_monitor", "EPMonitor"), + "NullEPMonitor": (".monitor.ep_monitor", "NullEPMonitor"), + "HWMonitor": (".monitor.hw_monitor", "HWMonitor"), + "OpenVinoMonitor": (".monitor.openvino_monitor", "OpenVinoMonitor"), + "QNNMonitor": (".monitor.qnn_monitor", "QNNMonitor"), + "VitisAIMonitor": (".monitor.vitisai_monitor", "VitisAIMonitor"), + "WinMLQairtSession": (".qairt.qairt_session", "WinMLQairtSession"), + "InferenceError": (".session", "InferenceError"), + "SessionState": (".session", "SessionState"), + "WinMLSession": (".session", "WinMLSession"), + "PerfStats": (".stats", "PerfStats"), +} + + +def __getattr__(name: str) -> object: + entry = _LAZY_IMPORT_MAP.get(name) + if entry is not None: + module_path, attr_name = entry + mod = importlib.import_module(module_path, __name__) + attr = getattr(mod, attr_name) + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/winml/modelkit/sysinfo/hardware.py b/src/winml/modelkit/sysinfo/hardware.py index 9bcdb44e8..ca46e7df4 100644 --- a/src/winml/modelkit/sysinfo/hardware.py +++ b/src/winml/modelkit/sysinfo/hardware.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import re from enum import Enum +from typing import ClassVar from .helper import CimInstance, PnpDevice @@ -179,10 +180,16 @@ def to_dict(self) -> dict: class NPU: """Represents NPU (Neural Processing Unit) information.""" + # Extra PnP properties that NPU needs from Get-PnpDeviceProperty. + _EXTRA_PROPERTY_KEYS: ClassVar[list[str]] = ["DEVPKEY_Device_DriverVersion"] + @staticmethod def get_all() -> list["NPU"]: """Get all NPUs in the system.""" - pnp_devices = PnpDevice.get_by_class_name("ComputeAccelerator") + pnp_devices = PnpDevice.get_by_class_name( + "ComputeAccelerator", + extra_property_keys=NPU._EXTRA_PROPERTY_KEYS, + ) return [NPU(pnp_device) for pnp_device in pnp_devices] def __init__(self, pnp_device: PnpDevice) -> None: diff --git a/src/winml/modelkit/sysinfo/helper.py b/src/winml/modelkit/sysinfo/helper.py index b2ae19a49..5fd23b340 100644 --- a/src/winml/modelkit/sysinfo/helper.py +++ b/src/winml/modelkit/sysinfo/helper.py @@ -2,9 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import json import subprocess -from typing import TypeVar +from typing import Any, TypeVar _T = TypeVar("_T") @@ -35,31 +37,64 @@ class CimInstance: """Represents a WMI CIM instance retrieved via PowerShell.""" @staticmethod - def get_by_class_name(class_name: str) -> list["CimInstance"]: + def get_by_class_name(class_name: str) -> list[CimInstance]: """Get all CIM instances of the specified class name.""" - output = None + results = CimInstance.get_many_by_class_name([class_name]) + return results.get(class_name, []) + + @staticmethod + def get_many_by_class_name( + class_names: list[str], + ) -> dict[str, list[CimInstance]]: + """Get CIM instances for multiple class names in a single PowerShell call. + + Returns: + Mapping of class_name -> list of CimInstance. + """ + if not class_names: + return {} + + # Build a single PowerShell script that queries all classes and wraps + # each result set with its class name for demuxing. + parts: list[str] = ["[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "] + parts.append("$result = @{}; ") + parts.extend( + f"$result['{name}'] = @(Get-CimInstance -ClassName {name} " + f"-ErrorAction SilentlyContinue); " + for name in class_names + ) + parts.append("$result | ConvertTo-Json -Depth 99") + try: - output = subprocess.check_output( # noqa: S603 - Input is trusted (class_name from code) + output = subprocess.check_output( # noqa: S603 - Input is trusted (class_names from code) [ # noqa: S607 - PowerShell path is standard on Windows "powershell", "-NoProfile", "-Command", - "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; " - + f"Get-CimInstance -ClassName {class_name} | " - + "ConvertTo-Json -Depth 99", + "".join(parts), ], stderr=subprocess.DEVNULL, ) except subprocess.CalledProcessError: - # This will throw if no matching device is found - return [] - json_array = json.loads(output.decode("utf-8")) - if isinstance(json_array, dict): - # Powershell will return a single object as a dict. - json_array = [json_array] - if not isinstance(json_array, list): - raise TypeError(f"Expected a list from Get-CimInstance, got {type(json_array)}") - return [CimInstance(json_obj) for json_obj in json_array] + return {name: [] for name in class_names} + + raw = output.decode("utf-8").strip() + if not raw: + return {name: [] for name in class_names} + + parsed = json.loads(raw) + if not isinstance(parsed, dict): + return {name: [] for name in class_names} + + result: dict[str, list[CimInstance]] = {} + for name in class_names: + items = parsed.get(name, []) + if isinstance(items, dict): + items = [items] + if not isinstance(items, list): + items = [] + result[name] = [CimInstance(obj) for obj in items if isinstance(obj, dict)] + return result def __init__(self, json_obj: dict) -> None: """Initialize a CIM instance from a JSON object.""" @@ -85,8 +120,21 @@ class PnpDevice: """Represents a Plug and Play device retrieved via PowerShell.""" @staticmethod - def get_by_class_name(class_name: str) -> list["PnpDevice"]: - """Get all PnP devices of the specified class name.""" + def get_by_class_name( + class_name: str, + extra_property_keys: list[str] | None = None, + ) -> list[PnpDevice]: + """Get all PnP devices of the specified class name. + + Args: + class_name: PnP device class name (e.g. "ComputeAccelerator"). + extra_property_keys: Optional list of DEVPKEY names to fetch via + Get-PnpDeviceProperty. When provided, a **single** batched + PowerShell call retrieves only the requested keys for every + device, instead of spawning one process per device for all + properties. Pass ``None`` (default) to skip extra properties + entirely — callers that don't need them avoid the cost. + """ output = None try: output = subprocess.check_output( # noqa: S603 - Input is trusted (class_name from code) @@ -108,39 +156,97 @@ def get_by_class_name(class_name: str) -> list["PnpDevice"]: json_array = [json_array] if not isinstance(json_array, list): raise TypeError(f"Expected a list from Get-PnpDevice, got {type(json_array)}") - return [PnpDevice(json_obj) for json_obj in json_array] - def __init__(self, json_obj: dict) -> None: - """Initialize a PnP device from a JSON object.""" - self._pnp_id = _get_property_from_json(json_obj, "PNPDeviceID", str) - self._pnp_device_obj = json_obj - output = b"[]" + # Batch-fetch extra properties for all devices in one PowerShell call. + all_extra: dict[str, dict[str, object]] = {} # pnp_id -> {key: value} + if extra_property_keys and json_array: + pnp_ids = [_get_property_from_json(obj, "PNPDeviceID", str) for obj in json_array] + all_extra = PnpDevice._batch_get_extra_properties(pnp_ids, extra_property_keys) + + return [ + PnpDevice( + json_obj, all_extra.get(_get_property_from_json(json_obj, "PNPDeviceID", str), {}) + ) + for json_obj in json_array + ] + + @staticmethod + def _batch_get_extra_properties( + pnp_ids: list[str], + property_keys: list[str], + ) -> dict[str, dict[str, object]]: + """Fetch specific extra properties for multiple devices in one PowerShell call. + + Returns: + Mapping of pnp_id -> {property_key: value}. + """ + # Build a PowerShell script that queries all devices and returns structured JSON. + # Using -KeyName filters to only the properties we need (much faster than all). + keys_array = ", ".join(f"'{k}'" for k in property_keys) + ids_array = ", ".join(f"'{pid}'" for pid in pnp_ids) + ps_script = ( + "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; " + f"$ids = @({ids_array}); " + f"$keys = @({keys_array}); " + "$result = @(); " + "foreach ($id in $ids) { " + " try { " + " $props = Get-PnpDeviceProperty -InstanceId $id -KeyName $keys " + " -ErrorAction SilentlyContinue; " + " foreach ($p in $props) { " + " $result += @{ InstanceId = $id; KeyName = $p.KeyName; Data = $p.Data } " + " } " + " } catch { } " + "} " + "$result | ConvertTo-Json -Depth 99" + ) + try: - output = subprocess.check_output( # noqa: S603 - Input is trusted (pnp_id from WMI) - [ # noqa: S607 - PowerShell path is standard on Windows + output = subprocess.check_output( # noqa: S603 - Input is trusted + [ # noqa: S607 "powershell", "-NoProfile", "-Command", - "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; " - + f"Get-PnpDeviceProperty -InstanceId '{self._pnp_id}' | " - + "ConvertTo-Json -Depth 99", + ps_script, ], stderr=subprocess.DEVNULL, ) except subprocess.CalledProcessError: - # This may happen if the device has no extra properties - pass + return {} - property_list = json.loads(output.decode("utf-8")) - if not isinstance(property_list, list): - raise TypeError( - f"Expected a list from Get-PnpDeviceProperty, got {type(property_list)}" - ) - self._extra_properties: dict[str, object] = {} - for prop in property_list: - key = _get_property_from_json(prop, "KeyName", str) - value = _get_property_from_json(prop, "Data", object) - self._extra_properties[key] = value + raw = output.decode("utf-8").strip() + if not raw: + return {} + + parsed = json.loads(raw) + if isinstance(parsed, dict): + parsed = [parsed] + if not isinstance(parsed, list): + return {} + + result: dict[str, dict[str, object]] = {} + for item in parsed: + inst_id = _get_property_from_json(item, "InstanceId", str) + key = _get_property_from_json(item, "KeyName", str) + value = _get_property_from_json(item, "Data", object) + result.setdefault(inst_id, {})[key] = value + return result + + def __init__( + self, + json_obj: dict, + extra_properties: dict[str, object] | None = None, + ) -> None: + """Initialize a PnP device from a JSON object. + + Args: + json_obj: Device data from Get-PnpDevice. + extra_properties: Pre-fetched extra properties (key -> value). + When ``None``, no extra properties are available. + """ + self._pnp_id = _get_property_from_json(json_obj, "PNPDeviceID", str) + self._pnp_device_obj = json_obj + self._extra_properties: dict[str, object] = extra_properties or {} def get_property(self, property_name: str, property_type: type[_T]) -> _T: """Get a property value from the PnP device.""" @@ -175,7 +281,7 @@ class AppxPackage: """Represents an AppX package retrieved via PowerShell.""" @staticmethod - def get_by_hint(hint: str) -> list["AppxPackage"]: + def get_by_hint(hint: str) -> list[AppxPackage]: """Get AppX packages matching the given hint string.""" output = None try: @@ -218,3 +324,113 @@ def try_get_property(self, property_name: str, property_type: type[_T], default: property_type=property_type, default=default, ) + + +def query_all_hardware( + cim_class_names: list[str], + pnp_class_name: str | None = None, + pnp_extra_keys: list[str] | None = None, +) -> dict[str, Any]: + """Query CIM instances and PnP devices in a single PowerShell process. + + This avoids paying the PowerShell cold-start cost multiple times. + + Args: + cim_class_names: WMI class names to query (e.g. Win32_Processor). + pnp_class_name: Optional PnP device class (e.g. ComputeAccelerator). + pnp_extra_keys: DEVPKEY names to fetch for each PnP device. + + Returns: + ``{"cim": {class_name: [CimInstance, ...]}, + "pnp": [PnpDevice, ...]}`` + """ + # Build a single PowerShell script + parts: list[str] = ["[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; "] + parts.append("$r = @{}; ") + + # CIM queries + parts.extend( + f"$r['{name}'] = @(Get-CimInstance -ClassName {name} -ErrorAction SilentlyContinue); " + for name in cim_class_names + ) + + # PnP device + extra properties + if pnp_class_name: + parts.append( + f"$pnp = @(Get-PnpDevice -Class {pnp_class_name} -ErrorAction SilentlyContinue); " + ) + parts.append("$r['_pnp'] = $pnp; ") + + if pnp_extra_keys: + keys_str = ", ".join(f"'{k}'" for k in pnp_extra_keys) + parts.append( + "$pnpProps = @(); " + "foreach ($d in $pnp) { " + " try { " + f" $props = Get-PnpDeviceProperty -InstanceId $d.InstanceId " + f"-KeyName {keys_str} -ErrorAction SilentlyContinue; " + " foreach ($p in $props) { " + " $pnpProps += @{ InstanceId = $d.InstanceId; " + "KeyName = $p.KeyName; Data = $p.Data } " + " } " + " } catch { } " + "} " + "$r['_pnpProps'] = $pnpProps; " + ) + + parts.append("$r | ConvertTo-Json -Depth 99") + + try: + output = subprocess.check_output( # noqa: S603 + ["powershell", "-NoProfile", "-Command", "".join(parts)], # noqa: S607 + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError: + return {"cim": {n: [] for n in cim_class_names}, "pnp": []} + + raw = output.decode("utf-8").strip() + if not raw: + return {"cim": {n: [] for n in cim_class_names}, "pnp": []} + + parsed = json.loads(raw) + if not isinstance(parsed, dict): + return {"cim": {n: [] for n in cim_class_names}, "pnp": []} + + # Demux CIM results + cim_result: dict[str, list[CimInstance]] = {} + for name in cim_class_names: + items = parsed.get(name, []) + if isinstance(items, dict): + items = [items] + if not isinstance(items, list): + items = [] + cim_result[name] = [CimInstance(obj) for obj in items if isinstance(obj, dict)] + + # Demux PnP results + pnp_result: list[PnpDevice] = [] + if pnp_class_name: + # Build extra-properties lookup + extra_map: dict[str, dict[str, object]] = {} + for item in parsed.get("_pnpProps", []) or []: + if not isinstance(item, dict): + continue + inst_id = item.get("InstanceId", "") + key = item.get("KeyName", "") + if inst_id and key: + extra_map.setdefault(inst_id, {})[key] = item.get("Data") + + pnp_raw = parsed.get("_pnp", []) + if isinstance(pnp_raw, dict): + pnp_raw = [pnp_raw] + for obj in pnp_raw or []: + if not isinstance(obj, dict): + continue + # Get-PnpDevice returns "InstanceId" (matches the PS property + # name), while Win32_PnPEntity / CIM returns "PNPDeviceID". + # Both contain the same device path string, so we can use + # PNPDeviceID from the JSON to look up batched properties + # keyed by InstanceId. + pnp_id = obj.get("PNPDeviceID", "") or obj.get("InstanceId", "") + pnp_result.append(PnpDevice(obj, extra_map.get(pnp_id, {}))) + + return {"cim": cim_result, "pnp": pnp_result} diff --git a/tests/unit/sysinfo/test_sysinfo.py b/tests/unit/sysinfo/test_sysinfo.py index 2b8da839c..6e22410d2 100644 --- a/tests/unit/sysinfo/test_sysinfo.py +++ b/tests/unit/sysinfo/test_sysinfo.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- """Unit tests for sysinfo command module. -Tests the _get_platform_info function with Windows version detection. +Tests the _get_platform_info and _is_windows_11 functions. """ from __future__ import annotations @@ -12,120 +12,121 @@ from unittest.mock import MagicMock, patch +class TestIsWindows11: + """Test _is_windows_11 helper function.""" + + @patch("winml.modelkit.commands.sys.platform") + def test_build_26200_is_windows_11(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0.26200" + assert _is_windows_11() is True + + @patch("winml.modelkit.commands.sys.platform") + def test_build_22000_is_windows_11(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0.22000" + assert _is_windows_11() is True + + @patch("winml.modelkit.commands.sys.platform") + def test_build_21999_is_not_windows_11(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0.21999" + assert _is_windows_11() is False + + @patch("winml.modelkit.commands.sys.platform") + def test_build_19045_is_not_windows_11(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0.19045" + assert _is_windows_11() is False + + @patch("winml.modelkit.commands.sys.platform") + def test_malformed_version_returns_false(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0" + assert _is_windows_11() is False + + @patch("winml.modelkit.commands.sys.platform") + def test_non_numeric_build_returns_false(self, mock_platform: MagicMock) -> None: + from winml.modelkit.commands.sys import _is_windows_11 + + mock_platform.version.return_value = "10.0.abc" + assert _is_windows_11() is False + + class TestGetPlatformInfo: """Test _get_platform_info function.""" - @patch("winml.modelkit.commands.sys.OS") @patch("winml.modelkit.commands.sys.platform") - def test_windows_11_detection(self, mock_platform: MagicMock, mock_os_class: MagicMock) -> None: - """Test Windows 11 is correctly detected.""" + def test_windows_11_detection(self, mock_platform: MagicMock) -> None: + """Test Windows 11 is correctly detected via build number.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks mock_platform.system.return_value = "Windows" mock_platform.release.return_value = "10" # Platform reports wrong version + mock_platform.version.return_value = "10.0.26200" mock_platform.machine.return_value = "AMD64" mock_platform.processor.return_value = "Intel64 Family 6" - mock_os_instance = MagicMock() - mock_os_instance.is_windows_11.return_value = True - mock_os_class.get.return_value = mock_os_instance - result = _get_platform_info() assert result["system"] == "Windows" assert result["release"] == "11" # Should be corrected to 11 assert result["machine"] == "AMD64" - mock_os_class.get.assert_called_once() - @patch("winml.modelkit.commands.sys.OS") @patch("winml.modelkit.commands.sys.platform") - def test_windows_10_detection(self, mock_platform: MagicMock, mock_os_class: MagicMock) -> None: + def test_windows_10_detection(self, mock_platform: MagicMock) -> None: """Test Windows 10 is correctly detected.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks mock_platform.system.return_value = "Windows" mock_platform.release.return_value = "10" + mock_platform.version.return_value = "10.0.19045" mock_platform.machine.return_value = "AMD64" mock_platform.processor.return_value = "Intel64 Family 6" - mock_os_instance = MagicMock() - mock_os_instance.is_windows_11.return_value = False - mock_os_class.get.return_value = mock_os_instance - result = _get_platform_info() assert result["system"] == "Windows" assert result["release"] == "10" assert result["machine"] == "AMD64" - @patch("winml.modelkit.commands.sys.OS") @patch("winml.modelkit.commands.sys.platform") - def test_windows_7_preserved(self, mock_platform: MagicMock, mock_os_class: MagicMock) -> None: - """Test Windows 7 version is preserved (not changed to 10).""" + def test_windows_7_preserved(self, mock_platform: MagicMock) -> None: + """Test Windows 7 version is preserved.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks mock_platform.system.return_value = "Windows" mock_platform.release.return_value = "7" + mock_platform.version.return_value = "6.1.7601" mock_platform.machine.return_value = "AMD64" mock_platform.processor.return_value = "Intel64 Family 6" - mock_os_instance = MagicMock() - mock_os_instance.is_windows_11.return_value = False - mock_os_class.get.return_value = mock_os_instance - result = _get_platform_info() assert result["system"] == "Windows" - assert result["release"] == "7" # Should keep original value + assert result["release"] == "7" assert result["machine"] == "AMD64" - @patch("winml.modelkit.commands.sys.OS") @patch("winml.modelkit.commands.sys.platform") - def test_windows_81_preserved(self, mock_platform: MagicMock, mock_os_class: MagicMock) -> None: - """Test Windows 8.1 version is preserved (not changed to 10).""" + def test_windows_81_preserved(self, mock_platform: MagicMock) -> None: + """Test Windows 8.1 version is preserved.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks mock_platform.system.return_value = "Windows" mock_platform.release.return_value = "8.1" + mock_platform.version.return_value = "6.3.9600" mock_platform.machine.return_value = "AMD64" mock_platform.processor.return_value = "Intel64 Family 6" - mock_os_instance = MagicMock() - mock_os_instance.is_windows_11.return_value = False - mock_os_class.get.return_value = mock_os_instance - result = _get_platform_info() assert result["system"] == "Windows" - assert result["release"] == "8.1" # Should keep original value - assert result["machine"] == "AMD64" - - @patch("winml.modelkit.commands.sys.OS") - @patch("winml.modelkit.commands.sys.platform") - def test_windows_detection_fallback_on_exception( - self, mock_platform: MagicMock, mock_os_class: MagicMock - ) -> None: - """Test fallback to platform.release() when OS detection fails.""" - from winml.modelkit.commands.sys import _get_platform_info - - # Setup mocks - mock_platform.system.return_value = "Windows" - mock_platform.release.return_value = "10" - mock_platform.machine.return_value = "AMD64" - mock_platform.processor.return_value = "Intel64 Family 6" - - # OS.get() raises exception - mock_os_class.get.side_effect = RuntimeError("WMI error") - - result = _get_platform_info() - - # Should use fallback value from platform.release() - assert result["system"] == "Windows" - assert result["release"] == "10" + assert result["release"] == "8.1" assert result["machine"] == "AMD64" @patch("winml.modelkit.commands.sys.platform") @@ -133,7 +134,6 @@ def test_non_windows_platform(self, mock_platform: MagicMock) -> None: """Test non-Windows platforms pass through unchanged.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks for Linux mock_platform.system.return_value = "Linux" mock_platform.release.return_value = "5.15.0" mock_platform.machine.return_value = "x86_64" @@ -150,7 +150,6 @@ def test_macos_platform(self, mock_platform: MagicMock) -> None: """Test macOS platforms pass through unchanged.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks for macOS mock_platform.system.return_value = "Darwin" mock_platform.release.return_value = "21.6.0" mock_platform.machine.return_value = "arm64" @@ -167,11 +166,10 @@ def test_processor_unknown_fallback(self, mock_platform: MagicMock) -> None: """Test processor defaults to 'Unknown' when empty.""" from winml.modelkit.commands.sys import _get_platform_info - # Setup mocks mock_platform.system.return_value = "Linux" mock_platform.release.return_value = "5.15.0" mock_platform.machine.return_value = "x86_64" - mock_platform.processor.return_value = "" # Empty string + mock_platform.processor.return_value = "" result = _get_platform_info()