diff --git a/README.md b/README.md index ee64a0de2..1546a6309 100644 --- a/README.md +++ b/README.md @@ -1,65 +1,454 @@ # ModelKit -Accelerate Model Deployment on WinML. +![Status](https://img.shields.io/badge/status-early%20access-blue) +![Python](https://img.shields.io/badge/python-3.10%2B-blue?logo=python&logoColor=white) +![License](https://img.shields.io/badge/license-MIT-green) -ModelKit is a Python toolkit for converting and optimizing PyTorch models to ONNX format, targeting deployment on the [Windows ML](https://learn.microsoft.com/en-us/windows/ai/windows-ml/) runtime. It supports multiple hardware backends including QNN (Qualcomm Neural Processing SDK) and OpenVINO. +**ModelKit** is a CLI toolkit to build **portable, performant, and high-quality** models for Windows ML. It covers the entire journey from pretrained model to on-device inference — export, optimization, quantization, compilation, and benchmarking — across **all execution providers**, regardless of silicon. -## Features +--- -- **Universal ONNX Export** — Convert PyTorch and Hugging Face models to ONNX with hierarchy preservation -- **Model Analysis** — Validate ONNX models for operator support, shape inference, and backend compatibility -- **Quantization** — INT8/INT16 quantization with calibration dataset support -- **Optimization** — Graph optimizations tailored for target execution providers -- **Performance Profiling** — Operation-level tracing and hardware monitoring -- **Multi-Backend Support** — QNN, OpenVINO, DirectML, and ONNX Runtime CPU/GPU +## :dart: ModelKit Is Right for You If -## Getting Started +- [x] You want to build models that run on **any Windows device** — Qualcomm, Intel, AMD, NVIDIA, or CPU +- [x] You want to benchmark a model with **one command** — latency, throughput, and live hardware utilization +- [x] You want to catch compatibility issues **ahead of time** — unsupported ops, shape mismatches, EP gaps +- [x] You want **deep insights** into your model — I/O shapes, task mapping, operator coverage per EP +- [x] You want a **repeatable and traceable** model building process — config-driven, inspectable at every stage +- [x] You want **AI agents** to build and profile models for you — agent-ready skills for coding assistants -### Prerequisites +--- -- Windows 10/11 -- Python 3.10 -- [uv](https://github.com/astral-sh/uv) package manager +## :desktop_computer: Supported Hardware -### Installation +| Execution Provider | Hardware | Status | EP Flag | Device Flag | +|:-------------------|:---------|:------:|:--------|:------------| +| **QNN** | Qualcomm NPU (Snapdragon X Elite) | 🟢 Ready | `--ep qnn` | `--device npu` | +| **OpenVINO** | Intel NPU (Meteor Lake / Lunar Lake) | 🟢 Ready | `--ep openvino` | `--device npu` | +| **VitisAI** | AMD NPU (Ryzen AI) | 🟢 Ready | `--ep vitisai` | `--device npu` | +| **TensorRT** | NVIDIA discrete GPUs | 🔶 Planned | `--ep tensorrt` | `--device gpu` | +| **MIGraphX** | AMD discrete GPUs | 🔶 Planned | `--ep migraphx` | `--device gpu` | +| **DirectML** | Hardware-agnostic GPU backend | 🔶 Planned | `--ep dml` | `--device gpu` | +| **CPU** | Cross-platform fallback | ⚪ Always available | `--ep cpu` | `--device cpu` | + +> **Tip:** Use `--device auto` and ModelKit picks the best available device — NPU first, then GPU, then CPU. + +--- + +## :clipboard: Prerequisites + +### Required Software + +| **Component** | **How to Get It** | +|-----------|--------------| +| **Windows 11** (x64 or ARM64) | Windows 11 24H2+ required for NPU support | +| **UV** | Install [UV](https://github.com/astral-sh/uv) | +| **Windows App SDK Runtime 1.8** | [Latest Windows App SDK downloads](https://learn.microsoft.com/en-us/windows/apps/windows-app-sdk/downloads) | +| **ModelKit** (Python wheel) | See release instructions | + +### Required Hardware + +**ModelKit targets NPU.** We recommend testing on one of the following NPU devices: + +| Device | EP | Flag | +|--------|-----|------| +| Snapdragon X Elite (Qualcomm) | QNN | `--ep qnn --device npu` | +| Intel AI Boost (Meteor Lake / Lunar Lake) | OpenVINO | `--ep openvino --device npu` | +| AMD Ryzen AI (Phoenix / Hawk Point / Strix) | VitisAI | `--ep vitisai --device npu` | + +**No NPU?** Use `--device auto` — ModelKit will fall back to the best available device (GPU → CPU). Note that `winml compile` requires NPU and cannot run without one. + +### Accepted Inputs + +- **HuggingFace model ID** (e.g., `microsoft/resnet-50`) — weights are downloaded on first run +- **Local ONNX file** (e.g., `model.onnx`) — from `winml export`, `winml build`, or any ONNX you already have + +### The Golden Rule: Inspect First + +Before running any pipeline command, always verify the model is supported: + +```bash +winml inspect -m +``` + +If `inspect` prints an error or shows `Unsupported`, **skip that model**. Only models that pass inspect are valid inputs for export, analyze, build, perf, and eval. + +--- + +## :package: Installation + +ModelKit requires **Python 3.10** and is distributed as a Python wheel. We recommend [uv](https://docs.astral.sh/uv/) for fast, reproducible environment setup. + +**1. Create a Python 3.10 environment** + +```bash +uv venv --python 3.10 +``` + +Activate it: + +```bash +# Windows (PowerShell) +.venv\Scripts\activate + +# Windows (Git Bash / WSL) +source .venv/Scripts/activate +``` + +**2. Install from wheel** + +```bash +uv pip install winml_modelkit--py3-none-any.whl +``` + +**3. Verify your environment** + +```bash +winml sys --list-device --list-ep +``` + +Confirm that your target device and EP appear in the output: + +- **Snapdragon X Elite** — look for `QNNExecutionProvider` +- **Intel AI Boost** — look for `OpenVINOExecutionProvider` +- **AMD Ryzen AI** — look for `VitisAIExecutionProvider` + +If no NPU is detected, you can still use ModelKit with `--device auto` for most commands. The only exception is `winml compile`, which requires an NPU device. + +--- + +## :wrench: Commands + +| Category | Commands | Purpose | +|:---------|:---------|:--------| +| **Primitives** | `inspect` `export` `optimize` `quantize` `compile` | Single-stage building blocks | +| **Pipeline** | `config` `build` `perf` `eval` `run`\* | End-to-end orchestration | +| **Insights** | `analyze` `debug`\* | Diagnostics and compatibility | +| **Utilities** | `hub` `cache`\* `doctor`\* `setting`\* `sys` | Catalog, cache, and environment | + +\* = coming soon + +
+Primitives — one stage at a time + +**`winml inspect`** — Discover model metadata. Prints the task, model class, input/output tensor names and shapes, and execution provider compatibility. No weights are loaded — this reads only the model configuration, making it fast and lightweight. Always run inspect first to verify a model is supported. + +**`winml export`** — Convert a source model to ONNX. Takes a Hugging Face model ID (or local checkpoint) and produces a standards-compliant ONNX file with hierarchy-preserving metadata. + +**`winml optimize`** — Fuse operators, simplify graphs, and prepare for target EPs. Takes an ONNX model and an optimization config (typically generated by `winml analyze`) and applies graph-level transformations: operator fusion, constant folding, shape inference, and EP-specific rewrites. + +**`winml quantize`** — Compress to low-bit precision. Reduces model size and inference latency by converting weights and activations from FP32 to INT8 (or other low-bit formats). After quantization, the model is portable — it can run on any ONNX Runtime backend. + +**`winml compile`** — Generate device-specific binaries. Takes a quantized ONNX model and produces EP-specific compiled artifacts (for example, QNN context binaries for Qualcomm NPU). This step locks the model to a specific device but delivers the lowest possible inference latency. + +
+ +
+Pipeline — orchestrated workflows + +**`winml config`** — Auto-detect optimal settings into a JSON config. Inspects the model and generates a complete build specification: task, I/O shapes, optimization flags, quantization parameters, and target EP settings. The config file is reviewable, editable, and version-controllable — the single source of truth for your build. + +**`winml build`** — Orchestrate the full pipeline. Takes a config file and executes every stage in sequence: export, analyze, optimize, quantize, and compile. Two commands (`config` + `build`) replace eight manual steps. + +**`winml perf`** — Benchmark latency, throughput, and hardware utilization. Runs inference on the target device and reports latency percentiles (p50, p90, p99), throughput (inferences per second), and optionally live hardware monitoring (CPU, RAM, NPU utilization) with the `--monitor` flag. Can accept a local ONNX file or a Hugging Face model ID. + +**`winml eval`** — Measure model accuracy against reference datasets. Compares the output of your optimized/quantized model against the original to quantify any accuracy loss introduced by the pipeline. + +**`winml run`** — End-to-end inference with pre/post processing. *(Coming soon.)* + +
+ +
+Insights — understand what is happening inside + +**`winml analyze`** — Lint operators, check EP compatibility, and generate optimization config. The analyzer has two components: the **Linter** (like ESLint for ONNX) checks every operator against target EPs and classifies each as supported, partial, or unsupported. **AutoConf** detects suboptimal patterns and generates the optimization config that the optimizer consumes. Together they form the analyze-optimize loop. + +**`winml debug`** — Interactive model debugging and layer-by-layer inspection. *(Coming soon.)* + +
+ +
+Utilities — catalog, cache, and environment + +**`winml hub`** — Browse the curated built-in model catalog. + +**`winml cache`** — Manage built model artifacts and pipeline outputs. View, clean, or selectively remove cached models and intermediate files. + +**`winml doctor`** — Diagnose environment issues. Checks runtimes, execution providers, and dependencies to identify configuration problems. + +**`winml setting`** — Configure ModelKit preferences. Set default EPs, output directories, and other global options. + +**`winml sys`** — System information and capability reporting. Prints detected hardware, available EPs, Python version, and installed package versions. + +
+ +--- + +## :rocket: Quick Start + +### Inspect a Model + +The fastest way to get started is to inspect a model. Let's look at ResNet-50: ```bash -git clone https://github.com/microsoft/ModelKit.git -cd ModelKit -uv python install 3.10 -uv sync +winml inspect -m microsoft/resnet-50 ``` -### Usage +This prints the model's metadata without downloading weights: + +- **Task**: `image-classification` — what the model does +- **Model class**: `ResNetForImageClassification` — the architecture +- **Input tensors**: names, data types, and shapes (e.g., `pixel_values: float32 [1, 3, 224, 224]`) +- **Output tensors**: names, data types, and shapes (e.g., `logits: float32 [1, 1000]`) + +If inspect succeeds, the model is supported and you can proceed with the rest of the pipeline. + +> **Golden rule: always inspect first.** Before running export, build, perf, or any other pipeline command, verify the model is supported with `winml inspect`. -ModelKit provides a CLI tool `winml`: +### Build with Primitive Commands + +This walkthrough builds **ConvNeXT** (`facebook/convnext-base-224`) step by step using primitive commands. ConvNeXT is a family of CNN models inspired by Vision Transformers, introduced by Meta in 2022 — it offers high accuracy while retaining the efficiency of CNNs. + +#### Phase 1: Inspect + +```bash +winml inspect -m facebook/convnext-base-224 +``` + +#### Phase 2: Build a Portable Model + +**Export** from PyTorch to ONNX: ```bash -# Export a Hugging Face model to ONNX -uv run winml export --model microsoft/resnet-50 --output ./output +winml export -m facebook/convnext-base-224 -o convnext/model.onnx -v +``` -# Analyze an ONNX model -uv run winml analyze --model ./output/model.onnx +**Analyze** for EP compatibility: -# Quantize an ONNX model -uv run winml quantize --model ./output/model.onnx +```bash +winml analyze -m convnext/model.onnx --optim-config optim.json ``` -## Contributions and Feedback +**Optimize** the graph using the analyzer's config: + +```bash +winml optimize -m convnext/model.onnx -c optim.json -o convnext/model_opt.onnx +``` + +**Quantize** to INT8: + +```bash +winml quantize -m convnext/model_opt.onnx -o convnext/model_opt_int8.onnx +``` + +#### Phase 3: Benchmark on Device + +**Compile** for NPU (generates device-specific binaries): + +```bash +winml compile -m convnext/model_opt_int8.onnx --ep qnn -o convnext/model_compiled.onnx +``` + +**Benchmark on NPU** — note the latency: + +```bash +winml perf -m convnext/model_compiled.onnx --ep qnn --iterations 100 +``` + +**Benchmark on CPU** for comparison: + +```bash +winml perf -m convnext/model_opt.onnx --ep cpu --iterations 100 +``` + +Compare the two numbers to see the performance difference between NPU and CPU inference. + +### Build with Config + Build + +Same model, different approach. Instead of running each command manually, use the config-driven pipeline. Think of it like CMake: `config` generates a build plan, `build` executes it. + +**Generate the build config:** + +```bash +winml config -m facebook/convnext-base-224 -o convnext_config.json +``` + +This creates a JSON file containing all settings for every pipeline step — task, I/O shapes, optimization flags, quantization parameters — all auto-detected from the model. + +**Build the model:** + +```bash +winml build -c convnext_config.json -m facebook/convnext-base-224 -o convnext_build/ +``` + +This orchestrates the full pipeline — export, analyze, optimize, quantize, compile — all in one go. Same result as the manual steps above, but in two commands. + +**Benchmark the result:** + +```bash +winml perf -m convnext_build/model.onnx --ep qnn --iterations 100 +``` + +The config file is the single source of truth for your build. Version-control it, share it with teammates, edit it to override settings, and replay builds deterministically on any machine. + +### Benchmark in One Command + +The simplest way to evaluate a model — one command, zero setup: + +```bash +winml perf -m facebook/convnext-base-224 --device npu --monitor +``` + +ModelKit handles everything behind the scenes: download the model from Hugging Face, export to ONNX, optimize the graph, and run the benchmark on your NPU. The `--monitor` flag enables live hardware monitoring — real-time CPU utilization, RAM usage, and NPU activity alongside the latency results. + +This is ideal for quick smoke tests: does the model run on this device, and how fast is it? + +--- + +## :arrows_counterclockwise: The BYOM Workflow + +The **Build Your Own Model** (BYOM) workflow is the philosophy behind ModelKit. It defines how a source model becomes a production-ready, device-optimized artifact. + +### The Pipeline + +``` +Source Model --> Export --> Analyze --> Optimize --> Quantize --> Compile --> Benchmark +``` + +![BYOM Workflow](docs/assets/workflow-only.svg) + +Each arrow is a ModelKit command. You can enter the pipeline at any stage (for example, start with a local ONNX file and skip export), exit early (stop after optimization if you do not need quantization), or loop back to repeat a stage with different settings. + +### Primitive Commands vs. Config-Driven Pipeline + +| | **Primitive Commands** | **Config-Driven Pipeline** | +|:--|:--|:--| +| **Steps** | One command **per stage** | Two steps: **config** + **build** | +| **Control** | Start from any stage; try different settings to fix errors or tweak performance | Repeatable, tweakable, version-controllable | +| **Best for** | **Flexible** workflow | Production-ready **delivery** | +| **When to use** | Exploring, debugging, prototyping | CI/CD, batch builds, team workflows | +| **Lifecycle** | "Coding" phase | Polish | + +--- + +## :clipboard: Built-in Models + +Run `winml hub` to browse the full catalog interactively. + +
+Click to expand the full model catalog + +| Model ID | Task | Architecture | +|:---------|:-----|:-------------| +| `microsoft/resnet-50` | image-classification | ResNet | +| `google/vit-base-patch16-224` | image-classification | ViT | +| `microsoft/swin-large-patch4-window7-224` | image-classification | Swin | +| `facebook/convnext-tiny-224` | image-classification | ConvNeXT | +| `rizvandwiki/gender-classification` | image-classification | ViT | +| `ProsusAI/finbert` | text-classification | BERT | +| `Intel/bert-base-uncased-mrpc` | text-classification | BERT | +| `cardiffnlp/twitter-roberta-base-sentiment-latest` | text-classification | RoBERTa | +| `dslim/bert-base-NER` | token-classification | BERT | +| `dbmdz/bert-large-cased-finetuned-conll03-english` | token-classification | BERT | +| `Babelscape/wikineural-multilingual-ner` | token-classification | BERT | +| `w11wo/indonesian-roberta-base-posp-tagger` | token-classification | RoBERTa | +| `microsoft/table-transformer-detection` | object-detection | Table Transformer | +| `mattmdjaga/segformer_b2_clothes` | image-segmentation | SegFormer | +| `nvidia/segformer-b1-finetuned-ade-512-512` | image-segmentation | SegFormer | +| `nvidia/segformer-b2-finetuned-ade-512-512` | image-segmentation | SegFormer | +| `nvidia/segformer-b5-finetuned-ade-640-640` | image-segmentation | SegFormer | + +
+ +These models are verified against ModelKit's full pipeline and serve as reliable starting points. You are not limited to this list — any Hugging Face model that passes `winml inspect` is a valid input. + +For models not in this table, run `winml inspect -m ` to verify support before proceeding. + +--- + +## :warning: Scope & Limitations + +### What ModelKit supports + +ModelKit targets **classic deep learning models** — CNNs, encoders, vision transformers, NLP classifiers, token classifiers, object detection models, and segmentation models. + +Supported tasks include: +- Image classification (ResNet, ViT, Swin, ConvNeXT) +- Text classification (BERT, RoBERTa) +- Token classification / NER (BERT, RoBERTa) +- Object detection (Table Transformer) +- Image segmentation (SegFormer) + +### What ModelKit does not support + +**LLMs and generative models are not in scope.** Do not use ModelKit with GPT, LLaMA, Phi, Mistral, Stable Diffusion, or any model with a decoder-only or sequence-to-sequence generative architecture. LLM support (with LoRA) is planned for Q3-Q4 2026. + +### Known constraints + +- `winml compile` requires an NPU device. If no NPU is available, skip the compile step and use `--device auto` for benchmarking. +- Some models may export successfully but fail during optimization or quantization due to unsupported operator patterns. The analyzer will flag these issues. +- Performance numbers vary by device, driver version, and EP version. Always benchmark on your target hardware. + +--- + +## :world_map: Roadmap + +| Milestone | Target | Highlights | +|:----------|:-------|:-----------| +| 🟡 **Kickoff** | Q4 2025 | Internal prototype, core primitive commands | +| 🟢 **Early Access** | Q1 2026 | First external testers, config + build pipeline, hub catalog | +| 🔵 **Public Beta** | Q2 2026 | Open source, agent skills, AI Toolkit integration | +| 🟣 **RC** | Q3-Q4 2026 | **LLM support** (with LoRA), broader device coverage, MLIR | + +
+Click to expand roadmap details + +**Q4 2025 — Kickoff** +- Primitive commands: `inspect`, `export`, `optimize`, `quantize`, `compile` +- QNN, OpenVINO, and VitisAI execution provider support +- Internal validation with ResNet, BERT, ViT, SegFormer families + +**Q1 2026 — Early Access** +- Pipeline commands: `config`, `build`, `perf`, `eval` +- Analyzer with auto-configuration loop +- Built-in model catalog (`winml hub`) +- Live hardware monitoring (`--monitor`) + +**Q2 2026 — Public Beta** +- Open source release +- Agent-ready skills for coding assistants (Claude Code, Cursor, Copilot) +- AI Toolkit for VS Code integration + +**Q3-Q4 2026 — Release Candidate** +- LLM support (decoder-only architectures with LoRA adapters) +- TensorRT, MIGraphX, and DirectML execution providers +- MLIR-based optimization backend +- Public SDK and framework APIs + +
+ +--- + +## :handshake: Contributions and Feedback We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md). For feature requests or bug reports, please file a [GitHub Issue](https://github.com/microsoft/ModelKit/issues). +--- -## Code of Conduct +## :balance_scale: Code of Conduct See [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). -## License +--- + +## :page_facing_up: License This project is licensed under the [MIT License](LICENSE.txt). +--- + ## Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft diff --git a/docs/assets/workflow-only.svg b/docs/assets/workflow-only.svg new file mode 100644 index 000000000..c7ce33600 --- /dev/null +++ b/docs/assets/workflow-only.svg @@ -0,0 +1,608 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/winml/modelkit/__init__.py b/src/winml/modelkit/__init__.py index c10fcc9de..59a45c313 100644 --- a/src/winml/modelkit/__init__.py +++ b/src/winml/modelkit/__init__.py @@ -28,15 +28,13 @@ model = WinMLAutoModel.from_pretrained("facebook/convnext-tiny-224", config=config) """ +import logging from importlib.metadata import PackageNotFoundError, version + +logging.getLogger(__name__).addHandler(logging.NullHandler()) + from . import _warnings # Configure warning filters before importing subpackages -from .config import WinMLBuildConfig -from .models import ( - WinMLAutoModel, - WinMLModelForImageClassification, - WinMLPreTrainedModel, -) try: @@ -51,3 +49,33 @@ "WinMLPreTrainedModel", "__version__", ] + + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "WinMLBuildConfig": (".config", "WinMLBuildConfig"), + "WinMLAutoModel": (".models", "WinMLAutoModel"), + "WinMLPreTrainedModel": (".models", "WinMLPreTrainedModel"), + "WinMLModelForImageClassification": (".models", "WinMLModelForImageClassification"), +} + + +def __getattr__(name: str): + """Lazy-load heavy exports on first access (PEP 562). + + This avoids importing torch/transformers/optimum (~30s) when only + lightweight operations are needed (e.g., ``winml --help``). + """ + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + """Include lazy attributes in dir() for debugger/IPython compatibility.""" + return list(set(list(globals()) + __all__)) diff --git a/src/winml/modelkit/_warnings.py b/src/winml/modelkit/_warnings.py index 530be88ab..68de9ab99 100644 --- a/src/winml/modelkit/_warnings.py +++ b/src/winml/modelkit/_warnings.py @@ -44,38 +44,41 @@ class _DiffusersDistributionFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: return "Multiple distributions found" not in record.getMessage() - logging.getLogger("diffusers.utils.import_utils").addFilter( - _DiffusersDistributionFilter() - ) + logging.getLogger("diffusers.utils.import_utils").addFilter(_DiffusersDistributionFilter()) - class _HFPipelineFalsePositiveFilter(logging.Filter): - """Filter false-positive HF pipeline warnings when using WinML models. + class _PipelineNoiseFilter(logging.Filter): + """Filter noisy HF Pipeline warnings. - HF pipeline emits these because WinMLModel wraps ONNX via ORT, not a - native HF model class. These are expected and not actionable. + - 'The model X is not supported for Y' — WinML models are duck-type + compatible but not in HF's supported list. + - 'Device set to use cpu' — HF Pipeline forces CPU, we handle device. + - 'Using a slow image processor' — cosmetic deprecation notice. """ - _FALSE_POSITIVES = ( - "WinMLModel", # False positive warning which says WinML is not native HF model class - "Device set to use", # PyTorch tensor device, not ONNX device - "Using a slow image processor", # expected when using processor with pipeline. + _SUPPRESSED = ( + "is not supported for", + "Device set to use cpu", + "Using a slow image processor", ) def filter(self, record: logging.LogRecord) -> bool: msg = record.getMessage() - return not any(phrase in msg for phrase in self._FALSE_POSITIVES) + return not any(s in msg for s in self._SUPPRESSED) - for _name in ( - "transformers.pipelines.base", - "transformers.models.auto.image_processing_auto", - ): - logging.getLogger(_name).addFilter(_HFPipelineFalsePositiveFilter()) + logging.getLogger("transformers.pipelines.base").addFilter(_PipelineNoiseFilter()) # ========================================================================= # Warning filters (for warnings.warn() calls) # ========================================================================= - warnings.filterwarnings("ignore", category=FutureWarning, module=r"transformers.*") - warnings.filterwarnings("ignore", category=UserWarning, module=r"torch.*") + # Transformers: suppress cosmetic warnings (not RuntimeWarning/ResourceWarning) + for _cat in (FutureWarning, DeprecationWarning, UserWarning): + warnings.filterwarnings("ignore", category=_cat, module=r"transformers\..*") + + # PyTorch: suppress cosmetic warnings (not RuntimeWarning/ResourceWarning) + for _cat in (FutureWarning, DeprecationWarning, UserWarning): + warnings.filterwarnings("ignore", category=_cat, module=r"torch\..*") + + # Diffusers warnings.filterwarnings( "ignore", message=r".*CUDA.*", category=UserWarning, module=r"diffusers.*" ) diff --git a/src/winml/modelkit/analyze/analyzer.py b/src/winml/modelkit/analyze/analyzer.py index edb119412..e44a804e3 100644 --- a/src/winml/modelkit/analyze/analyzer.py +++ b/src/winml/modelkit/analyze/analyzer.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + import onnx from .models.information import Action @@ -492,6 +494,8 @@ def analyze( htp_metadata_path: str | None = None, run_unknown_op: bool = True, save_node_types: set[str] | None = None, + on_node_result: Callable | None = None, + on_ep_start: Callable | None = None, ) -> AnalysisResult: """Analyze ONNX model for runtime support. @@ -590,6 +594,8 @@ def analyze( htp_metadata_path=htp_metadata_path, run_unknown_op=run_unknown_op, save_node_types=save_node_types, + on_node_result=on_node_result, + on_ep_start=on_ep_start, ) def analyze_from_proto( @@ -602,6 +608,8 @@ def analyze_from_proto( htp_metadata_path: str | None = None, run_unknown_op: bool = True, save_node_types: set[str] | None = None, + on_node_result: Callable | None = None, + on_ep_start: Callable | None = None, ) -> AnalysisResult: """Analyze ONNX model from ModelProto object. @@ -691,6 +699,11 @@ def analyze_from_proto( for current_ep in eps_to_analyze: logger.info("Checking runtime support for %s...", current_ep) + if on_ep_start: + try: + on_ep_start(current_ep, metadata.operator_counts) + except Exception: + logger.debug("on_ep_start callback failed", exc_info=True) runtime_checker = RuntimeChecker( ep=current_ep, @@ -708,6 +721,7 @@ def analyze_from_proto( patterns=pattern_matches, run_unknown_op=run_unknown_op_for_ep, save_node_types=save_node_types, + on_node_result=on_node_result, ) # Convert runtime summary to expected format @@ -727,7 +741,6 @@ def analyze_from_proto( ep=current_ep, model=onnx_model, device=device_to_use, - shape_inferred_model_proto=runtime_checker.get_shape_inferred_model_proto(), ) information_list[current_ep] = engine.summary() # Use EP name as key @@ -786,6 +799,8 @@ def analyze_onnx( ep: str | None = None, device: str | None = None, autoconf: bool = True, + on_ep_start: Callable | None = None, + on_node_result: Callable | None = None, ) -> AnalyzeResult: """Analyze an ONNX model and return lint + autoconf results. @@ -841,6 +856,8 @@ def analyze_onnx( ep=ep, device=device, enable_information=autoconf, + on_ep_start=on_ep_start, + on_node_result=on_node_result, ) # Extract lint result (always computed — uses RuntimeChecker classification) diff --git a/src/winml/modelkit/analyze/core/runtime_checker.py b/src/winml/modelkit/analyze/core/runtime_checker.py index 87145e07c..4c8e65054 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker.py +++ b/src/winml/modelkit/analyze/core/runtime_checker.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + import onnx from winml.modelkit.pattern.match import PatternMatchResult @@ -142,21 +144,35 @@ def _get_query(self) -> RuntimeCheckerQuery: return self._query - def get_shape_inferred_model_proto(self) -> onnx.ModelProto | None: - """Return the shape-inferred model proto from the cached query, if available.""" - if self._query is not None: - return self._query.model_proto - return None - def op_support( self, run_unknown_op: bool = True, save_node_types: set[str] | None = None, + on_node_result: Callable | None = None, ) -> list[PatternRuntime]: """Check operator-level runtime support. Returns operator-level runtime check results for each operator. + Args: + on_node_result: Optional per-node progress callback. + When provided, tqdm progress bar is suppressed (caller + handles progress display via Rich Live). + + Signature:: + + (result: PatternRuntime) -> None + + The ``PatternRuntime`` passed to the callback has: + + - ``pattern_id`` (str): Full pattern ID, e.g. + ``"OP/ai.onnx/Conv"``. Use ``split("/")[-1]`` to get + the display name (``"Conv"``). + - ``result.classification`` (SupportLevel): The support + level enum. Call ``.value`` to get the string, e.g. + ``"supported"``, ``"partial"``, ``"unsupported"``, + ``"unknown"``. + Returns: List[PatternRuntime]: Runtime results for each operator pattern @@ -177,15 +193,21 @@ def op_support( model_proto = self._model.get_model() # Get cached RuntimeCheckerQuery query = self._get_query() - for node in tqdm.tqdm(model_proto.graph.node): - # Run runtime check for node - results.append( # noqa: PERF401 - query.run_for_node( - node, - run_unknown_op=run_unknown_op, - save_node_types=save_node_types, - ) + # Use tqdm for progress unless caller provides a callback + nodes = model_proto.graph.node + iterator = nodes if on_node_result else tqdm.tqdm(nodes) + for node in iterator: + result = query.run_for_node( + node, + run_unknown_op=run_unknown_op, + save_node_types=save_node_types, ) + results.append(result) + if on_node_result: + try: + on_node_result(result) + except Exception: + logger.debug("on_node_result callback failed", exc_info=True) logger.info("Checked %d operators", len(results)) @@ -302,6 +324,7 @@ def summary( patterns: list[PatternMatchResult] | None = None, run_unknown_op: bool = True, save_node_types: set[str] | None = None, + on_node_result: Callable | None = None, ) -> dict[str, list[PatternRuntime]]: """Combine operator-level & pattern-level runtime results. @@ -325,6 +348,7 @@ def summary( op_results = self.op_support( run_unknown_op=run_unknown_op, save_node_types=save_node_types, + on_node_result=on_node_result, ) summary_dict["op_runtime_check_result"] = op_results diff --git a/src/winml/modelkit/build/common.py b/src/winml/modelkit/build/common.py index d64cddbfd..096700cf3 100644 --- a/src/winml/modelkit/build/common.py +++ b/src/winml/modelkit/build/common.py @@ -35,6 +35,11 @@ def run_optimize_analyze_loop( ep: str | None = None, device: str | None = None, max_optim_iterations: int = 0, + on_ep_start: Any = None, + on_node_result: Any = None, + on_iteration_start: Any = None, + on_patterns_discovered: Any = None, + on_reoptimize: Any = None, **onnx_kwargs: Any, ) -> tuple[Path, float, int, int, dict]: """Optimize an ONNX model, analyze, and optionally re-optimize via autoconf. @@ -72,18 +77,77 @@ def run_optimize_analyze_loop( **onnx_kwargs, **config.optim, ) + current_path = optimized_path + + # Autoconf: analyze model, discover missing optimizations, re-optimize + if max_optim_iterations > 0: + analyze_iterations, analyze_black_nodes, analyze_details = _run_analyze_loop( + optimized_path=optimized_path, + ep=ep, + device=device, + max_optim_iterations=max_optim_iterations, + config=config, + on_ep_start=on_ep_start, + on_node_result=on_node_result, + on_iteration_start=on_iteration_start, + on_patterns_discovered=on_patterns_discovered, + on_reoptimize=on_reoptimize, + **onnx_kwargs, + ) + else: + analyze_iterations, analyze_black_nodes, analyze_details = 0, 0, {} + + elapsed = time.monotonic() - t0 + return current_path, elapsed, analyze_iterations, analyze_black_nodes, analyze_details + - # 2. Analyze - analysis = analyze_onnx(optimized_path, ep=ep, device=device) - analyze_count = 1 +def _run_analyze_loop( + *, + optimized_path: Path, + ep: str | None, + device: str | None, + max_optim_iterations: int, + config: WinMLBuildConfig, + on_ep_start: Any = None, + on_node_result: Any = None, + on_iteration_start: Any = None, + on_patterns_discovered: Any = None, + on_reoptimize: Any = None, + **kwargs: Any, +) -> tuple[int, int, dict]: + """Run iterative analyzer autoconf loop in a temp folder. + + Each iteration applies ONLY the autoconf flags (not merged with original). + A separate dict accumulates all discovered flags for persistence. + """ + analyze_iterations = 0 + analyze_black_nodes = 0 discovered_optim: dict[str, bool] = {} + analysis = None + _not_converged = False # 3. Autoconf re-optimization loop with tempfile.TemporaryDirectory() as tmp: iter_model = Path(tmp) / "iter.onnx" - copied = False + copy_onnx_model(optimized_path, iter_model) for _iteration in range(max_optim_iterations): + # Notify: iteration starting + if on_iteration_start is not None: + on_iteration_start( + _iteration + 1, + max_optim_iterations, + ) + + analysis = analyze_onnx( + iter_model, + ep=ep, + device=device, + on_ep_start=on_ep_start, + on_node_result=on_node_result, + ) + analyze_iterations += 1 + if not analysis.autoconf: break @@ -93,23 +157,41 @@ def run_optimize_analyze_loop( analysis.optimization_config.to_dict(), ) - if not copied: - copy_onnx_model(optimized_path, iter_model) - copied = True + # Notify: patterns discovered + if on_patterns_discovered is not None: + on_patterns_discovered(analysis.optimization_config) + # Notify: re-optimizing with discovered flags + if on_reoptimize is not None: + on_reoptimize(analysis.optimization_config) + + # Re-optimize with ONLY the autoconf flags (not merged with original) optimize_onnx( model=iter_model, output=iter_model, - **onnx_kwargs, + **kwargs, **analysis.optimization_config, ) discovered_optim.update(analysis.optimization_config) + else: + logger.warning( + "Autoconf did not converge after %d iteration(s)", + max_optim_iterations, + ) + _not_converged = True + + # Always analyze final state (validates after last optimize). + # Pass a no-op on_node_result to suppress tqdm (which would + # break the Rich Live display). No on_ep_start to avoid + # duplicate EP bars. + analysis = analyze_onnx( + iter_model, + ep=ep, + device=device, + on_node_result=lambda _: None, + ) - analysis = analyze_onnx(iter_model, ep=ep, device=device) - analyze_count += 1 - - if copied: - copy_onnx_model(iter_model, optimized_path) + copy_onnx_model(iter_model, optimized_path) # 4. Wrap up if discovered_optim: @@ -122,22 +204,27 @@ def run_optimize_analyze_loop( analysis.optimization_config.to_dict(), ) - if analysis.has_errors: + if analysis is not None and analysis.has_errors: raise RuntimeError( - f"Unsupported nodes persist after {analyze_count} analyze " + f"Unsupported nodes persist after {analyze_iterations} analyze " f"pass(es): {analysis.lint.error_patterns}" ) - details = { - "lint": { - "errors": analysis.lint.errors, - "warnings": analysis.lint.warnings, - "passed": analysis.lint.passed, - "error_patterns": analysis.lint.error_patterns, - "warning_patterns": analysis.lint.warning_patterns, - }, - "autoconf": discovered_optim or {}, - } - - elapsed = time.monotonic() - t0 - return optimized_path, elapsed, analyze_count, analysis.lint.errors, details + analyze_black_nodes = analysis.lint.errors if analysis else 0 + + # Build details for manifest + details: dict = {} + if analysis: + details = { + "lint": { + "errors": analysis.lint.errors, + "warnings": analysis.lint.warnings, + "passed": analysis.lint.passed, + "error_patterns": analysis.lint.error_patterns, + "warning_patterns": analysis.lint.warning_patterns, + }, + "autoconf": discovered_optim or {}, + "autoconf_not_converged": _not_converged, + } + + return analyze_iterations, analyze_black_nodes, details diff --git a/src/winml/modelkit/cli.py b/src/winml/modelkit/cli.py index 263494c87..bc4f89c54 100644 --- a/src/winml/modelkit/cli.py +++ b/src/winml/modelkit/cli.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- """WinML ModelKit CLI - Universal ONNX export from command line. -This module provides the main CLI entry point for ModelKit with automatic +This module provides the main CLI entry point for ModelKit with lazy command discovery from the commands/ directory. Usage: @@ -19,6 +19,7 @@ from __future__ import annotations +import ast import logging from importlib import import_module from pathlib import Path @@ -26,94 +27,130 @@ import click from . import __version__ +from .utils.logging import configure_logging logger = logging.getLogger(__name__) +_COMMANDS_DIR = Path(__file__).parent / "commands" -@click.group() -@click.version_option(version=__version__, prog_name="winml") -@click.option( - "--debug", - is_flag=True, - default=False, - help="Enable debug logging", -) -@click.pass_context -def main(ctx: click.Context, debug: bool) -> None: - """WML ModelKit - Accelerate Model Deployment on WinML. - - Universal ONNX export with QNN and OpenVINO backend support. - """ - # Configure logging based on debug flag - log_level = logging.DEBUG if debug else logging.INFO - logging.basicConfig( - level=log_level, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - # Store debug flag in context for subcommands - ctx.ensure_object(dict) - ctx.obj["debug"] = debug - - -def _discover_commands() -> None: - """Auto-discover Click commands from commands/ directory. - This function scans the commands/ directory for Python modules and - registers any Click commands found. Commands are registered using - the module filename as the command name. +def _parse_click_help(path: Path) -> str: + """Extract short help from a command module without importing it. - Command Discovery Rules: - - Skips files starting with underscore (_) - - Looks for any object that is a click.Command instance - - Uses module filename (without .py) as command name + Parses the module's AST to find the first decorated function's docstring, + which Click uses as the command help text. + """ + try: + tree = ast.parse(path.read_text(encoding="utf-8")) + except (SyntaxError, OSError): + return "" + + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef) and node.decorator_list: + docstring = ast.get_docstring(node) + if docstring: + # Return first line only (Click's short help) + return docstring.split("\n")[0] + return "" + + +class LazyGroup(click.Group): + """Click group that defers command module imports until invoked. + + Instead of importing every command module at startup, this group reads + command names from the filesystem and only imports a module when the + user actually invokes that command. Help text is extracted via AST + parsing (no module execution). """ - commands_dir = Path(__file__).parent / "commands" - - # Early exit if commands directory doesn't exist - if not commands_dir.exists(): - logger.debug("Commands directory not found: %s", commands_dir) - return - # Scan for Python modules - for py_file in commands_dir.glob("*.py"): - # Skip private modules - if py_file.name.startswith("_"): - continue + def list_commands(self, ctx: click.Context) -> list[str]: + """Return command names from filesystem — no module imports.""" + if not _COMMANDS_DIR.exists(): + return [] + return sorted(p.stem for p in _COMMANDS_DIR.glob("*.py") if not p.name.startswith("_")) - module_name = py_file.stem + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + """Import command module only when the command is actually invoked.""" try: - # Import the module module = import_module( - f".commands.{module_name}", + f".commands.{cmd_name}", package=__package__, ) - - # Find Click command in module - # Prefer click.Group over click.Command for hierarchical commands - discovered_command = None - for attr_name in dir(module): - attr = getattr(module, attr_name) - if isinstance(attr, click.Group): - discovered_command = attr - break - if isinstance(attr, click.Command) and discovered_command is None: - discovered_command = attr - - if discovered_command: - # Register command with module name - main.add_command(discovered_command, name=module_name) - logger.debug("Discovered command: %s", module_name) - except ImportError as e: - logger.warning("Failed to import command module %s: %s", module_name, e) + logger.warning("Failed to import command module %s: %s", cmd_name, e) + return None except Exception as e: - logger.error("Error loading command %s: %s", module_name, e) + logger.error("Error loading command %s: %s", cmd_name, e) + return None + + # Find Click command in module (prefer Group over Command) + discovered = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, click.Group): + return attr + if isinstance(attr, click.Command) and discovered is None: + discovered = attr + return discovered + + def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + """Format command list using AST-parsed help (no module imports).""" + commands = [] + for cmd_name in self.list_commands(ctx): + help_text = _parse_click_help(_COMMANDS_DIR / f"{cmd_name}.py") + commands.append((cmd_name, help_text)) + + if commands: + limit = max(1, formatter.width - 6 - max(len(name) for name, _ in commands)) + rows = [] + for name, help_text in commands: + short = help_text[:limit].rstrip() if help_text else "" + rows.append((name, short)) + + with formatter.section("Commands"): + formatter.write_dl(rows) + + +@click.group(cls=LazyGroup) +@click.version_option(version=__version__, prog_name="winml") +@click.option( + "--verbose", + "-v", + count=True, + help="Increase verbosity (-v=INFO, -vv=DEBUG)", +) +@click.option( + "--quiet", + "-q", + is_flag=True, + default=False, + help="Quiet mode - errors only", +) +@click.option( + "--debug", + is_flag=True, + default=False, + help="Alias for -vv (DEBUG logging)", + hidden=True, +) +@click.pass_context +def main(ctx: click.Context, verbose: int, quiet: bool, debug: bool) -> None: + """WML ModelKit - Accelerate Model Deployment on WinML. + Universal ONNX export with QNN and OpenVINO backend support. + """ + # --debug is a backward-compat alias for -vv + if debug: + verbose = max(verbose, 2) -# Discover and register commands at module load time -_discover_commands() + configure_logging(verbosity=verbose, quiet=quiet) + + # Store verbosity in context for subcommands + ctx.ensure_object(dict) + ctx.obj["debug"] = debug or verbose >= 2 + ctx.obj["verbosity"] = verbose + ctx.obj["quiet"] = quiet if __name__ == "__main__": diff --git a/src/winml/modelkit/commands/live_chart.py b/src/winml/modelkit/commands/_live_chart.py similarity index 100% rename from src/winml/modelkit/commands/live_chart.py rename to src/winml/modelkit/commands/_live_chart.py diff --git a/src/winml/modelkit/commands/analyze.py b/src/winml/modelkit/commands/analyze.py index 7c628a884..0c4874d23 100644 --- a/src/winml/modelkit/commands/analyze.py +++ b/src/winml/modelkit/commands/analyze.py @@ -4,17 +4,11 @@ # -------------------------------------------------------------------------- """Analyze command for winml CLI. -This module provides the analyze command that analyzes ONNX models -for runtime support across NPU execution providers. +Analyzes ONNX models for runtime support with Rich Live stacked bar +visualization, showing real-time per-node progress display. Usage: - winml analyze --model MODEL --ep EP --device DEVICE [OPTIONS] - -Examples: - winml analyze --model model.onnx --ep QNNExecutionProvider --device NPU - winml analyze --model model.onnx --ep qnn --device NPU - winml analyze --model model.onnx --ep ov --device GPU --information - winml analyze --model model.onnx --ep vitis --device GPU --output results.json + winml analyze --model MODEL [--ep EP] [--device DEVICE] [OPTIONS] """ from __future__ import annotations @@ -24,6 +18,11 @@ from pathlib import Path import click +from rich.console import Console +from rich.live import Live +from rich.logging import RichHandler +from rich.table import Table +from rich.text import Text from ..utils import cli as cli_utils from ..utils.constants import normalize_ep_name @@ -32,8 +31,354 @@ logger = logging.getLogger(__name__) - -@click.command(name="analyze") # type: ignore[misc] +# ── Rich visualization helpers ──────────────────────────────────────────── + +MAX_BAR_WIDTH = 40 + +_COLORS = { + "supported": "green", + "partial": "yellow", + "unsupported": "red", + "unknown": "bright_black", +} + + +def _display_name(pattern_id: str) -> str: + """Extract operator display name from pattern_id ('OP/ai.onnx/Conv' -> 'Conv').""" + return pattern_id.split("/")[-1] + + +_LEVEL_ICONS = [ + ("unsupported", "🔴"), + ("partial", "🟡"), + ("unknown", "🔵"), +] + + +def _worst_level_icon(counts: dict[str, int]) -> str: + """Return icon for the worst support level present (lower bound).""" + for level, icon in _LEVEL_ICONS: + if counts.get(level, 0) > 0: + return icon + return "🟢" + + +def _build_stacked_bar(counts: dict[str, int], max_count: int) -> Text: + """Build a stacked bar where total width is proportional to max_count.""" + total = sum(counts.values()) + if total == 0: + return Text() + + bar_width = max(1, round(total / max_count * MAX_BAR_WIDTH)) + # Ensure bar can fit all non-zero segments + nonzero = sum(1 for v in counts.values() if v > 0) + bar_width = max(bar_width, nonzero) + + bar = Text() + chars_used = 0 + + for level in ("supported", "partial", "unsupported", "unknown"): + count = counts.get(level, 0) + if count == 0: + continue + width = max(1, round(count / total * bar_width)) + width = min(width, bar_width - chars_used) + bar.append("█" * width, style=_COLORS[level]) + chars_used += width + + return bar + + +def _build_analyzed_text(counts: dict[str, int]) -> Text: + """Build 'W/G/B' format like '53/0/0' or '12/5/1' with colors.""" + w = counts.get("supported", 0) + g = counts.get("partial", 0) + b = counts.get("unsupported", 0) + u = counts.get("unknown", 0) + + text = Text() + text.append(str(w), style="bold green") + text.append("/", style="dim") + text.append(str(g), style="bold yellow" if g > 0 else "dim") + text.append("/", style="dim") + text.append(str(b), style="bold red" if b > 0 else "dim") + if u > 0: + text.append("/", style="dim") + text.append(str(u), style="bold bright_black") + return text + + +def _build_analysis_table( + data: dict[str, dict[str, int]], + ep_name: str = "", + complete: bool = False, + all_ops: dict[str, int] | None = None, +) -> Table: + """Build the analysis table with variable-width stacked bars. + + Args: + data: Per-op instance counts (filled in as analysis progresses). + Ops with data show colored bars (partial or complete). + Ops in all_ops but not in data show dim pending rows. + ep_name: EP name for title + complete: Show complete marker + all_ops: All op types with total counts (for showing pending rows) + """ + # Build display order: all_ops sorted by count, or just data if no all_ops + if all_ops: + display_order = sorted(all_ops, key=lambda x: all_ops[x], reverse=True) + else: + display_order = sorted(data, key=lambda x: sum(data[x].values()), reverse=True) + + # Max count for bar width scaling (anchored to all_ops for stable bars during animation) + if all_ops: + max_count = max(all_ops.values(), default=1) + else: + max_count = max((sum(v.values()) for v in data.values()), default=1) + + title = "📊 OP CHECK" + if ep_name: + title += f" — [bold cyan]{ep_name}[/bold cyan]" + if complete: + title += " [bold green]✅ Complete[/bold green]" + + table = Table( + title=title, + show_header=True, + header_style="bold", + box=None, + padding=(0, 1), + expand=False, + ) + + table.add_column("Op Type", width=28, no_wrap=True) + table.add_column("S/P/U", width=14, no_wrap=True) + table.add_column("", no_wrap=True) + + agg: dict[str, int] = {"supported": 0, "partial": 0, "unsupported": 0, "unknown": 0} + + for op_type in display_order: + total = all_ops.get(op_type, 0) if all_ops else sum(data.get(op_type, {}).values()) + counts = data.get(op_type) + + if not counts: + # No data yet — fully pending + bar_width = max(1, round(total / max_count * MAX_BAR_WIDTH)) if max_count else 1 + table.add_row( + Text(f" {op_type} ({total})", style="dim"), + Text("...", style="dim"), + Text("░" * bar_width, style="dim"), + ) + else: + # Has data — show progress (partial or complete) + analyzed_for_op = sum(counts.values()) + for level in agg: + agg[level] += counts.get(level, 0) + + icon = _worst_level_icon(counts) + op_label = Text() + op_label.append(f"{icon} ") + op_label.append(op_type, style="cyan") + if analyzed_for_op < total: + op_label.append(f" ({analyzed_for_op}/{total})", style="dim") + else: + op_label.append(f" ({total})", style="dim") + + # Build bar: colored portion (analyzed) + dim portion (remaining) + bar = _build_stacked_bar(counts, max_count) + remaining = total - analyzed_for_op + if remaining > 0: + remaining_width = max(1, round(remaining / max_count * MAX_BAR_WIDTH)) + bar.append("░" * remaining_width, style="dim") + + table.add_row(op_label, _build_analyzed_text(counts), bar) + + # Summary row + table.add_section() + total_ops = sum(all_ops.values()) if all_ops else sum(agg.values()) + analyzed_count = sum(agg.values()) + total_label = Text() + total_label.append("TOTAL", style="bold") + if analyzed_count < total_ops: + total_label.append(f" ({analyzed_count}/{total_ops})", style="dim") + else: + total_label.append(f" ({total_ops})", style="dim") + + # TOTAL bar: colored portion + dim remainder + total_bar = _build_stacked_bar(agg, max(total_ops, 1)) + total_remaining = total_ops - analyzed_count + if total_remaining > 0: + total_remaining_width = max(1, round(total_remaining / max(total_ops, 1) * MAX_BAR_WIDTH)) + total_bar.append("░" * total_remaining_width, style="dim") + + table.add_row( + total_label, + _build_analyzed_text(agg), + total_bar, + ) + + return table + + +_STATUS_ICONS = {"s": "🟢", "p": "🟡", "u": "🔴", "uk": "🔵"} +_PATTERN_STATUS_LABELS = {"s": "supported", "p": "partial", "u": "unsupported", "uk": "unknown"} +_SUPPORT_LEVEL_TO_SHORT = { + "supported": "s", + "partial": "p", + "unsupported": "u", + "unknown": "uk", +} + + +_PAT_COLORS = {"s": "green", "p": "yellow", "u": "red", "uk": "bright_black"} + + +def _render_pattern_matching( + console: Console, + ep_patterns: dict[str, dict[str, dict]], +) -> None: + """Render the PATTERN MATCHING section — per-EP pattern support.""" + if not any(ep_patterns.values()): + return + + console.print("═" * 80) + console.print("🔍 [bold]PATTERN MATCHING[/bold]") + console.print("═" * 80) + + for ep_name, patterns in ep_patterns.items(): + if not patterns: + continue + + console.print(f" 💻 [bold cyan]{ep_name}[/bold cyan]") + + for pat_id, pat_info in sorted(patterns.items(), key=lambda x: x[1]["count"], reverse=True): + status = pat_info["status"] + count = pat_info["count"] + icon = _STATUS_ICONS.get(status, "❓") + label = _PATTERN_STATUS_LABELS.get(status, "unknown") + console.print( + f" {icon} [cyan]{pat_id}[/cyan] [dim]({count} instances)[/dim]" + f" — [{_PAT_COLORS.get(status, 'dim')}]{label}[/{_PAT_COLORS.get(status, 'dim')}]" + ) + + console.print() + + +def _extract_ep_patterns( + results: list, +) -> dict[str, dict[str, dict]]: + """Extract per-EP subgraph pattern support from analysis results. + + Args: + results: List of EPSupport objects from AnalysisOutput. + + Returns: + Dict keyed by EP name, containing dicts of pattern_id to + ``{"count": int, "status": str}`` where status is one of + ``"s"`` (supported), ``"p"`` (partial), ``"u"`` (unsupported), + ``"uk"`` (unknown). + """ + ep_patterns: dict[str, dict[str, dict]] = {} + for ep_support in results: + patterns: dict[str, dict] = {} + for info in ep_support.information: + if info.pattern_id and info.pattern_id.startswith("SUBGRAPH/"): + status = ( + _SUPPORT_LEVEL_TO_SHORT.get(info.status.value, "uk") if info.status else "uk" + ) + patterns[info.pattern_id] = { + "count": len(info.pattern_node_list), + "status": status, + } + ep_patterns[ep_support.ep_type] = patterns + return ep_patterns + + +def _render_analysis_summary( + console: Console, + results: list, + ep_instance_counts: dict[str, dict[str, dict[str, int]]], + ep_patterns: dict[str, dict[str, dict]] | None = None, +) -> None: + """Render the Analysis Summary section after pattern detection. + + Args: + console: Rich console for output. + results: List of EPSupport objects from AnalysisOutput. + ep_instance_counts: Per-EP instance counts accumulated during analysis, + keyed by EP name, then op name, then support level. + ep_patterns: Per-EP subgraph pattern support extracted from results. + """ + from ..analyze.models.support_level import SupportLevel + + console.print("═" * 80) + console.print("\U0001f4c8 [bold]ANALYSIS SUMMARY[/bold]") + console.print("═" * 80) + + for ep_support in results: + ep_name = ep_support.ep_type + + # Aggregate instance counts for this EP + ep_data = ep_instance_counts.get(ep_name, {}) + agg: dict[str, int] = {"supported": 0, "partial": 0, "unsupported": 0, "unknown": 0} + for counts in ep_data.values(): + for level in agg: + agg[level] += counts.get(level, 0) + + icon = _worst_level_icon(agg) + + # EP name style based on worst level + if agg.get("unsupported", 0) > 0: + ep_style = "bold red" + elif agg.get("partial", 0) > 0: + ep_style = "bold yellow" + elif agg.get("unknown", 0) > 0 and agg.get("supported", 0) == 0: + ep_style = "bold bright_black" + else: + ep_style = "bold green" + + analyzed = _build_analyzed_text(agg) + console.print(f" {icon} [{ep_style}]{ep_name}[/{ep_style}]: ", end="") + console.print(analyzed) + + # List ops by non-white support level + classification = ep_support.classification + _issue_sections = [ + (SupportLevel.UNSUPPORTED, "red", "\u26d4 Unsupported"), + (SupportLevel.PARTIAL, "yellow", "\u26a0\ufe0f Partial"), + (SupportLevel.UNKNOWN, "bright_black", "\u2753 Unknown"), + ] + for level, color, heading in _issue_sections: + ops = classification.get(level, []) + if ops: + console.print(f" [{color}]{heading}:[/{color}]") + for op in sorted(ops): + console.print(f" \u2022 [dim]{op}[/dim]") + + # List non-supported patterns for this EP + patterns = (ep_patterns or {}).get(ep_name, {}) + bad_patterns = {pid: p for pid, p in patterns.items() if p["status"] != "s"} + if bad_patterns: + console.print(" [dim]Patterns:[/dim]") + for pid, p in sorted(bad_patterns.items(), key=lambda x: x[1]["count"], reverse=True): + status = p["status"] + icon_p = _STATUS_ICONS.get(status, "\u2753") + label = _PATTERN_STATUS_LABELS.get(status, "unknown") + console.print( + f" {icon_p} [dim]{pid}[/dim] ({p['count']} instances, {label})" + ) + + has_issues = any(classification.get(lvl) for lvl, _, _ in _issue_sections) or bad_patterns + if not has_issues: + console.print(" [green]Ready to deploy[/green]") + + console.print() + + +# ── Click command ───────────────────────────────────────────────────────── + + +@click.command(name="analyze") @cli_utils.model_option(required=True) @cli_utils.ep_option( required=False, optional_message="If not specified, analyzes all supported EPs" @@ -42,176 +387,308 @@ required=False, optional_message="If not specified, uses NPU as default", default="NPU" ) @cli_utils.verbosity_options -@click.option( # type: ignore[misc] +@click.option( "--output", type=click.Path(path_type=Path), default=None, - help="Save JSON output to file (default: console display)", + help="Save JSON output to file", ) -@click.option( # type: ignore[misc] +@click.option( "--information/--no-information", default=True, - help="Include detailed recommendations in output (default: enabled)", + help="Include detailed recommendations (default: enabled)", ) -@click.option( # type: ignore[misc] +@click.option( "--htp-metadata", type=click.Path(exists=True, path_type=Path), default=None, help="Path to HTP metadata JSON file for enhanced pattern extraction", ) -@click.option( # type: ignore[misc] +@click.option( "--run-unknown-op/--no-run-unknown-op", default=True, help="Run unknown operators on local machine if possible (default: enabled)", ) -@click.option( # type: ignore[misc] +@click.option( "--save-node", multiple=True, type=click.Choice(["partial", "unsupported"], case_sensitive=False), help="Save specific node types for further analysis. Can be specified multiple times " "(e.g., --save-node partial --save-node unsupported).", ) +@click.option( + "--optim-config", + type=click.Path(path_type=Path), + default=None, + help="Save auto-discovered optimization config to JSON file", +) def analyze( model: Path, ep: str | None, device: str | None, output: Path | None, information: bool, - verbose: bool, + verbose: int, quiet: bool, htp_metadata: Path | None, run_unknown_op: bool, save_node: tuple[str, ...], + optim_config: Path | None, ) -> None: - r"""Analyze ONNX model for runtime support. + r"""Analyze ONNX model for runtime support with live progress. - Analyze ONNX model to determine runtime support status for the specified - execution provider and device. Performs static analysis to detect patterns - and check operator compatibility. + Performs static analysis to detect patterns and check operator + compatibility, showing real-time per-operator results. Exit Codes: - 0: Success - execution provider supports model + 0: Model fully supported - 1: Partial support - some unsupported operators + 1: Partial support — some unsupported operators - 2: Error - invalid input or analysis failure + 2: Error — invalid input or analysis failure Examples: - Analyze all supported EPs with default device: - - winml analyze --model model.onnx - - Check QNN NPU support (full name): - - winml analyze --model model.onnx --ep QNNExecutionProvider --device NPU - - Check QNN NPU support (using alias): - - winml analyze --model model.onnx --ep qnn --device NPU - - Check Intel OpenVINO GPU support with recommendations (using alias): - - winml analyze --model model.onnx --ep ov --device GPU --information - - Analyze all EPs and save results to file: - + \b + winml analyze --model model.onnx --ep qnn + winml analyze --model model.onnx --ep ov --device GPU winml analyze --model model.onnx --output results.json - - Use HTP metadata for enhanced pattern extraction: - - winml analyze --model model.onnx - --ep OpenVINOExecutionProvider --driver GPU --information --htp-metadata metadata.json """ - # Configure logging - configure_logging(verbose=verbose, quiet=quiet) + configure_logging(verbosity=verbose, quiet=quiet) try: - # Import core components - logger.debug("Importing static analyzer components...") - from ..analyze import ONNXStaticAnalyzer, __version__ - - logger.info("Using analyzer version: %s", __version__) + from ..analyze import ONNXStaticAnalyzer - # Validate model file + # Validate model if not model.exists(): logger.error("ONNX model file not found: %s", model) sys.exit(2) - logger.debug("Model path: %s", model) - logger.debug("Execution provider: %s", ep) - logger.debug("Device: %s", device) - logger.debug("Information: %s", information) - if htp_metadata: - logger.debug("HTP metadata path: %s", htp_metadata) - - # Normalize EP name (convert aliases to full names) ep_normalized = normalize_ep_name(ep) - if ep != ep_normalized: - logger.debug("EP alias '%s' normalized to '%s'", ep, ep_normalized) - # Run static analysis using ONNXStaticAnalyzer - logger.info("Running static analysis...") + logger.info("Analyzing model: %s", model) + logger.info("Target: %s on %s", ep_normalized or "all EPs", device) + analyzer = ONNXStaticAnalyzer() - save_node_types = set(save_node) - result = analyzer.analyze( - model_path=model, - ep=ep_normalized, - device=device, - enable_information=information, - htp_metadata_path=str(htp_metadata) if htp_metadata else None, - run_unknown_op=run_unknown_op, - save_node_types=save_node_types, - ) - - logger.info( - "Analysis complete: Model is %s", - "fully supported" if result.is_fully_supported() else "partially supported", - ) - - # Serialize to JSON - json_output = result.to_json() - - # Parse JSON for console display - import json - - from ..analyze.console_writer import ( - display_analysis_results, - ) - from ..analyze.models.output import AnalysisOutput - - data = json.loads(json_output) - analysis = AnalysisOutput.model_validate(data) - - # Save JSON to file if output path specified - if output: - output.write_text(json_output, encoding="utf-8") - logger.info("JSON results saved to: %s", output) - - # Always display friendly console output - display_analysis_results(analysis, verbose=verbose) - - # Determine exit code based on support level - unsupported_ops = result.get_unsupported_operators() - is_model_supported = result.is_fully_supported() - if is_model_supported: - # Full support - logger.info("Model is fully supported") - sys.exit(0) + + # Console for Rich output (stderr so stdout stays clean for JSON) + console = Console(stderr=True) + + # Model info header + if not quiet: + console.print() + console.print("═" * 80) + console.print("📊 [bold]OP CHECK[/bold]") + console.print("═" * 80) + console.print(f" 📦 Model: [bold cyan]{model.name}[/bold cyan]") + + # Load model metadata for header + try: + import onnx + + _proto = onnx.load(str(model), load_external_data=False) + _opset = _proto.opset_import[0].version if _proto.opset_import else "?" + _producer = _proto.producer_name or "unknown" + if _proto.producer_version: + _producer += f" v{_proto.producer_version}" + _total_ops = len(_proto.graph.node) + _unique_ops = len({n.op_type for n in _proto.graph.node}) + console.print( + f" 🔧 Opset: [green]{_opset}[/green] Producer: [green]{_producer}[/green]" + ) + console.print( + f" 📋 Operators: [cyan]{_total_ops}[/cyan] total, " + f"[cyan]{_unique_ops}[/cyan] unique types" + ) + console.print() + del _proto # free memory + except Exception: + logger.debug("Could not load model metadata for header display") + + # Per-EP state for Live display + current_ep_name = "" + all_op_counts: dict[str, int] = {} + instance_counts: dict[str, dict[str, int]] = {} + ep_instance_counts: dict[str, dict[str, dict[str, int]]] = {} + live: Live | None = None + ep_counter = 0 + + def _finalize_live(mark_complete: bool = True) -> None: + """Stop the active Live display, optionally marking it complete.""" + nonlocal live + if live is None: + return + try: + if mark_complete and current_ep_name: + ep_instance_counts[current_ep_name] = { + k: dict(v) for k, v in instance_counts.items() + } + live.update( + _build_analysis_table( + instance_counts, + ep_name=current_ep_name, + complete=True, + all_ops=all_op_counts, + ) + ) + except Exception: + logger.debug("Failed to render final table", exc_info=True) + finally: + live.stop() + live = None + + def on_ep_start(ep_name, operator_counts): + """Called when analysis starts for a new EP.""" + nonlocal current_ep_name, instance_counts, all_op_counts, ep_counter, live + ep_counter += 1 + + # Finalize previous EP's Live display + if current_ep_name: + _finalize_live() + console.print() # blank line between EP tables + + # Reset for new EP (normalize keys to display names) + current_ep_name = ep_name + all_op_counts = {_display_name(k): v for k, v in operator_counts.items()} + instance_counts = {} + + # EP section header + console.print("─" * 80) + console.print(f"💻 [bold]EP {ep_counter}[/bold]: [bold cyan]{ep_name}[/bold cyan]") + console.print("─" * 80) + + # Start new Live display — all ops shown as pending + live = Live( + _build_analysis_table( + instance_counts, + ep_name=ep_name, + all_ops=all_op_counts, + ), + console=console, + refresh_per_second=30, + ) + live.start() + + def on_node_result(pattern_runtime): + """Callback invoked per-node during analysis.""" + op = _display_name(pattern_runtime.pattern_id) + level = pattern_runtime.result.classification.value + op_counts = instance_counts.setdefault(op, {}) + op_counts[level] = op_counts.get(level, 0) + 1 + + if live is not None: + live.update( + _build_analysis_table( + instance_counts, + ep_name=current_ep_name, + all_ops=all_op_counts, + ) + ) + + if not quiet: + # Redirect logging through Rich console so log messages render + # above the Live table instead of breaking it + root_logger = logging.getLogger() + old_handlers = root_logger.handlers[:] + rich_handler = RichHandler( + console=console, + show_path=False, + show_time=True, + rich_tracebacks=False, + ) + rich_handler.setLevel(root_logger.level) + root_logger.handlers = [rich_handler] + + try: + save_node_types = set(save_node) + result = analyzer.analyze( + model_path=str(model), + ep=ep_normalized, + device=device, + enable_information=information, + htp_metadata_path=str(htp_metadata) if htp_metadata else None, + run_unknown_op=run_unknown_op, + save_node_types=save_node_types, + on_node_result=on_node_result, + on_ep_start=on_ep_start, + ) + + # Extract per-EP pattern support (available now) + ep_patterns = _extract_ep_patterns(result.output.results) + + # Finalize last EP's Live display + _finalize_live() + finally: + # Safety: stop Live if still running (e.g. on exception) + _finalize_live(mark_complete=False) + root_logger.handlers = old_handlers + + console.print() + + # Pattern Matching section (per-EP) + _render_pattern_matching(console, ep_patterns) + + # Analysis Summary section + _render_analysis_summary( + console, + result.output.results, + ep_instance_counts, + ep_patterns=ep_patterns, + ) + + # Legend (at the very bottom) + console.print( + " [dim]S/P/U = Supported/Partial/Unsupported[/dim]" + " [green]██[/green] supported" + " [yellow]██[/yellow] partial" + " [red]██[/red] unsupported" + " [bright_black]██[/bright_black] unknown" + ) + console.print() else: - # Partial or no support - logger.warning("Model has %d unsupported operators", len(unsupported_ops)) - if verbose: - for op_name in unsupported_ops[:5]: # Show first 5 - logger.warning(" - %s", op_name) - if len(unsupported_ops) > 5: - logger.warning(" ... and %d more", len(unsupported_ops) - 5) - sys.exit(1) + # Quiet mode — no live display + save_node_types = set(save_node) + result = analyzer.analyze( + model_path=str(model), + ep=ep_normalized, + device=device, + enable_information=information, + htp_metadata_path=str(htp_metadata) if htp_metadata else None, + run_unknown_op=run_unknown_op, + save_node_types=save_node_types, + ) + + # Save JSON if requested + if output: + try: + output.write_text(result.to_json(), encoding="utf-8") + logger.info("JSON results saved to: %s", output) + except OSError as e: + logger.error("Failed to write JSON output to %s: %s", output, e) + except Exception as e: + logger.error("Failed to serialize results to JSON: %s", e) + logger.debug("JSON serialization traceback:", exc_info=True) + + # Save optimization config if requested + if optim_config: + import json + + try: + config = result.get_optimization_config(ep=ep_normalized) + optim_config.write_text(json.dumps(config.to_dict(), indent=2), encoding="utf-8") + logger.info("Optimization config saved to: %s", optim_config) + except OSError as e: + logger.error("Failed to write config to %s: %s", optim_config, e) + except Exception as e: + logger.error("Failed to generate optimization config: %s", e) + logger.debug("Config generation traceback:", exc_info=True) + + # Exit code: 0 = fully supported, 1 = partial support + sys.exit(0 if result.is_fully_supported() else 1) except FileNotFoundError as e: logger.error("File not found: %s", e) sys.exit(2) - except Exception as e: logger.error("Analysis failed: %s", e) if verbose: @@ -219,11 +696,4 @@ def analyze( sys.exit(2) -# Register the command -# This will be auto-discovered by the CLI framework -# Export only the command for CLI discovery __all__ = ["analyze"] - - -if __name__ == "__main__": - analyze() diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index ee9a9159c..79a879da1 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -12,7 +12,7 @@ winml build -c config.json -m microsoft/resnet-50 -o output/ winml build -c config.json -m model.onnx -o output/ winml build -c config.json -m bert-base-uncased -o output/ --no-quant --no-compile - winml build -c config.json -m microsoft/resnet-50 --random-init -o output/ + winml build -c config.json -o output/ --use-cache winml build -c config.json -m microsoft/resnet-50 -o output/ --rebuild -v """ @@ -20,11 +20,22 @@ import json import logging +import time from pathlib import Path from typing import TYPE_CHECKING import click -from rich.console import Console +from rich.logging import RichHandler + +from ..utils.console import ( + detect_model_source, + get_console, + print_error, + print_final, + print_setup, + print_stage_skip, + print_stages_header, +) if TYPE_CHECKING: @@ -36,7 +47,7 @@ from ..config import WinMLBuildConfig logger = logging.getLogger(__name__) -console = Console(stderr=True) +console = get_console() # ============================================================================= @@ -115,7 +126,7 @@ def _instantiate_parent_model(model_type: str, task: str | None = None) -> nn.Mo Returns: PyTorch model in eval mode with random/init weights. """ - from ..loader.config import resolve_loader_config + from ..loader import resolve_loader_config _, hf_config, resolved_class = resolve_loader_config( model_type=model_type, @@ -206,7 +217,7 @@ def _build_modules( # ============================================================================= -@click.command() +@click.command("build") @click.option( "-c", "--config", @@ -219,22 +230,8 @@ def _build_modules( "-m", "--model", "model_id", - required=True, - help="HuggingFace model ID or path to .onnx file.", - # --model is mandatory because random-weight builds (omitting --model) are - # unreliable: AutoConfig.for_model() returns architecture class defaults - # which can differ from pretrained configs in ways that cause silent - # runtime failures. E.g. MPNet/Roberta-family models set - # max_position_embeddings = usable_length + pad_token_id + 1 (514) in the - # pretrained config, but the class default is only 512. The smaller - # embedding table causes "index out of range in self" during ONNX export - # tracing -- a position-offset OOB that the OnnxConfig-level fix (PR #415) - # cannot reach because HTPExporter uses pre-populated input_tensors, not - # Optimum's input generation path. Supporting random-init reliably would - # require storing the full pretrained HF config (or at least the model ID) - # in the build config so _load_model can call AutoConfig.from_pretrained() - # instead of AutoConfig.for_model(). Until that plumbing exists, require - # --model to guarantee correct model instantiation. + default=None, + help="HuggingFace model ID or path to .onnx file. Omit for random-weight build.", ) @click.option( "-o", @@ -250,12 +247,6 @@ def _build_modules( default=False, help="Use ModelKit global cache (~/.cache/winml/). Mutually exclusive with -o.", ) -@click.option( - "--random-init", - is_flag=True, - default=False, - help="Skip weight download; use model config with random weights.", -) @click.option( "--rebuild", is_flag=True, @@ -274,12 +265,6 @@ def _build_modules( default=False, help="Skip compilation (overrides config)", ) -@click.option( - "--no-optimize", - is_flag=True, - default=False, - help="Skip optimization (for pre-quantized ONNX models)", -) @click.option( "--ep", default=None, @@ -297,6 +282,12 @@ def _build_modules( default=False, help="Skip analyzer loop during build", ) +@click.option( + "--no-optimize", + is_flag=True, + default=False, + help="Skip optimization (for pre-quantized ONNX models)", +) @click.option( "--max-optim-iterations", "max_optim_iterations", @@ -318,7 +309,6 @@ def build( model_id: str | None, output_dir: str | None, use_cache: bool, - random_init: bool, rebuild: bool, no_quant: bool, no_compile: bool, @@ -349,8 +339,8 @@ def build( # Export + optimize only winml build -c config.json -m bert-base-uncased -o output/ --no-quant --no-compile - # Random-weight build (no weight download) - winml build -c config.json -m microsoft/resnet-50 --random-init -o output/ + # Random-weight build (no download) + winml build -c config.json -o output/ # Use global cache winml build -c config.json -m microsoft/resnet-50 --use-cache @@ -425,11 +415,15 @@ def build( if not configs: raise click.UsageError("Module config array is empty -- nothing to build.") - console.print() - console.print("[bold]winml build[/bold] (module mode)") - console.print(f" Config: {Path(config_file).name}") - console.print(f" Modules: {len(configs)}") - console.print(f" Output: {resolved_dir}") + print_setup( + console, + model=model_id or "random-init", + config=Path(config_file).name, + output=str(resolved_dir), + source="HuggingFace", + ) + print_stages_header(console) + console.print(f" \U0001f9e9 [bold]Modules:[/bold] {len(configs)}") console.print() results = _build_modules( @@ -451,6 +445,8 @@ def build( ) # Write module summary + from ..build.module_summary import write_module_summary + summary_instances = [] for cfg, result in zip(configs, results, strict=True): summary_instances.append( @@ -462,15 +458,13 @@ def build( } ) - summary_path = resolved_dir / "module_summary.json" - summary = { - "model_id": model_id or "random-init", - "module_class": configs[0].loader.model_class or "unknown", - "instance_count": len(summary_instances), - "instances": summary_instances, - } - summary_path.write_text(json.dumps(summary, indent=2)) - console.print(f" Summary: {summary_path}") + write_module_summary( + output_path=resolved_dir / "module_summary.json", + model_id=model_id or "random-init", + module_class=configs[0].loader.model_class or "unknown", + instances=summary_instances, + ) + console.print(f" Summary: {resolved_dir / 'module_summary.json'}") console.print() @@ -501,66 +495,705 @@ def build( else: resolved_dir = Path(output_dir) - # Report build plan - model_label = f"{model_id} (random-init)" if random_init else model_id + _run_single_build( + config=config, + config_file=config_file, + model_id=model_id, + resolved_dir=resolved_dir, + rebuild=rebuild, + cache_key=cache_key, + ep=ep, + device=device, + extra_kwargs=extra_kwargs, + ) + + except click.UsageError: + raise # Let click handle its own errors + except ValueError as e: + raise click.UsageError(str(e)) from e + except Exception as e: + if verbose: + logger.exception("Build failed") + + # Map common errors to actionable hints + err_str = str(e) + hint = None + if "Quantization failed" in err_str: + hint = "Try: --no-quant to skip quantization" + elif "Compilation failed" in err_str: + hint = "Try: --no-compile to skip compilation" + elif "Black nodes persist" in err_str: + hint = "Try: winml analyze -m --ep to investigate operator support" + elif isinstance(e, FileNotFoundError): + hint = "Check: model path or HuggingFace model ID" + + if hint: console.print() - console.print("[bold]winml build[/bold]") - console.print(f" Config: {Path(config_file).name}") - console.print(f" Model: {model_label}") - console.print(f" Output: {resolved_dir}") + print_error(console, f"Build failed: {e}", hint=hint) console.print() - # Call build API (late import to speed up CLI startup) - from .config import _is_onnx_file + raise click.ClickException(f"Build failed: {e}") from e + + +# ============================================================================= +# SINGLE MODEL BUILD — CLI-level stage orchestration +# ============================================================================= + + +def _run_single_build( + *, + config: WinMLBuildConfig, + config_file: str, + model_id: str | None, + resolved_dir: Path, + rebuild: bool, + cache_key: str | None, + ep: str | None, + device: str | None, + extra_kwargs: dict[str, Any], +) -> None: + """Run single-model build with Rich Live progress per stage.""" + from .config import _is_onnx_file + + _is_onnx = model_id is not None and _is_onnx_file(model_id) + # Derive source from _is_onnx to guarantee header label matches pipeline + source = "ONNX" if _is_onnx else detect_model_source(model_id) + + # Gap 1: (pretrained) suffix; Gap 2: ONNX file size + if model_id is None: + model_label = "random-init" + elif _is_onnx: + _sz = _safe_size(Path(model_id)) + from ..utils.console import fmt_size + + model_label = f"{model_id} [dim]({fmt_size(_sz)})[/dim]" if _sz else model_id + else: + model_label = f"{model_id} [dim](pretrained)[/dim]" + + # ── 🔧 Setup section ──────────────────────────────────────── + print_setup( + console, + model=model_label, + config=Path(config_file).name, + output=str(resolved_dir), + source=source, + ) + print_stages_header(console) + + # ── Redirect logging + warnings through Rich during Live stages ── + # This ensures log messages and warnings.warn() render above the + # Live area instead of breaking it (same pattern as winml analyze). + root_logger = logging.getLogger() + old_handlers = root_logger.handlers[:] + rich_handler = RichHandler( + console=console, + show_path=False, + show_time=True, + rich_tracebacks=False, + ) + rich_handler.setLevel(root_logger.level) + root_logger.handlers = [rich_handler] + # Route warnings.warn() (e.g., TracerWarning) through logging → Rich + logging.captureWarnings(True) + + start_time = time.monotonic() + + try: + if _is_onnx: + stage_timings = _build_onnx_pipeline( + config=config, + onnx_path=Path(model_id), + output_dir=resolved_dir, + rebuild=rebuild, + ep=ep, + device=device, + extra_kwargs=extra_kwargs, + ) + else: + stage_timings = _build_hf_pipeline( + config=config, + model_id=model_id, + output_dir=resolved_dir, + rebuild=rebuild, + cache_key=cache_key, + ep=ep, + device=device, + extra_kwargs=extra_kwargs, + ) + + elapsed = time.monotonic() - start_time + final_path = resolved_dir / "model.onnx" + if final_path.exists() and stage_timings: + print_final( + console, + elapsed, + str(final_path), + stage_timings=stage_timings, + ) + finally: + logging.captureWarnings(False) + root_logger.handlers = old_handlers + + +def _print_reused(artifact_path: Path) -> None: + """Print reused artifact message.""" + console.print() + console.print( + f" \u267b\ufe0f [bold cyan]Existing artifact found:[/bold cyan] {artifact_path}" + ) + console.print(" \U0001f4a1 [dim]Use --rebuild to force rebuild.[/dim]") + console.print() + + +def _safe_size(path: Path) -> int: + """Get file size including ONNX external data, return 0 if unavailable.""" + try: + if path.suffix == ".onnx": + from ..utils.console import get_onnx_total_size + + return get_onnx_total_size(path) + return path.stat().st_size + except OSError: + return 0 + + +def _show_io(sl: Any, config: WinMLBuildConfig) -> None: + """Show I/O tensors in a StageLive.""" + export_cfg = config.export + if not export_cfg: + return + inputs = export_cfg.input_tensors or [] + outputs = export_cfg.output_tensors or [] + for i, t in enumerate(inputs): + name = t.name or "(unnamed)" + shape = str(list(t.shape)) if getattr(t, "shape", None) else "dynamic" + dtype = getattr(t, "dtype", None) or "?" + sl.io_input(name, shape, dtype, first=(i == 0)) + for i, t in enumerate(outputs): + name = t.name or "(unnamed)" + # OutputTensorSpec has name only — show name, no shape/dtype + label = "Output: " if i == 0 else " " + sl.detail(f"{label}[cyan]{name}[/cyan]") + + +# ============================================================================= +# SHARED PIPELINE STAGE HELPERS +# ============================================================================= + + +def _run_optimize_stage( + *, + config: WinMLBuildConfig, + model_path: Path, + optimized_path: Path, + ep: str | None, + device: str | None, + max_iters: int, + stage_timings: list[tuple[str, float | None]], + show_io_first: bool = False, +) -> tuple[Path, float]: + """Run the optimize stage inside a StageLive context. - if model_id and _is_onnx_file(model_id): - from ..build import build_onnx_model + Creates all 5 analyzer callbacks bound to the live display, calls + run_optimize_analyze_loop, shows convergence message and artifact. - result = build_onnx_model( - onnx_path=Path(model_id), - config=config, - output_dir=resolved_dir, - rebuild=rebuild, - ep=ep, - device=device, - **extra_kwargs, + Args: + config: Build configuration. + model_path: Input model path. + optimized_path: Output path for optimized model. + ep: Execution provider for analyzer. + device: Target device for analyzer. + max_iters: Maximum analyzer iterations. + stage_timings: List to append (stage_name, elapsed) tuple to. + show_io_first: If True, show I/O tensors at the start of the stage + (used in ONNX mode where there is no export stage). + + Returns: + Tuple of (current_path, opt_elapsed). + """ + from ..build.common import run_optimize_analyze_loop + from ..utils.console import StageLive + + with StageLive("optimize", console) as sl: + sl.set_status("Optimizing ONNX graph...") + + if show_io_first: + _show_io(sl, config) + + # Analyzer callback state for live EP bars + _ep_bars: dict[str, int] = {} + _ep_counts: dict[str, dict[str, int]] = {} + _ep_totals: dict[str, int] = {} + _current_ep = [""] + _current_iter = [0, 0] # [iteration, max_iter] + _header_shown = [False] + + def _on_iteration_start(iteration: int, max_iter: int) -> None: + _ep_bars.clear() + _ep_counts.clear() + _ep_totals.clear() + _current_iter[0] = iteration + _current_iter[1] = max_iter + _header_shown[0] = False + + def _on_ep_start(ep_name: str, operator_counts: dict) -> None: + _current_ep[0] = ep_name + _ep_counts[ep_name] = {} + total = sum(operator_counts.values()) + _ep_totals[ep_name] = total + # Show "Analyzing N nodes (iter X/Y)" on first EP of each iter + if not _header_shown[0]: + _header_shown[0] = True + sl.detail( + f"[bold]Analyzing[/bold] [cyan]{total}[/cyan] nodes " + f"[dim](iter {_current_iter[0]}/{_current_iter[1]})[/dim]" ) - else: - from ..build import build_hf_model - - result = build_hf_model( - config=config, - output_dir=resolved_dir, - model_id=model_id, - rebuild=rebuild, - random_init=random_init, - cache_key=cache_key, - ep=ep, - device=device, - **extra_kwargs, + _ep_bars[ep_name] = sl.ep_bar_add(ep_name, total=total) + + def _on_node_result(pattern_runtime: Any) -> None: + ep_name = _current_ep[0] + level = pattern_runtime.result.classification.value + counts = _ep_counts.setdefault(ep_name, {}) + counts[level] = counts.get(level, 0) + 1 + s = counts.get("supported", 0) + p = counts.get("partial", 0) + u = counts.get("unsupported", 0) + idx = _ep_bars.get(ep_name) + if idx is not None: + sl.ep_bar_update( + idx, + ep_name, + s, + p, + u, + total=_ep_totals.get(ep_name, 0), ) - # Report results - if result.reused: - console.print(f" Existing artifact: {result.final_onnx_path}") - console.print(" Use --rebuild to force rebuild.") - else: - for stage in result.stages_completed: - t = result.stage_timings.get(stage, 0) - console.print(f" {stage:<12} done ({t:.1f}s)") - for stage in result.stages_skipped: - console.print(f" {stage:<12} skipped") - console.print() - console.print(f" Build complete in {result.elapsed:.1f}s") - console.print(f" Final artifact: {result.final_onnx_path}") + def _on_patterns(autoconf_dict: dict) -> None: + sl.detail("[bold]Patterns[/bold]") + for key in autoconf_dict: + name = key.replace("disable_", "").replace("_fusion", "").replace("_", " ").title() + sl.detail(f" [yellow]{name}[/yellow] [dim]\u2192 {key}[/dim]") + + def _on_reoptimize(autoconf_dict: dict) -> None: + sl.detail("[bold]Optimizing[/bold] [dim](applying autoconf)[/dim]") + sl.detail(f" [dim]{autoconf_dict}[/dim]") + + t0 = time.monotonic() + current_path, _, analyze_iters, _, analyze_details = run_optimize_analyze_loop( + model_path=model_path, + optimized_path=optimized_path, + config=config, + ep=ep, + device=device, + max_optim_iterations=max_iters, + on_ep_start=_on_ep_start, + on_node_result=_on_node_result, + on_iteration_start=_on_iteration_start, + on_patterns_discovered=_on_patterns, + on_reoptimize=_on_reoptimize, + use_external_data=True, + ) + opt_elapsed = time.monotonic() - t0 - console.print() + if analyze_iters > 0: + converged = not analyze_details.get("autoconf_not_converged", False) + conv_str = "converged" if converged else "NOT converged" + # Show pattern result even when none found + autoconf = analyze_details.get("autoconf", {}) + if not autoconf: + sl.detail("[bold]Patterns[/bold]") + sl.detail(" [dim]No optimization patterns found[/dim]") + sl.detail(f"[dim]Autoconf {conv_str} after {analyze_iters} iteration(s)[/dim]") - except click.UsageError: - raise # Let click handle its own errors + sl.set_done(opt_elapsed) + sl.artifact(str(optimized_path), _safe_size(optimized_path)) + sl.blank() + + stage_timings.append(("Optimize", opt_elapsed)) + return current_path, opt_elapsed + + +def _run_quantize_stage( + *, + config: WinMLBuildConfig, + current_path: Path, + quantized_path: Path, + stage_timings: list[tuple[str, float | None]], +) -> Path: + """Run the quantize stage inside a StageLive context (if quant is configured). + + Handles QDQ skip detection, shows dataset/calibration/precision details, + and appends timing to stage_timings. + + Args: + config: Build configuration. + current_path: Input model path. + quantized_path: Output path for quantized model. + stage_timings: List to append (stage_name, elapsed) tuple to. + + Returns: + Updated current_path (quantized_path if quantization ran, else unchanged). + """ + from ..onnx import is_quantized_onnx + from ..quant import quantize_onnx + from ..utils.console import StageLive + + if config.quant is None: + return current_path + + if is_quantized_onnx(current_path): + print_stage_skip(console, "quantize", "(QDQ nodes already present)") + stage_timings.append(("Quantize", None)) + return current_path + + with StageLive("quantize", console) as sl: + wt = config.quant.weight_type or "?" + sl.set_status(f"Quantizing ({wt})...") + # Calibration info before blocking call + ds = config.quant.dataset_name or "default" + sl.kv( + "Dataset:", + f"[cyan]{ds}[/cyan] [dim]({config.quant.task or 'unknown'})[/dim]", + ) + sl.kv( + "Calibration:", + f"[cyan]{config.quant.samples}[/cyan] samples" + f" [dim]({config.quant.calibration_method})[/dim]", + ) + # Suppress tqdm/datasets progress bars during quantize + # to keep Live display clean + _datasets_available = False + try: + import datasets + + datasets.disable_progress_bars() + _datasets_available = True + except ImportError: + pass # datasets package not installed; progress bar suppression not needed + + t0 = time.monotonic() + try: + quant_result = quantize_onnx( + model_path=current_path, + output_path=quantized_path, + config=config.quant, + use_external_data=True, + ) + finally: + if _datasets_available: + datasets.enable_progress_bars() + if not quant_result.success: + errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" + sl.set_error(errors) + raise RuntimeError(f"Quantization failed: {errors}") + current_path = quantized_path + _quant_elapsed = time.monotonic() - t0 + sl.set_done(_quant_elapsed) + sl.kv( + "Precision:", + f"[cyan]{config.quant.weight_type}/" + f"{config.quant.activation_type}[/cyan]" + f" [dim](weight/activation)[/dim]", + ) + sl.artifact( + str(quantized_path), + _safe_size(quantized_path), + ) + sl.blank() + stage_timings.append(("Quantize", _quant_elapsed)) + return current_path + + +def _run_compile_stage( + *, + config: WinMLBuildConfig, + current_path: Path, + compiled_path: Path, + stage_timings: list[tuple[str, float | None]], +) -> Path: + """Run the compile stage inside a StageLive context (if compile is configured). + + Shows graph summary after compilation and appends timing to stage_timings. + + Args: + config: Build configuration. + current_path: Input model path. + compiled_path: Output path for compiled model. + stage_timings: List to append (stage_name, elapsed) tuple to. + + Returns: + Updated current_path (compiled_path if compilation ran, else unchanged). + """ + from ..compiler import compile_onnx + from ..onnx import copy_onnx_model + from ..utils.console import StageLive, get_onnx_graph_summary + + if config.compile is None: + return current_path + + with StageLive("compile", console) as sl: + _cp = "" + if hasattr(config.compile, "ep_config") and config.compile.ep_config: + _cp = f" for {config.compile.ep_config.provider.upper()}" + sl.set_status(f"Compiling{_cp}...") + t0 = time.monotonic() + compile_result = compile_onnx( + model_path=current_path, + output_path=compiled_path, + config=config.compile, + ) + if hasattr(compile_result, "success") and not compile_result.success: + errors = ", ".join(compile_result.errors) if compile_result.errors else "Unknown" + sl.set_error(errors) + raise RuntimeError(f"Compilation failed: {errors}") + if ( + compile_result.output_path + and Path(compile_result.output_path).resolve() != compiled_path.resolve() + ): + copy_onnx_model(compile_result.output_path, compiled_path) + current_path = compiled_path + _compile_elapsed = time.monotonic() - t0 + sl.set_done(_compile_elapsed) + + # Graph summary + try: + summary = get_onnx_graph_summary(compiled_path) + op_parts = ", ".join( + f"[cyan]{op}[/cyan] ({count})" + for op, count in list(summary["op_counts"].items())[:8] + ) + sl.detail(f"[bold]Graph:[/bold] {op_parts}") + except Exception: + logger.debug("Could not load graph summary", exc_info=True) + + sl.artifact( + str(compiled_path), + _safe_size(compiled_path), + ) + stage_timings.append(("Compile", _compile_elapsed)) + return current_path + + +# ============================================================================= +# PIPELINE FUNCTIONS +# ============================================================================= + + +def _build_hf_pipeline( + *, + config: WinMLBuildConfig, + model_id: str | None, + output_dir: Path, + rebuild: bool, + cache_key: str | None, + ep: str | None, + device: str | None, + extra_kwargs: dict[str, Any], +) -> list[tuple[str, float | None]] | None: + """HF build pipeline with cascading StageLive per stage. + + Returns list of (stage_name, elapsed_seconds | None) for summary, + or None if build was reused. + """ + from ..build.hf import _load_model + from ..export import export_onnx + from ..onnx import copy_onnx_model + from ..utils.console import StageLive + + max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3) + model_label = model_id or "random-init" + + # ── Validate + setup ───────────────────────────────────────── + try: + config.validate() except ValueError as e: - raise click.UsageError(str(e)) from e - except Exception as e: - if verbose: - logger.exception("Build failed") - raise click.ClickException(f"Build failed: {e}") from e + raise ValueError(f"Config validation failed: {e}") from e + + output_dir.mkdir(parents=True, exist_ok=True) + + def _name(base: str) -> str: + return f"{cache_key}_{base}" if cache_key else base + + export_path = output_dir / _name("export.onnx") + optimized_path = output_dir / _name("optimized.onnx") + quantized_path = output_dir / _name("quantized.onnx") + compiled_path = output_dir / _name("compiled.onnx") + final_path = output_dir / _name("model.onnx") + config_path = output_dir / _name("winml_build_config.json") + + # Reuse check + if final_path.exists() and not rebuild: + _print_reused(final_path) + return None + + stage_timings: list[tuple[str, float | None]] = [] + + # Clean old artifacts on rebuild + if rebuild: + pattern = f"{cache_key}_*.onnx" if cache_key else "*.onnx" + for old in output_dir.glob(pattern): + old.unlink() + + current_path = export_path + + # ── Export stage ────────────────────────────────────────────── + import warnings + + with StageLive("export", console) as sl: + sl.set_status("Exporting to ONNX...") + + # Load + export (blocking) + # Suppress TracerWarning and other transformer warnings + # during export to keep Live display clean. + pytorch_model = _load_model(config, model_id, trust_remote_code=False) + t0 = time.monotonic() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + export_onnx( + model=pytorch_model, + output_path=export_path, + export_config=config.export, + model_id=model_label, + task=config.loader.task, + verbose=False, + use_external_data=True, + ) + _export_elapsed = time.monotonic() - t0 + sl.set_done(_export_elapsed) + # Meta shown after export completes (avoids duplicate in Live frame) + if config.loader.model_class: + sl.kv("Model class:", f"[cyan]{config.loader.model_class}[/cyan]") + if config.loader.task: + sl.kv("Task:", f"[cyan]{config.loader.task}[/cyan]") + _show_io(sl, config) + sl.artifact(str(export_path), _safe_size(export_path)) + sl.blank() + + stage_timings.append(("Export", _export_elapsed)) + + # ── Optimize stage ─────────────────────────────────────────── + current_path, _ = _run_optimize_stage( + config=config, + model_path=current_path, + optimized_path=optimized_path, + ep=ep, + device=device, + max_iters=max_iters, + stage_timings=stage_timings, + show_io_first=False, + ) + + # Persist config after autoconf + config_path.write_text(json.dumps(config.to_dict(), indent=2)) + + # ── Quantize stage ─────────────────────────────────────────── + current_path = _run_quantize_stage( + config=config, + current_path=current_path, + quantized_path=quantized_path, + stage_timings=stage_timings, + ) + + # ── Compile stage ──────────────────────────────────────────── + current_path = _run_compile_stage( + config=config, + current_path=current_path, + compiled_path=compiled_path, + stage_timings=stage_timings, + ) + + # ── Finalize ───────────────────────────────────────────────── + if current_path != final_path: + copy_onnx_model(current_path, final_path) + + return stage_timings + + +def _build_onnx_pipeline( + *, + config: WinMLBuildConfig, + onnx_path: Path, + output_dir: Path, + rebuild: bool, + ep: str | None, + device: str | None, + extra_kwargs: dict[str, Any], +) -> list[tuple[str, float | None]] | None: + """ONNX build pipeline with cascading StageLive per stage. + + Returns list of (stage_name, elapsed_seconds | None) for summary, + or None if build was reused. + """ + from ..onnx import copy_onnx_model + + max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3) + + # ── Validate + setup ───────────────────────────────────────── + if not onnx_path.exists(): + raise FileNotFoundError(f"ONNX file not found: {onnx_path}") + try: + config.validate() + except ValueError as e: + raise ValueError(f"Config validation failed: {e}") from e + + output_dir.mkdir(parents=True, exist_ok=True) + + stem = onnx_path.stem + optimized_path = output_dir / f"{stem}_optimized.onnx" + quantized_path = output_dir / f"{stem}_quantized.onnx" + compiled_path = output_dir / f"{stem}_compiled.onnx" + final_path = output_dir / "model.onnx" + config_path = output_dir / "winml_build_config.json" + + # Reuse check + if final_path.exists() and not rebuild: + _print_reused(final_path) + return None + + stage_timings: list[tuple[str, float | None]] = [] + + if rebuild: + for old in output_dir.glob("*.onnx"): + old.unlink() + + # Copy input ONNX to output dir + current_path = output_dir / onnx_path.name + if current_path.resolve() != onnx_path.resolve(): + copy_onnx_model(onnx_path, current_path) + + # ── Optimize stage (first stage for ONNX — show I/O here) ──── + current_path, _ = _run_optimize_stage( + config=config, + model_path=current_path, + optimized_path=optimized_path, + ep=ep, + device=device, + max_iters=max_iters, + stage_timings=stage_timings, + show_io_first=True, + ) + + config_path.write_text(json.dumps(config.to_dict(), indent=2)) + + # ── Quantize stage ─────────────────────────────────────────── + current_path = _run_quantize_stage( + config=config, + current_path=current_path, + quantized_path=quantized_path, + stage_timings=stage_timings, + ) + + # ── Compile stage ──────────────────────────────────────────── + current_path = _run_compile_stage( + config=config, + current_path=current_path, + compiled_path=compiled_path, + stage_timings=stage_timings, + ) + + # ── Finalize ───────────────────────────────────────────────── + if current_path != final_path: + copy_onnx_model(current_path, final_path) + + return stage_timings diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index f37ea22a4..f97aab690 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Config generation command for ModelKit CLI. +"""Config generation command (v2, Rich UI) for ModelKit CLI. Generates WinMLBuildConfig for a HuggingFace model or a pre-exported ONNX file by auto-detecting task, model class, and I/O specifications. @@ -28,11 +28,20 @@ from typing import Any import click -from rich.console import Console + +from ..utils.console import ( + get_console, + print_command_header, + print_error, + print_io_specs_detail, + print_io_specs_na, + print_kv, + print_success, +) logger = logging.getLogger(__name__) -console = Console(stderr=True) +console = get_console() def _apply_stage_overrides(cfg: Any, *, no_quant: bool, no_compile: bool) -> None: @@ -49,7 +58,7 @@ def _is_onnx_file(model_input: str) -> bool: return path.suffix == ".onnx" and path.exists() -@click.command() +@click.command("config") @click.option( "-m", "--model", @@ -97,7 +106,7 @@ def _is_onnx_file(model_input: str) -> bool: type=click.Path(exists=True), default=None, help="JSON file with shape overrides passed to dummy input generation. " - "Valid keys -- text: sequence_length; " + "Valid keys — text: sequence_length; " "vision: height, width, num_channels; " "audio: feature_size, nb_max_frames, audio_sequence_length.", ) @@ -230,6 +239,14 @@ def config( # Validate: at least one of -m, --model-type, or --model-class is required if hf_model is None and model_type is None and model_class is None: + # Show header even for errors + print_command_header(console, "\U0001f4cb CONFIG GENERATION") + print_error( + console, + "Missing required input", + hint="Provide one of: -m/--model, --model-type, or --model-class", + ) + console.print() raise click.UsageError( "At least one of -m/--model, --model-type, or --model-class is required." ) @@ -243,6 +260,8 @@ def config( # Load override config from JSON file if provided override = None + _override_file: str | None = None + _shape_config_file: str | None = None if config_file: config_path = Path(config_file) try: @@ -258,7 +277,7 @@ def config( override = WinMLBuildConfig.from_dict(data) except json.JSONDecodeError as e: raise click.UsageError(f"Invalid JSON in config file {config_path}: {e}") from e - console.print(f"[dim]Loaded overrides from {config_path.name}[/dim]") + _override_file = config_path.name # Load shape_config (shape overrides) from JSON file if provided shape_config = None @@ -278,12 +297,15 @@ def config( raise click.UsageError( f"Invalid JSON in I/O config file {shape_config_path}: {e}" ) from e - console.print(f"[dim]Loaded I/O config from {shape_config_path.name}[/dim]") + _shape_config_file = shape_config_path.name # ONNX file detection: generate simpler config without loader/export + if hf_model and _is_onnx_file(hf_model) and module: + raise click.UsageError( + "--module is not supported with ONNX file input. " + "Module discovery requires a HuggingFace model." + ) if hf_model and _is_onnx_file(hf_model): - console.print(f"[dim]Generating ONNX build config for {hf_model}...[/dim]") - config_obj = generate_onnx_build_config( hf_model, task=task, @@ -296,11 +318,15 @@ def config( # Apply --no-quant / --no-compile overrides _apply_stage_overrides(config_obj, no_quant=no_quant, no_compile=no_compile) - console.print("[green]Generated ONNX build config (export=None)[/green]") output_data = config_obj.to_dict() + _is_onnx_mode = True + _resolved_task = None + _resolved_model_class = None + _export_cfg = None + configs: list = [] # defensive — ONNX + module is rejected above + _n_modules = 0 else: - label = hf_model or model_type - console.print(f"[dim]Generating config for {label}...[/dim]") + _is_onnx_mode = False # Generate config(s) - returns single or list based on module parameter result = generate_hf_build_config( @@ -322,39 +348,136 @@ def config( if module: # Module mode: result is list[WinMLBuildConfig] configs = result - # Apply --no-quant / --no-compile overrides to each config for cfg in configs: _apply_stage_overrides(cfg, no_quant=no_quant, no_compile=no_compile) - console.print(f"[green]Found {len(configs)} submodules matching '{module}'[/green]") output_data = [cfg.to_dict() for cfg in configs] + _n_modules = len(configs) + # Use first config for display metadata + config_obj = configs[0] if configs else None else: # Normal mode: result is WinMLBuildConfig config_obj = result - # Apply --no-quant / --no-compile overrides + configs = [] _apply_stage_overrides(config_obj, no_quant=no_quant, no_compile=no_compile) - # B-4: Inform user of auto-selected task when --task not provided - if not task and not module: - auto_task = config_obj.loader.task - source = model_type or hf_model - console.print(f"[dim]Auto-selected task: {auto_task} (from '{source}')[/dim]") + output_data = config_obj.to_dict() + _n_modules = 0 + + _resolved_task = config_obj.loader.task if config_obj else None + _resolved_model_class = config_obj.loader.model_class if config_obj else None + _export_cfg = config_obj.export if config_obj else None + + # ── Rich console output ────────────────────────────────────── + subtitle = "ONNX mode" if _is_onnx_mode else ("module mode" if module else None) + print_command_header(console, "\U0001f4cb CONFIG GENERATION", subtitle) + + # Model identity + model_label = hf_model or model_type or model_class or "?" + print_kv(console, "Model:", model_label, icon="\U0001f4e6") + + if _is_onnx_mode: + print_kv(console, "Mode:", "Direct ONNX", note="export=None", icon="\U0001f527") + else: + # Fix #1: Model class before Task + if module: + print_kv(console, "Module:", module, icon="\U0001f9e9") + elif _resolved_model_class: + mc_note = None if model_class else "auto-detected" + print_kv( + console, + "Model class:", + _resolved_model_class, + note=mc_note, + icon="\U0001f9e9", + ) + # Fix #2: no trailing space after 🏷️ + if _resolved_task: + task_note = None if task else "auto-detected" + print_kv( + console, + "Task:", + _resolved_task, + note=task_note, + icon="\U0001f3f7\ufe0f", + ) + + # Override files + if config_file: + console.print( + f" \U0001f4c1 [bold]Overrides:[/bold] {_override_file} [green]\u2713[/green]" + ) + if shape_config_file: + console.print( + f" \U0001f4c1 [bold]Shape config:[/bold] " + f"{_shape_config_file} [green]\u2713[/green]" + ) + + console.print() + + # I/O specs (always full detail) + if _is_onnx_mode: + print_io_specs_na(console) + elif _export_cfg is not None: + print_io_specs_detail(console, _export_cfg) + + console.print() + + # Resolution — read directly from the config object. + # No inference or reverse mapping — display what the config contains. + _ref_config = config_obj if not module else (configs[0] if configs else None) + if _ref_config is not None: + _quant = _ref_config.quant + + console.print(" \u2699\ufe0f [bold]Resolution:[/bold]") + + # Fix #4: Device from resolve_device (existing API) + from ..sysinfo import resolve_device as _rd + + _resolved_dev, _ = _rd() + console.print(f" Device: [cyan]{_resolved_dev.upper()}[/cyan]") + + # EP — only shown when user explicitly passed --ep + if ep: + from ..utils.constants import normalize_ep_name + + _ep_full = normalize_ep_name(ep) or ep + console.print(f" EP: [cyan]{_ep_full}[/cyan]") + + # Quant types — display exactly what config contains + if _quant: console.print( - f"[green]Generated config for task '{config_obj.loader.task}'[/green]" + f" Quant: " + f"[cyan]{_quant.weight_type}/{_quant.activation_type}" + f"[/cyan] [dim](weight/activation)[/dim]" ) - output_data = config_obj.to_dict() + else: + console.print(" Quant: [dim]none[/dim]") + + # Module mode: show submodule list + if module and not _is_onnx_mode and _n_modules > 0: + console.print() + console.print( + f" \U0001f9e9 [bold]Submodules:[/bold] " + f"[green]{_n_modules}[/green] matching '{module}'" + ) - # Serialize to JSON + console.print() + + # ── Serialize and output ───────────────────────────────────── config_json = json.dumps(output_data, indent=2) - # Output to file or stdout if output: output_path = Path(output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(config_json) - console.print(f"[green]Config saved to:[/green] {output}") + suffix = f" [dim]({_n_modules} submodules)[/dim]" if _n_modules else "" + print_success(console, f"Config saved to: [bold]{output}[/bold]{suffix}") else: + print_success(console, "Config written to stdout") # Print to stdout (not stderr where console prints) print(config_json) + console.print() + except click.UsageError: raise # Let click handle its own errors except ValueError as e: diff --git a/src/winml/modelkit/commands/export.py b/src/winml/modelkit/commands/export.py index 4e8f0b3f3..707c32208 100644 --- a/src/winml/modelkit/commands/export.py +++ b/src/winml/modelkit/commands/export.py @@ -36,6 +36,28 @@ console = Console() +def _delete_onnx_with_external_data(onnx_path: Path) -> None: + """Delete an ONNX file and its external data files.""" + import onnx + from onnx.external_data_helper import ExternalDataInfo + + try: + model = onnx.load(str(onnx_path), load_external_data=False) + ext_files: set[str] = set() + for tensor in model.graph.initializer: + if tensor.data_location == onnx.TensorProto.EXTERNAL: + ext_files.add(ExternalDataInfo(tensor).location) + for name in ext_files: + data_path = onnx_path.parent / name + if data_path.exists(): + data_path.unlink() + except Exception: + logger.debug("Could not parse external data from %s", onnx_path, exc_info=True) + + if onnx_path.exists(): + onnx_path.unlink() + + @click.command() @click.option( "--model", @@ -322,19 +344,27 @@ def export( else: console.print(f"[dim]Detected task: {detected_task}[/dim]") - # Export using export_onnx() - the single implementation path - result_path = export_onnx( + export_stats = export_onnx( model=pytorch_model, output_path=output_path, export_config=cfg, - model_id=model, # For metadata - task=detected_task, # Use detected task for proper OnnxConfig lookup + model_id=model, + task=detected_task, verbose=verbose, enable_reporting=with_report, ) + logger.debug("Export stats: %s", export_stats) + + # TODO: re-enable post-export optimization (shape inference, constant folding) + # Disabled: needs validation that optimize_onnx preserves HTP hierarchy tags. + # from ..optim.api import optimize_onnx + # raw_path = output_path.with_stem(f"{output_path.stem}_raw") + # output_path.rename(raw_path) + # optimize_onnx(raw_path, output=output_path) + # _delete_onnx_with_external_data(raw_path) # Show results - console.print(f"\n[bold green]Success![/bold green] Model exported to: {result_path}") + console.print(f"\n[bold green]Success![/bold green] Model exported to: {output_path}") # Show report file locations if enabled if with_report: diff --git a/src/winml/modelkit/commands/inspect.py b/src/winml/modelkit/commands/inspect.py index b79f6a649..2737a0386 100644 --- a/src/winml/modelkit/commands/inspect.py +++ b/src/winml/modelkit/commands/inspect.py @@ -2,21 +2,26 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Inspect command for ModelKit CLI. +"""Inspect input model's ModelKit configuration. -Displays detailed information about a HuggingFace model's compatibility -with ModelKit, including loader, exporter, and WinML configurations. +Resolves loader, exporter, and WinML inference class for a given model, +showing what the build pipeline will use. Usage: - winml inspect -m openai/clip-vit-base-patch32 + winml inspect -m microsoft/resnet-50 + winml inspect --model-type bert --task fill-mask winml inspect -m google-bert/bert-base-uncased --format json - winml inspect -m facebook/detr-resnet-50 --verbose - winml inspect -m openai/clip-vit-base-patch32 --hierarchy + winml inspect --list-tasks """ from __future__ import annotations import logging +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from ..inspect.types import InspectResult import click from rich.console import Console @@ -26,12 +31,14 @@ console = Console() -@click.command() +@click.command("inspect") @click.option( "-m", "--model", - required=True, - help="HuggingFace model ID (e.g., openai/clip-vit-base-patch32)", + "model_id", + required=False, + default=None, + help="HuggingFace model ID (e.g., microsoft/resnet-50)", ) @click.option( "-f", @@ -61,53 +68,89 @@ default=False, help="Show HF module hierarchy (uses random weights, no weight download)", ) +@click.option( + "--list-tasks", + "list_tasks", + is_flag=True, + default=False, + help="List all known tasks and exit", +) +@click.option( + "--model-type", + "model_type", + default=None, + help="Override model type (e.g., bert, resnet) — can be used without --model", +) +@click.option( + "--model-class", + "model_class", + default=None, + help="Override model class (e.g., BertForMaskedLM) — can be used without --model", +) @click.pass_context def inspect( ctx: click.Context, - model: str, + model_id: str | None, output_format: str, verbose: bool, task: str | None, hierarchy: bool, + list_tasks: bool, + model_type: str | None, + model_class: str | None, ) -> None: - r"""Inspect a HuggingFace model's ModelKit configuration. + r"""Inspect input model's ModelKit configuration. - Shows the loader configuration, exporter configuration, and WinML - inference class that will be used for the specified model. + Shows the loader, exporter, WinML inference class, I/O specs, + and build resolution that the pipeline will use for the given model. - This command helps you understand: - - Which HuggingFace model class will be used for loading - - What ONNX export configuration will be applied - - Which WinML inference class will handle the model - - Overall support status in ModelKit + Supports inspection without a model ID via --model-type or --model-class. \b Examples: # Basic inspection - winml inspect -m openai/clip-vit-base-patch32 + winml inspect -m microsoft/resnet-50 - # JSON output for scripting - winml inspect -m google-bert/bert-base-uncased --format json + # Inspect by model type only (no weight download) + winml inspect --model-type bert --task fill-mask - # Show full build configuration - winml inspect -m facebook/detr-resnet-50 --verbose + # Override model class + winml inspect -m custom-model --model-class BertForCTC - # Include HF module hierarchy (no weight download) - winml inspect -m openai/clip-vit-base-patch32 --hierarchy + # JSON output + winml inspect -m google-bert/bert-base-uncased --format json - # Combined verbose + hierarchy - winml inspect -m google-bert/bert-base-uncased -v -H + # List all known tasks + winml inspect --list-tasks """ - # Import here to defer heavy transformers/torch imports - from ..inspect import ( - InspectError, - ModelNotFoundError, - NetworkError, - inspect_model, - ) + # Handle --list-tasks (no model required) + if list_tasks: + from ..inspect.resolver import get_known_tasks + + for t in sorted(get_known_tasks()): + click.echo(t) + return + + # Validate: need at least one of model_id, model_type, model_class + if model_id is None and model_type is None and model_class is None: + raise click.UsageError( + "At least one of -m/--model, --model-type, or --model-class is required. " + "Use --list-tasks to see available tasks." + ) + + # Handle ONNX file input + from pathlib import Path + + if model_id and model_id.endswith(".onnx") and Path(model_id).is_file(): + raise click.ClickException( + "ONNX file inspection is not yet supported. " + "Use 'winml config -m model.onnx' for ONNX build config." + ) + + from ..inspect import InspectError, ModelNotFoundError, NetworkError from ..inspect.formatter import output_json, output_table - # Inherit debug mode from parent + # Inherit debug mode from parent context if ctx.obj and ctx.obj.get("debug"): verbose = True @@ -116,7 +159,13 @@ def inspect( logging.getLogger("winml.modelkit").setLevel(logging.DEBUG) try: - result = inspect_model(model, include_hierarchy=hierarchy, task_override=task) + result = _inspect_model_v2( + model_id=model_id, + task_override=task, + model_type_override=model_type, + model_class_override=model_class, + include_hierarchy=hierarchy, + ) if output_format.lower() == "json": click.echo(output_json(result, verbose=verbose)) @@ -133,5 +182,245 @@ def inspect( raise click.ClickException(f"Inspection error: {e}") from e except (ValueError, RuntimeError, OSError) as e: - logger.exception("Failed to inspect model: %s", model) + logger.exception("Failed to inspect model") raise click.ClickException(f"Failed to inspect model: {e}") from e + + +def _inspect_model_v2( + model_id: str | None = None, + task_override: str | None = None, + model_type_override: str | None = None, + model_class_override: str | None = None, + include_hierarchy: bool = False, +) -> InspectResult: + """Inspect v2 core — calls shared loader/export modules directly. + + Args: + model_id: HuggingFace model ID (optional when model_type_override set) + task_override: Task to use instead of auto-detected task + model_type_override: Model type override (e.g., "bert") + model_class_override: Model class override (e.g., "BertForMaskedLM") + include_hierarchy: Whether to extract module hierarchy + + Returns: + InspectResult dataclass + """ + import functools + + from transformers import AutoConfig + + from ..export.io import resolve_io_specs + from ..inspect import InspectError, ModelNotFoundError, NetworkError + from ..inspect.resolver import ( + build_tensor_infos_from_io_specs, + compile_support_status, + resolve_cache, + resolve_io_config, + resolve_processor, + resolve_winml, + ) + from ..inspect.types import ( + ExporterInfo, + InspectResult, + LoaderInfo, + SupportLevel, + TensorInfo, + ) + from ..loader.config import resolve_loader_config + from ..loader.task import HF_TASK_DEFAULTS + from ..models import ( + HF_MODEL_CLASS_MAPPING, + MODEL_BUILD_CONFIGS, + ) + + # ========================================================================= + # STEP 1: Preserve parent hf_config before resolve_loader_config narrows it + # for multimodal models (e.g., CLIPConfig → CLIPTextConfig) + # ========================================================================= + parent_hf_config = None + if model_id and not model_type_override: + try: + parent_hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=False) + except Exception: + pass # resolve_loader_config will handle the error properly + + # ========================================================================= + # STEP 2: Shared loader resolution (same call as config command) + # ========================================================================= + try: + loader_config, hf_config, _resolved_class = resolve_loader_config( + model_id, + task=task_override, + model_type=model_type_override, + model_class=model_class_override, + ) + except ValueError as e: + err_str = str(e).lower() + if "not found" in err_str or "404" in err_str: + raise ModelNotFoundError(str(e)) from e + raise InspectError(str(e)) from e + except OSError as e: + raise NetworkError(str(e)) from e + + if parent_hf_config is None: + parent_hf_config = hf_config + + model_type = loader_config.model_type + task = loader_config.task + architectures = getattr(parent_hf_config, "architectures", []) or [] + + # ========================================================================= + # STEP 3: Derive task_source by checking registries post-hoc + # ========================================================================= + mt = model_type.lower().replace("_", "-") + task_source = "TasksManager" + for m, t in HF_MODEL_CLASS_MAPPING: + if m == mt and t == task: + task_source = "HF_MODEL_CLASS_MAPPING" + break + + # ========================================================================= + # STEP 4: Derive loader display info + # ========================================================================= + if (mt, task) in HF_MODEL_CLASS_MAPPING: + loader_source = "MODEL_CLASS_MAPPING" + loader_level = SupportLevel.SUPPORTED + elif task in HF_TASK_DEFAULTS: + loader_source = "HF_TASK_DEFAULTS" + loader_level = SupportLevel.DEFAULT + else: + loader_source = "TasksManager" + loader_level = SupportLevel.DEFAULT + + loader_info = LoaderInfo( + hf_model_class=loader_config.model_class or "Auto (TasksManager)", + hf_model_class_source=loader_source, + support_level=loader_level, + ) + + # ========================================================================= + # STEP 5: I/O tensor specs — registry first, then resolve_io_specs + # ========================================================================= + input_tensors: list[TensorInfo] = [] + output_tensors: list[TensorInfo] = [] + onnx_config_class = None + onnx_config_source = "none" + exporter_level = SupportLevel.UNSUPPORTED + opset_version = 17 + + # Path 1: Check MODEL_BUILD_CONFIGS registry for predefined config + registered = MODEL_BUILD_CONFIGS.get(mt) + if registered and registered.export and registered.export.input_tensors is not None: + export_cfg = registered.export + input_tensors = [ + TensorInfo(name=s.name or "unknown", dtype=s.dtype, shape=s.shape) + for s in export_cfg.input_tensors + ] + output_tensors = [ + TensorInfo(name=s.name or "unknown") for s in (export_cfg.output_tensors or []) + ] + onnx_config_class = f"{mt.upper()}IOConfig" + onnx_config_source = "MODEL_BUILD_CONFIGS" + exporter_level = SupportLevel.SUPPORTED + opset_version = export_cfg.opset_version + else: + # Path 2: resolve_io_specs (shared with config command) + try: + import optimum.exporters.onnx.model_configs # noqa: F401 + from optimum.exporters.tasks import TasksManager + + onnx_config_cls = TasksManager.get_exporter_config_constructor( + exporter="onnx", + model_type=model_type, + task=task, + library_name="transformers", + ) + if onnx_config_cls: + config_name = ( + onnx_config_cls.func.__name__ + if isinstance(onnx_config_cls, functools.partial) + else onnx_config_cls.__name__ + ) + onnx_config_class = config_name + onnx_config_source = "TasksManager" + exporter_level = SupportLevel.DEFAULT + + if hf_config is not None: + try: + io_specs = resolve_io_specs( + model_type=model_type, + task=task, + hf_config=hf_config, + model_id=model_id, + ) + input_tensors, output_tensors = build_tensor_infos_from_io_specs(io_specs) + except Exception as e: + logger.debug("resolve_io_specs failed for %s/%s: %s", model_type, task, e) + except Exception as e: + logger.debug("TasksManager lookup failed for %s/%s: %s", model_type, task, e) + + exporter_info = ExporterInfo( + onnx_config_class=onnx_config_class, + onnx_config_source=onnx_config_source, + support_level=exporter_level, + input_tensors=input_tensors, + output_tensors=output_tensors, + opset_version=opset_version, + ) + + # ========================================================================= + # STEP 6: WinML class (inspect-only lookup) + # ========================================================================= + winml_info = resolve_winml(model_type, task) + + # ========================================================================= + # STEP 7: Module hierarchy (optional, requires model_id) + # ========================================================================= + hierarchy_info = None + if include_hierarchy and model_id: + try: + from ..inspect.hierarchy import extract_hierarchy + + hierarchy_info = extract_hierarchy(model_id) + except Exception as e: + logger.debug("Hierarchy extraction failed for %s: %s", model_id, e) + + # ========================================================================= + # STEP 8: Overall support status + # ========================================================================= + overall_support, support_notes = compile_support_status(loader_info, exporter_info, winml_info) + + # ========================================================================= + # STEP 9: Build config (registry lookup only, no generation) + # ========================================================================= + build_config = registered.to_dict() if registered else None + + # ========================================================================= + # STEP 10: Inspect-only enrichment (conditional on model_id) + # ========================================================================= + cache_info = resolve_cache(model_id) if model_id else None + processor_info = resolve_processor(model_id, model_type=model_type) if model_id else None + io_config_info = resolve_io_config( + parent_hf_config, + model_id=model_id, + model_type=model_type, + task=task, + ) + + return InspectResult( + model_id=model_id or model_type or model_class_override or "unknown", + model_type=model_type, + architectures=architectures, + task=task, + task_source=task_source, + loader=loader_info, + exporter=exporter_info, + winml=winml_info, + overall_support=overall_support, + support_notes=support_notes, + build_config=build_config, + hierarchy=hierarchy_info, + cache=cache_info, + processor=processor_info, + io_config=io_config_info, + ) diff --git a/src/winml/modelkit/commands/optimize.py b/src/winml/modelkit/commands/optimize.py index 7e410760a..77bbbbaeb 100644 --- a/src/winml/modelkit/commands/optimize.py +++ b/src/winml/modelkit/commands/optimize.py @@ -30,7 +30,7 @@ import click from rich.console import Console -from ..onnx import is_compiled_onnx, load_onnx, save_onnx +from ..onnx import load_onnx, save_onnx if TYPE_CHECKING: @@ -382,12 +382,6 @@ def optimize( if model is None: raise click.UsageError("Missing option '--model' / '-m'.") - if is_compiled_onnx(model): - raise click.ClickException( - f"{model} is a compiled EPContext model and cannot be optimized. " - "Run 'winml optimize' on the original ONNX model before compilation." - ) - # Inherit debug mode from parent if ctx.obj and ctx.obj.get("debug"): verbose = True @@ -425,6 +419,8 @@ def optimize( # 3. Apply config file if specified (overrides preset/defaults) if config: file_config = load_config(config) + # Normalize snake_case keys to kebab-case (accept both formats) + file_config = {k.replace("_", "-"): v for k, v in file_config.items()} final_config.update(file_config) console.print(f"[dim]Loaded config from: {config}[/dim]") diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index ecfcbf8d3..d5266e523 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -25,10 +25,9 @@ import click import numpy as np from rich.console import Console -from rich.panel import Panel from rich.table import Table -from .live_chart import LiveMonitorDisplay +from ._live_chart import LiveMonitorDisplay if TYPE_CHECKING: @@ -178,8 +177,8 @@ def generate_random_inputs( ) -> dict[str, np.ndarray]: """Generate random inputs based on model io_config. - Uses modelkit.core.model_input_generator for spec-driven generation, - then converts torch tensors to numpy for ONNX Runtime. + Uses modelkit.core.model_input_generator for spec-driven generation. + Returns numpy arrays directly (no torch dependency). Args: io_config: Model I/O configuration from WinMLSession.io_config. @@ -217,8 +216,7 @@ def generate_random_inputs( "shape": list(resolved_shape), } - torch_inputs = generate_dummy_inputs_from_specs(specs) - return {name: tensor.numpy() for name, tensor in torch_inputs.items()} + return generate_dummy_inputs_from_specs(specs) def _resolve_shape( @@ -292,6 +290,16 @@ def run(self) -> BenchmarkResult: logger.info("Generating benchmark inputs") self._generate_inputs() + # Compile session early so model.device is resolved for display + self._model._session.compile() + + # Print model info before benchmark starts + _print_model_info( + self._model.io_config, + task=self._model.task or self.config.task, + device=self._model.device, + ) + # [3] Run benchmark logger.info( "Running benchmark: %d iterations + %d warmup", @@ -368,12 +376,7 @@ def _run_benchmark_simple(self) -> PerfStats: total_iterations = self.config.warmup + self.config.iterations with session.perf(warmup=self.config.warmup) as stats: - for i in range(total_iterations): - session.run(self._inputs) - - # Progress logging (every 10%) - if (i + 1) % max(1, total_iterations // 10) == 0: - logger.debug("Progress: %d/%d", i + 1, total_iterations) + _run_simple_loop(session, self._inputs, total_iterations) return stats @@ -416,36 +419,16 @@ def _run_benchmark_monitored(self) -> PerfStats: hw_monitor as hw, ep_monitor as ep_mon, ): - display = LiveMonitorDisplay( + _run_monitored_loop( + session, + self._inputs, + stats, + hw, total_iterations=total_iterations, warmup=self.config.warmup, model_id=self.config.model_id, device=self.config.device, ) - with display: - for i in range(total_iterations): - session.run(self._inputs) - - latest_latency = stats.all_samples_ms[-1] if stats.all_samples_ms else 0 - display.update( - iteration=i + 1, - latency_ms=latest_latency, - util_samples=hw.utilization_samples, - memory_local_mb=hw.peak_memory_local_mb, - memory_shared_mb=hw.peak_memory_shared_mb, - cpu_pct=hw.mean_cpu_pct, - ram_mb=hw.ram_used_mb, - cpu_samples=hw.cpu_samples, - ) - - # Print final monitor snapshot - display.print_final_snapshot( - util_samples=hw.utilization_samples, - memory_mb=hw.peak_memory_mb, - latency_ms=stats.mean_ms, - hw_dict=hw.to_dict(), - cpu_samples=hw.cpu_samples, - ) # Store hardware metrics self._hw_metrics = hw.to_dict() @@ -494,9 +477,9 @@ def _collect_results(self, stats: PerfStats) -> BenchmarkResult: # Throughput samples_per_sec=samples_per_sec, batches_per_sec=batches_per_sec, - # Actual values - actual_device=self._model._session.device, - actual_task=self.config.task or "auto-detected", + # Actual values (resolved after build + compile) + actual_device=self._model.device, + actual_task=self._model.task or self.config.task or "auto-detected", # Hardware monitor metrics (only present when --monitor is used) hw_monitor=getattr(self, "_hw_metrics", None), ) @@ -719,24 +702,26 @@ def _perf_modules( def display_console_report(result: BenchmarkResult, console: Console) -> None: """Display benchmark results in formatted console output.""" - # Header + # Info section — show "requested (resolved)" when they differ console.print() - console.print( - Panel.fit( - f"[bold]Benchmark: {result.config.model_id}[/bold]", - border_style="blue", - ) - ) - # Info section - console.print() - console.print(f"[dim]Device:[/dim] {result.actual_device}") - console.print(f"[dim]Precision:[/dim] {result.config.precision}") - console.print(f"[dim]Task:[/dim] {result.actual_task}") - console.print( - f"[dim]Iterations:[/dim] {result.config.iterations} (+ {result.config.warmup} warmup)" - ) - console.print(f"[dim]Batch Size:[/dim] {result.config.batch_size}") + req_device = result.config.device + act_device = result.actual_device + device_str = f"{req_device} ({act_device})" if req_device != act_device else act_device + console.print(f"[dim]Device:[/dim] {device_str}") + + # TODO: show resolved precision once WinMLPreTrainedModel.precision + # is implemented (derive from _build_config.quant.weight_type) + + act_task = result.actual_task + if act_task.startswith("n/a"): + task_str = act_task + else: + req_task = result.config.task or "auto" + task_str = f"{req_task} ({act_task})" if req_task != act_task else act_task + console.print(f"[dim]Task:[/dim] {task_str}") + + # I/O tensor info is printed before the benchmark via _print_model_info() # Latency table console.print() @@ -814,6 +799,203 @@ def generate_output_path(model_id: str) -> Path: return Path(f"{slug}_perf.json") +# ============================================================================= +# Shared benchmark helpers +# ============================================================================= + + +def _print_model_info( + io_config: dict, + *, + task: str | None = None, + device: str = "auto", +) -> None: + """Print model I/O metadata before the benchmark starts.""" + console = Console(stderr=True) + console.print() + console.print(f"[dim]Device:[/dim] {device}") + # TODO: show resolved precision once WinMLPreTrainedModel.precision + # is implemented (derive from _build_config.quant.weight_type) + if task: + console.print(f"[dim]Task:[/dim] {task}") + + names = io_config.get("input_names", []) + shapes = io_config.get("input_shapes", []) + types = io_config.get("input_types", []) + if names: + label = "[dim]Inputs:[/dim] " + pad = " " + for i, name in enumerate(names): + shape = shapes[i] if i < len(shapes) else [] + dtype = str(types[i]) if i < len(types) else "" + shape_str = f"{shape!s}" + line = f"{name:<20s} {shape_str:<22s} {dtype}" + console.print(f"{label if i == 0 else pad}{line}") + + out_names = io_config.get("output_names", []) + out_shapes = io_config.get("output_shapes", []) + if out_names: + label = "[dim]Outputs:[/dim] " + pad = " " + for i, name in enumerate(out_names): + shape = out_shapes[i] if i < len(out_shapes) else [] + console.print(f"{label if i == 0 else pad}{name:<20s} {shape!s}") + + console.print() + + +def _run_monitored_loop( + session: Any, + inputs: dict[str, Any], + stats: PerfStats, + hw: Any, + *, + total_iterations: int, + warmup: int, + model_id: str, + device: str, +) -> None: + """Run the benchmark iteration loop with live hardware monitoring. + + Shared by both HF-path (PerfBenchmark) and ONNX-path (_run_onnx_benchmark). + """ + display = LiveMonitorDisplay( + total_iterations=total_iterations, + warmup=warmup, + model_id=model_id, + device=device, + ) + with display: + for i in range(total_iterations): + session.run(inputs) + + latest_latency = stats.all_samples_ms[-1] if stats.all_samples_ms else 0 + display.update( + iteration=i + 1, + latency_ms=latest_latency, + util_samples=hw.utilization_samples, + memory_local_mb=hw.peak_memory_local_mb, + memory_shared_mb=hw.peak_memory_shared_mb, + cpu_pct=hw.mean_cpu_pct, + ram_mb=hw.ram_used_mb, + cpu_samples=hw.cpu_samples, + ) + + +def _run_simple_loop( + session: Any, + inputs: dict[str, Any], + total_iterations: int, +) -> None: + """Run the benchmark iteration loop with periodic debug logging. + + Shared by both HF-path (PerfBenchmark) and ONNX-path (_run_onnx_benchmark). + """ + for i in range(total_iterations): + session.run(inputs) + + if (i + 1) % max(1, total_iterations // 10) == 0: + logger.debug("Progress: %d/%d", i + 1, total_iterations) + + +# ============================================================================= +# ONNX Direct Benchmark +# ============================================================================= + + +def _run_onnx_benchmark( + onnx_path: Path, + *, + device: str, + iterations: int, + warmup: int, + batch_size: int, + config: BenchmarkConfig, +) -> BenchmarkResult: + """Benchmark an ONNX file directly via WinMLSession (no HF build). + + Creates a WinMLSession, reads io_config for input shapes, + generates random inputs, and runs the standard benchmark loop. + """ + from ..session import WinMLSession + + session = WinMLSession(onnx_path=onnx_path, device=device) + + # Generate random inputs from session's I/O config + io_cfg = session.io_config + inputs = generate_random_inputs(io_config=io_cfg, batch_size=batch_size) + + # Compile session early so session.device is resolved for display + session.compile() + + # Print model info before benchmark starts + _print_model_info(io_cfg, device=session.device) + + # Run benchmark + total_iterations = warmup + iterations + hw_metrics = None + hw_ctx = None + + # Determine if hardware monitoring is available + if config.monitor: + from ..session.monitor.hw_monitor import HWMonitor + + if HWMonitor.is_available(): + hw_ctx = HWMonitor(poll_interval_ms=_HW_POLL_INTERVAL_MS) + else: + Console(stderr=True).print( + "[yellow]Warning:[/yellow] HWMonitor unavailable. " + "Running ONNX benchmark without monitoring." + ) + + if hw_ctx: + with session.perf(warmup=warmup) as stats, hw_ctx as hw: + _run_monitored_loop( + session, + inputs, + stats, + hw, + total_iterations=total_iterations, + warmup=warmup, + model_id=str(onnx_path.name), + device=device, + ) + hw_metrics = hw.to_dict() + else: + with session.perf(warmup=warmup) as stats: + _run_simple_loop(session, inputs, total_iterations) + + # Collect results + mean_latency_sec = stats.mean_ms / 1000.0 + samples_per_sec = batch_size / mean_latency_sec if mean_latency_sec > 0 else 0 + batches_per_sec = 1.0 / mean_latency_sec if mean_latency_sec > 0 else 0 + samples = stats.samples_ms + std_ms = float(np.std(samples)) if samples else 0.0 + + return BenchmarkResult( + config=config, + input_names=io_cfg["input_names"], + input_shapes=[list(s) if s else [] for s in io_cfg["input_shapes"]], + input_types=[str(t) for t in io_cfg["input_types"]], + output_names=io_cfg["output_names"], + output_shapes=[list(s) if s else [] for s in io_cfg["output_shapes"]], + mean_ms=stats.mean_ms, + min_ms=stats.min_ms, + max_ms=stats.max_ms, + p50_ms=stats.p50_ms, + p90_ms=stats.p90_ms, + p95_ms=stats.p95_ms, + p99_ms=stats.p99_ms, + std_ms=std_ms, + raw_samples_ms=stats.samples_ms, + samples_per_sec=samples_per_sec, + batches_per_sec=batches_per_sec, + actual_device=session.device, + actual_task="n/a (direct ONNX)", + hw_monitor=hw_metrics, + ) + + # ============================================================================= # CLI Command # ============================================================================= @@ -981,7 +1163,7 @@ def perf( from the model's I/O configuration. Accepts both HuggingFace model IDs and local .onnx files. - Both paths go through PerfBenchmark with WinMLAutoModel. + HF models go through PerfBenchmark; .onnx files use _run_onnx_benchmark. \b Examples: @@ -1103,10 +1285,11 @@ def perf( ) try: - # Both ONNX and HF go through PerfBenchmark (unified pipeline) model_path = Path(hf_model) is_onnx = model_path.suffix.lower() == ".onnx" + if is_onnx: + # ONNX direct path -- skip HF build, benchmark via WinMLSession if shape_config: console.print( "[yellow]Warning:[/yellow] --shape-config is ignored for " @@ -1115,14 +1298,28 @@ def perf( config.shape_config = None if not model_path.exists(): raise FileNotFoundError(f"ONNX file not found: {model_path}") - console.print(f"[dim]Building + benchmarking ONNX:[/dim] {model_path}") + console.print(f"[dim]Benchmarking ONNX:[/dim] {model_path}") + + from ..sysinfo import resolve_device + + resolved_device, _ = resolve_device(device=config.device) + + result = _run_onnx_benchmark( + model_path, + device=resolved_device, + iterations=iterations, + warmup=warmup, + batch_size=batch_size, + config=config, + ) else: + # HF model path -- full build + benchmark via PerfBenchmark if precision != "auto": console.print(f"[dim]Precision: {precision} (applied during model build)[/dim]") console.print(f"[dim]Loading model:[/dim] {hf_model}") - benchmark = PerfBenchmark(config) - result = benchmark.run() + benchmark = PerfBenchmark(config) + result = benchmark.run() # Display console report display_console_report(result, console) diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index 5dae75b4b..daddb9282 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -25,7 +25,6 @@ import click from rich.console import Console -from ..onnx import is_compiled_onnx from ..utils.logging import configure_logging @@ -92,6 +91,12 @@ default=False, help="Use symmetric quantization", ) +@click.option( + "--task", + type=str, + default=None, + help="Task for calibration dataset selection (e.g., 'image-classification').", +) @click.option( "--verbose", "-v", @@ -111,6 +116,7 @@ def quantize( activation_type: str | None, per_channel: bool, symmetric: bool, + task: str | None, verbose: bool, ) -> None: r"""Quantize ONNX model by inserting QDQ nodes. @@ -142,12 +148,6 @@ def quantize( configure_logging(verbose=verbose) - if is_compiled_onnx(model): - raise click.ClickException( - f"{model} is a compiled EPContext model and cannot be quantized. " - "Run 'winml quantize' on the original ONNX model before compilation." - ) - # Import quantizer (late import to speed up CLI) from ..quant import WinMLQuantizationConfig, quantize_onnx @@ -178,8 +178,18 @@ def quantize( activation_type=resolved_activation, per_channel=per_channel, symmetric=symmetric, + task=task, ) + # Display dataset info from config + if config.dataset_name: + _dataset_display = config.dataset_name + elif config.task and config.task != "random": + _dataset_display = f"Default for task '{config.task}'" + else: + _dataset_display = "Random data (synthetic from ONNX I/O specs)" + console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}") + try: console.print("\n[bold]Running quantization...[/bold]") result = quantize_onnx(model, output_path=output, config=config) diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 1eca3bfc4..17da4ab0b 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -56,8 +56,7 @@ WinMLExportConfig, _resolve_export_config_from_specs, ) -from ..loader import resolve_loader_config -from ..loader.config import WinMLLoaderConfig +from ..loader.config import WinMLLoaderConfig, resolve_loader_config from ..optim.config import WinMLOptimizationConfig from ..quant.config import WinMLQuantizationConfig from ..utils.config_utils import merge_config @@ -464,11 +463,10 @@ def generate_hf_build_config( Orchestration Flow: 1. loader.resolve_loader_config() -> (WinMLLoaderConfig, hf_config, resolved_class) (includes sub-config consolidation for multimodal) - 2. MODEL_BUILD_CONFIGS.get() — registry lookup - 3. Try Optimum export config; on failure use empty placeholder - 4. Merge registered export on top (registry always wins) - 5. _assemble_config() + merge -> WinMLBuildConfig - 6. If module: specialize for each matching submodule + 2. MODEL_BUILD_CONFIGS.get() — registry lookup (may short-circuit step 3) + 3. export._resolve_export_config_from_specs() OR registered export config + 4. _assemble_config() + merge -> WinMLBuildConfig + 5. If module: specialize for each matching submodule Args: model_id: HuggingFace model ID (e.g., "bert-base-uncased") or local path. @@ -521,9 +519,26 @@ class name. Uses torchinfo to discover submodules and infer # ========================================================================= # STEP 3: Generate export config # ========================================================================= - # Try Optimum first; if model is unsupported, use empty placeholder. - # Then always merge registered export config on top (registry wins). - try: + # Priority: registered config with I/O specs > Optimum lookup. + # Models not in Optimum's TasksManager (e.g., BLIP) crash at + # _resolve_export_config_from_specs(). If the registry already has + # input_tensors, use them directly and skip the Optimum path. + # Note: None means "not configured" (fall through to Optimum); + # [] would mean "explicitly no inputs" (use as-is, skip Optimum). + _registered_export = registered.export if registered else None + if _registered_export is not None and _registered_export.input_tensors is not None: + # deepcopy to avoid mutating the shared registry singleton + export_config = copy.deepcopy(_registered_export) + logger.info( + "Using registered export config for '%s' (skipping Optimum lookup)", + _registry_key, + ) + else: + # Standard path: resolve I/O specs from Optimum's OnnxConfig + logger.debug( + "No registered export config for '%s'; resolving via Optimum", + _registry_key, + ) export_config = _resolve_export_config_from_specs( model_type=loader_config.model_type, task=loader_config.task, @@ -533,31 +548,6 @@ class name. Uses torchinfo to discover submodules and infer batch_size=WinMLExportConfig().batch_size, **(shape_config or {}), ) - except ValueError as e: - # ONNXConfigNotFoundError is a ValueError subclass (from export.io) - # — catch broadly to avoid top-level import of export.io which - # triggers heavy optimum/transformers imports. - from ..export.io import ONNXConfigNotFoundError - - if not isinstance(e, ONNXConfigNotFoundError): - raise - logger.info( - "Optimum has no OnnxConfig for '%s'; using empty export config", - _registry_key, - ) - export_config = WinMLExportConfig() - - # Merge registered export on top — registered always wins. - # Use WinMLExportConfig.merge() to properly handle nested - # InputTensorSpec/OutputTensorSpec lists (merge_config converts - # dataclass lists to dicts which breaks __post_init__). - _registered_export = registered.export if registered else None - if _registered_export is not None: - export_config = _merge_export_config(export_config, _registered_export) - logger.info( - "Merged registered export config for '%s'", - _registry_key, - ) # ========================================================================= # STEP 4: Assemble config + merge override @@ -569,30 +559,59 @@ class name. Uses torchinfo to discover submodules and infer model_id=model_id, model_type=hf_config.model_type, ) - # STEP 3.5: Resolve quant + compile based on device/precision - # Only override assembled defaults when user explicitly targets a device/precision. - # When both are "auto", preserve _assemble_config() defaults (registry values). - if device != "auto" or precision != "auto" or ep is not None: - resolved_quant, resolved_compile = resolve_quant_compile_config( - device=device, - precision=precision, - ep=ep, - task=parent_config.loader.task, - ) - if resolved_quant is not None: - # Merge into assembled config to preserve task/model_name + if override: + parent_config = merge_config(parent_config, override) + + # ========================================================================= + # STEP 4.5: Apply device/precision policy (affects quant + compile only) + # ========================================================================= + from ..sysinfo import resolve_device + from .precision import resolve_precision + + # ALWAYS detect hardware — even when device="auto" — so we don't + # blindly default to QNN on machines without an NPU (#412). + resolved_device, available_devices = resolve_device(device=device) + logger.info( + "Device resolved: %s (available: %s)", + resolved_device, + ", ".join(available_devices), + ) + + policy = resolve_precision( + device=resolved_device, + precision=precision, + ep=ep, + available_devices=available_devices, + task=parent_config.loader.task, + ) + + # Apply policy: set compile provider from detected hardware + if policy.device != "auto": + # Quant config (weight_type and activation_type are always both-None or both-set) + if policy.weight_type is not None: if parent_config.quant is None: - parent_config.quant = resolved_quant - else: - parent_config.quant.weight_type = resolved_quant.weight_type - parent_config.quant.activation_type = resolved_quant.activation_type + parent_config.quant = WinMLQuantizationConfig() + parent_config.quant.weight_type = policy.weight_type + parent_config.quant.activation_type = policy.activation_type else: parent_config.quant = None - parent_config.compile = resolved_compile - # User override has highest priority — applied last - if override: - parent_config = merge_config(parent_config, override) + # Compile config + parent_config.compile = WinMLCompileConfig.for_provider( + policy.compile_provider, + ) + else: + # Even in auto/auto mode, set compile provider from detected hardware + # instead of preserving the hardcoded EPConfig default (#412). + from .precision import get_provider_for_device + + hw_provider = get_provider_for_device(resolved_device) + if hw_provider is not None: + parent_config.compile = WinMLCompileConfig.for_provider( + hw_provider, + ) + # When hw_provider is None (CPU-only), keep the default compile config + # so the pipeline still has a valid compile section. # ========================================================================= # STEP 5: Specialize for submodules if requested @@ -873,8 +892,8 @@ def _assemble_config( Args: loader_config: Resolved WinMLLoaderConfig (from resolve_loader_config). - export_config: Resolved WinMLExportConfig (Optimum baseline - merged with registered export config). + export_config: Resolved WinMLExportConfig + (from registry or _resolve_export_config_from_specs). registered: Registered config from MODEL_BUILD_CONFIGS (or None). model_id: HuggingFace model ID (for quant model_name), or None. model_type: Parent HF model type (for quant fallback name). diff --git a/src/winml/modelkit/config/precision.py b/src/winml/modelkit/config/precision.py index 69706c5b4..8722f31ef 100644 --- a/src/winml/modelkit/config/precision.py +++ b/src/winml/modelkit/config/precision.py @@ -18,10 +18,12 @@ logger = logging.getLogger(__name__) # Tasks where GPU auto-precision may differ (LLM = w4a16 recommendation) -_LLM_TASKS = frozenset({ - "text-generation", - "text2text-generation", -}) +_LLM_TASKS = frozenset( + { + "text-generation", + "text2text-generation", + } +) # Default auto-precision mapping: device -> precision _AUTO_PRECISION: dict[str, str] = { @@ -66,6 +68,19 @@ "cpu": None, } + +def get_provider_for_device(device: str) -> str | None: + """Get the default compile provider for a resolved device. + + Args: + device: Resolved device name ("npu", "gpu", "cpu"). + + Returns: + Provider name (e.g., "qnn", "dml") or None for CPU. + """ + return _DEVICE_TO_PROVIDER.get(device) + + # EP -> device inference (when --ep is given without --device) _EP_TO_DEVICE: dict[str, str] = { "qnn": "npu", @@ -234,9 +249,7 @@ def resolve_precision( if ep is not None: ep = ep.lower() if ep not in VALID_EPS: - raise ValueError( - f"Unknown EP '{ep}'. Expected one of: {sorted(VALID_EPS)}" - ) + raise ValueError(f"Unknown EP '{ep}'. Expected one of: {sorted(VALID_EPS)}") # Infer device from EP when device is "auto" if device == "auto": device = _EP_TO_DEVICE[ep] @@ -263,7 +276,8 @@ def resolve_precision( # Device is "auto" but precision is explicit — pick best device # FIXME: improve device-precision compatibility lookup table later resolved_device = _pick_device_for_precision( - resolved_precision, available_devices or ["cpu"], + resolved_precision, + available_devices or ["cpu"], ) # Resolve "auto" precision for the resolved device diff --git a/src/winml/modelkit/core/__init__.py b/src/winml/modelkit/core/__init__.py index 9bc5c4dfd..8f4cc9d47 100644 --- a/src/winml/modelkit/core/__init__.py +++ b/src/winml/modelkit/core/__init__.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- """Core utilities for ModelKit.""" -# New API - pure torch, no external dependencies +from .model_input_generator import generate_dummy_inputs_from_specs from .node_metadata import ( NodeMetadata, add_metadata_to_node, @@ -15,20 +15,6 @@ query_nodes_by_origin, set_origin_for_graph, ) -from .onnx_utils import ( - get_epcontext_info, - get_io_config, -) - - -def __getattr__(name: str): - """Lazy-load generate_dummy_inputs_from_specs to avoid importing torch at startup.""" - if name == "generate_dummy_inputs_from_specs": - from .model_input_generator import generate_dummy_inputs_from_specs - - globals()["generate_dummy_inputs_from_specs"] = generate_dummy_inputs_from_specs - return generate_dummy_inputs_from_specs - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ @@ -44,3 +30,26 @@ def __getattr__(name: str): "query_nodes_by_origin", "set_origin_for_graph", ] + + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "get_epcontext_info": (".onnx_utils", "get_epcontext_info"), + "get_io_config": (".onnx_utils", "get_io_config"), +} + + +def __getattr__(name: str): + """Lazy-load onnx_utils (imports torch at module level).""" + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return list(set(list(globals()) + __all__)) diff --git a/src/winml/modelkit/core/model_input_generator.py b/src/winml/modelkit/core/model_input_generator.py index 63191958a..5c6f3fb76 100644 --- a/src/winml/modelkit/core/model_input_generator.py +++ b/src/winml/modelkit/core/model_input_generator.py @@ -3,10 +3,11 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- #!/usr/bin/env python3 -"""Manual Model Input Generator - Pure PyTorch. +"""Manual Model Input Generator - Pure NumPy. -This module provides manual input tensor generation from specifications. -No external dependencies on Optimum or transformers. +This module provides input array generation from specifications using only +NumPy (no torch dependency). Outputs are numpy arrays compatible with +ONNX Runtime session.run(). For Optimum-based automatic input generation, use modelkit.export.io: - resolve_io_specs(model_type, task, hf_config) @@ -21,7 +22,7 @@ ... } >>> inputs = generate_dummy_inputs_from_specs(specs) >>> inputs["input_ids"].shape - torch.Size([1, 128]) + (1, 128) """ from __future__ import annotations @@ -29,7 +30,7 @@ import logging from typing import Any -import torch +import numpy as np logger = logging.getLogger(__name__) @@ -37,11 +38,11 @@ def generate_dummy_inputs_from_specs( input_specs: dict[str, dict[str, Any]], -) -> dict[str, torch.Tensor]: +) -> dict[str, np.ndarray]: """Generate dummy inputs from manual specifications. - This function creates PyTorch tensors based on explicit specifications, - without requiring model loading or Optimum/transformers dependencies. + This function creates NumPy arrays based on explicit specifications, + without requiring model loading or Optimum/transformers/torch dependencies. Args: input_specs: Input specifications with format: @@ -54,7 +55,7 @@ def generate_dummy_inputs_from_specs( } Returns: - Dictionary mapping input names to generated tensors + Dictionary mapping input names to generated numpy arrays Raises: ValueError: If required fields are missing or invalid @@ -70,7 +71,7 @@ def generate_dummy_inputs_from_specs( ... } >>> inputs = generate_dummy_inputs_from_specs(specs) >>> inputs["pixel_values"].shape - torch.Size([1, 3, 224, 224]) + (1, 3, 224, 224) """ inputs = {} @@ -84,9 +85,9 @@ def generate_dummy_inputs_from_specs( # Parse dtype dtype_str = spec["dtype"].lower() if dtype_str in ["int", "long", "int64"]: - dtype = torch.long + dtype = np.int64 elif dtype_str in ["float", "float32"]: - dtype = torch.float32 + dtype = np.float32 else: raise ValueError( f"Unsupported dtype '{spec['dtype']}' for '{name}'. Use 'int' or 'float'" @@ -103,16 +104,16 @@ def generate_dummy_inputs_from_specs( raise ValueError(f"Range must have exactly 2 values [min, max] for '{name}'") min_val, max_val = spec["range"] - if dtype == torch.long: - inputs[name] = torch.randint(min_val, max_val + 1, shape, dtype=dtype) + if dtype == np.int64: + inputs[name] = np.random.randint(min_val, max_val + 1, size=shape).astype(dtype) else: - inputs[name] = torch.rand(shape, dtype=dtype) * (max_val - min_val) + min_val + inputs[name] = np.random.rand(*shape).astype(dtype) * (max_val - min_val) + min_val else: # Default ranges - if dtype == torch.long: - inputs[name] = torch.randint(0, 2, shape, dtype=dtype) # Default: 0 or 1 + if dtype == np.int64: + inputs[name] = np.random.randint(0, 2, size=shape).astype(dtype) else: - inputs[name] = torch.rand(shape, dtype=dtype) # Default: [0, 1) + inputs[name] = np.random.rand(*shape).astype(dtype) logger.info( "Generated '%s': shape=%s, dtype=%s", name, list(inputs[name].shape), inputs[name].dtype diff --git a/src/winml/modelkit/core/onnx_utils.py b/src/winml/modelkit/core/onnx_utils.py index 04b479303..0aadbce2c 100644 --- a/src/winml/modelkit/core/onnx_utils.py +++ b/src/winml/modelkit/core/onnx_utils.py @@ -18,8 +18,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -import torch - if TYPE_CHECKING: import onnx @@ -356,7 +354,7 @@ def infer_output_names(outputs: Any) -> list[str] | None: for field_name in outputs.__dataclass_fields__: field_value = getattr(outputs, field_name, None) - if field_value is not None and isinstance(field_value, torch.Tensor): + if field_value is not None and type(field_value).__module__.startswith("torch"): output_names.append(field_name) # Only return names if we found simple tensor outputs diff --git a/src/winml/modelkit/export/__init__.py b/src/winml/modelkit/export/__init__.py index 2e4b2bd63..017f28911 100644 --- a/src/winml/modelkit/export/__init__.py +++ b/src/winml/modelkit/export/__init__.py @@ -19,32 +19,6 @@ ) -def __getattr__(name: str): - """Lazy-load heavy submodules to avoid importing optimum at startup.""" - _io_names = { - "MaxLengthTextInputGenerator", - "ONNXConfigNotFoundError", - "generate_dummy_inputs", - "register_onnx_overwrite", - "resolve_io_specs", - } - if name in _io_names: - from . import io - - resolved = getattr(io, name) - globals()[name] = resolved - return resolved - - _pytorch_names = {"export_pytorch", "export_onnx"} - if name in _pytorch_names: - from .pytorch import export_pytorch - - globals().update(export_pytorch=export_pytorch, export_onnx=export_pytorch) - return globals()[name] - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - __version__ = "2.1.0" __all__ = [ @@ -60,3 +34,31 @@ def __getattr__(name: str): "resolve_export_config", "resolve_io_specs", ] + + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "MaxLengthTextInputGenerator": (".io", "MaxLengthTextInputGenerator"), + "ONNXConfigNotFoundError": (".io", "ONNXConfigNotFoundError"), + "generate_dummy_inputs": (".io", "generate_dummy_inputs"), + "register_onnx_overwrite": (".io", "register_onnx_overwrite"), + "resolve_io_specs": (".io", "resolve_io_specs"), + "export_pytorch": (".pytorch", "export_pytorch"), + "export_onnx": (".pytorch", "export_pytorch"), # alias for export_pytorch +} + + +def __getattr__(name: str): + """Lazy-load heavy exports to avoid importing optimum at package init.""" + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return list(set(list(globals()) + __all__)) diff --git a/src/winml/modelkit/export/io.py b/src/winml/modelkit/export/io.py index 64c7952be..aa93dac04 100644 --- a/src/winml/modelkit/export/io.py +++ b/src/winml/modelkit/export/io.py @@ -62,6 +62,26 @@ class ONNXConfigNotFoundError(ValueError): register_onnx_overwrite = TasksManager.create_register("onnx", overwrite_existing=True) +_hf_models_registered = False + + +def ensure_hf_models_registered() -> None: + """Trigger HF model ONNX config registrations (idempotent). + + With lazy loading in ``modelkit/__init__.py``, the HF model files + (bert.py, clip.py, etc.) and their ``@register_onnx_overwrite`` + decorators are not executed until explicitly imported. This function + forces that import so registrations are in place before any + ``TasksManager.get_exporter_config_constructor()`` call. + """ + global _hf_models_registered + if _hf_models_registered: + return + from ..models import hf as _hf # noqa: F401 + + _hf_models_registered = True + + # ============================================================================= # Task Synonym Extensions (extends Optimum's TasksManager.map_from_synonym) # ============================================================================= @@ -190,6 +210,8 @@ def _get_onnx_config( Raises: ValueError: If no OnnxConfig is registered for the model_type/task combination """ + ensure_hf_models_registered() + normalized_task = _map_task_synonym(task) logger.debug( diff --git a/src/winml/modelkit/inspect/__init__.py b/src/winml/modelkit/inspect/__init__.py index bb481f951..8d5fc7876 100644 --- a/src/winml/modelkit/inspect/__init__.py +++ b/src/winml/modelkit/inspect/__init__.py @@ -23,9 +23,11 @@ from transformers import AutoConfig from .resolver import ( + build_tensor_infos_from_io_specs, compile_support_status, detect_task, get_build_config, + get_known_tasks, resolve_cache, resolve_exporter, resolve_io_config, @@ -57,17 +59,14 @@ class InspectError(Exception): """Base exception for inspect command.""" - class ModelNotFoundError(InspectError): """Model not found on HuggingFace Hub.""" - class NetworkError(InspectError): """Network error while fetching model config.""" - def inspect_model( model_id: str, include_hierarchy: bool = False, @@ -130,8 +129,8 @@ def inspect_model( loader_info.hf_model_class_source, ) - # Step 4: Resolve exporter configuration (pass HF config for tensor specs) - exporter_info = resolve_exporter(model_type, task, hf_config=hf_config) + # Step 4: Resolve exporter configuration (pass model_id for correct image sizes) + exporter_info = resolve_exporter(model_type, task, hf_config=hf_config, model_id=model_id) logger.debug( "Exporter: %s (source: %s)", exporter_info.onnx_config_class, @@ -151,9 +150,7 @@ def inspect_model( logger.debug("Hierarchy: %d HF modules", hierarchy_info.hf_module_count) # Step 6: Compile overall support status - overall_support, support_notes = compile_support_status( - loader_info, exporter_info, winml_info - ) + overall_support, support_notes = compile_support_status(loader_info, exporter_info, winml_info) logger.info("Overall support: %s", overall_support.value) # Step 7: Get full build config (for verbose output) @@ -164,15 +161,20 @@ def inspect_model( logger.debug("Cache: %d/%d stages cached", cache_info.total_cached, len(cache_info.stages)) # Step 9: Resolve processor classes - processor_info = resolve_processor(model_id) + processor_info = resolve_processor(model_id, model_type=model_type) logger.debug( "Processor: %s, Tokenizer: %s", processor_info.processor_class, processor_info.tokenizer_class, ) - # Step 10: Extract IO config from HF config - io_config_info = resolve_io_config(hf_config) + # Step 10: Extract IO config (dynamically discovers attrs from OnnxConfig) + io_config_info = resolve_io_config( + hf_config, + model_id=model_id, + model_type=model_type, + task=task, + ) logger.debug( "IO Config: max_pos=%s, vocab=%s, img_size=%s", io_config_info.max_position_embeddings, @@ -215,5 +217,7 @@ def inspect_model( "SupportLevel", "TensorInfo", "WinMLInfo", + "build_tensor_infos_from_io_specs", + "get_known_tasks", "inspect_model", ] diff --git a/src/winml/modelkit/inspect/formatter.py b/src/winml/modelkit/inspect/formatter.py index e5bf0e4d2..e5b4855eb 100644 --- a/src/winml/modelkit/inspect/formatter.py +++ b/src/winml/modelkit/inspect/formatter.py @@ -63,14 +63,21 @@ def _output_processor_table(console: Console, result: InspectResult) -> None: processor_table.add_column("Field", style="cyan") processor_table.add_column("Value") + def _src_tag(source: str | None) -> str: + return f" [dim](via {source})[/dim]" if source else "" + if processor.processor_class: - processor_table.add_row("Processor", processor.processor_class) + src = _src_tag(processor.processor_source) + processor_table.add_row("Processor", f"{processor.processor_class}{src}") if processor.tokenizer_class: - processor_table.add_row("Tokenizer", processor.tokenizer_class) + src = _src_tag(processor.tokenizer_source) + processor_table.add_row("Tokenizer", f"{processor.tokenizer_class}{src}") if processor.image_processor_class: - processor_table.add_row("Image Processor", processor.image_processor_class) + src = _src_tag(processor.image_processor_source) + processor_table.add_row("Image Processor", f"{processor.image_processor_class}{src}") if processor.feature_extractor_class: - processor_table.add_row("Feature Extractor", processor.feature_extractor_class) + src = _src_tag(processor.feature_extractor_source) + processor_table.add_row("Feature Extractor", f"{processor.feature_extractor_class}{src}") # Only show panel if we have at least one processor class if any( @@ -134,6 +141,17 @@ def _output_io_config_table(console: Console, result: InspectResult) -> None: if io_config.hidden_size is not None: io_table.add_row("Hidden Size", str(io_config.hidden_size)) has_content = True + if io_config.hidden_sizes is not None: + sizes_str = " → ".join(str(s) for s in io_config.hidden_sizes) + io_table.add_row("Hidden Sizes", sizes_str) + has_content = True + + # Extra attrs discovered dynamically from OnnxConfig + if io_config.extra: + for key, val in sorted(io_config.extra.items()): + label = key.replace("_", " ").title() + io_table.add_row(label, str(val)) + has_content = True # Only show panel if we have content if has_content: @@ -244,7 +262,10 @@ def output_table(console: Console, result: InspectResult, verbose: bool = False) else: shape_str = "-" dtype_str = tensor.dtype or "-" - exporter_table.add_row(f" {tensor.name}", f"{dtype_str} {shape_str}") + extra = "" + if tensor.value_range is not None: + extra = f" [dim]range {tensor.value_range}[/dim]" + exporter_table.add_row(f" {tensor.name}", f"{dtype_str} {shape_str}{extra}") # Output tensors if result.exporter.output_tensors: @@ -395,6 +416,7 @@ def output_json(result: InspectResult, verbose: bool = False) -> str: "shape": list(t.shape) if t.shape else None, "shape_desc": t.shape_desc, "dynamic_axes": t.dynamic_axes, + "value_range": list(t.value_range) if t.value_range else None, } for t in result.exporter.input_tensors ], @@ -447,6 +469,10 @@ def output_json(result: InspectResult, verbose: bool = False) -> str: "tokenizer_class": result.processor.tokenizer_class, "image_processor_class": result.processor.image_processor_class, "feature_extractor_class": result.processor.feature_extractor_class, + "processor_source": result.processor.processor_source, + "tokenizer_source": result.processor.tokenizer_source, + "image_processor_source": result.processor.image_processor_source, + "feature_extractor_source": result.processor.feature_extractor_source, } else: data["processor"] = None @@ -466,6 +492,8 @@ def output_json(result: InspectResult, verbose: bool = False) -> str: "num_channels": io_config.num_channels, "sampling_rate": io_config.sampling_rate, "hidden_size": io_config.hidden_size, + "hidden_sizes": io_config.hidden_sizes, + "extra": io_config.extra, } else: data["io_config"] = None diff --git a/src/winml/modelkit/inspect/resolver.py b/src/winml/modelkit/inspect/resolver.py index 4002aab1f..27550ffe6 100644 --- a/src/winml/modelkit/inspect/resolver.py +++ b/src/winml/modelkit/inspect/resolver.py @@ -56,7 +56,7 @@ } -def _get_known_tasks() -> set[str]: +def get_known_tasks() -> set[str]: """Collect all known task strings from internal mappings and TasksManager. Returns: @@ -94,12 +94,10 @@ def validate_task(task: str) -> None: Raises: ValueError: If the task is not recognized. """ - known = _get_known_tasks() + known = get_known_tasks() if task not in known: sorted_tasks = sorted(known) - raise ValueError( - f"Unknown task '{task}'. Known tasks: {', '.join(sorted_tasks)}" - ) + raise ValueError(f"Unknown task '{task}'. Known tasks: {', '.join(sorted_tasks)}") def detect_task(config: PretrainedConfig) -> tuple[str, str]: @@ -176,18 +174,51 @@ def resolve_loader(model_type: str, task: str) -> LoaderInfo: ) -def _extract_tensor_specs_from_onnx_config( - onnx_config_cls, - hf_config: PretrainedConfig, +def _shape_to_desc(shape: tuple | list | None, dynamic_axes: dict[int, str]) -> str: + """Convert tensor shape to human-readable string with dynamic markers. + + Dynamic axes are shown as the concrete value from dummy inputs, + distinguishable from static dims by context (batch → "B"). + For non-batch dynamic dims (sequence, height, width), shows the + concrete value since that's what the model actually uses for export. + + Fixes D-3 from #247: uses axis names directly, no hardcoded abbreviations. + """ + if shape is None: + parts = [] + for _idx, axis_name in sorted(dynamic_axes.items()): + if axis_name.lower() in ("batch", "batch_size"): + parts.append("B") + else: + parts.append(axis_name) + return f"[{', '.join(parts)}]" if parts else "[]" + + parts = [] + for i, dim in enumerate(shape): + if i in dynamic_axes: + axis_name = dynamic_axes[i] + if axis_name.lower() in ("batch", "batch_size"): + parts.append("B") + else: + # Show concrete value — this is the export shape from + # preprocessor_config or shape_config, not a placeholder + parts.append(str(dim)) + else: + parts.append(str(dim)) + return f"[{', '.join(parts)}]" + + +def build_tensor_infos_from_io_specs( + io_specs: dict, ) -> tuple[list[TensorInfo], list[TensorInfo]]: - """Extract tensor specifications from an ONNX config class. + """Convert resolve_io_specs() output to TensorInfo lists. - Uses the ONNX config's generate_dummy_inputs() to get actual tensor shapes, - and the inputs/outputs properties for dynamic axes information. + Single conversion point from config's I/O spec format to inspect's + TensorInfo dataclass. Eliminates the duplicated extraction logic + that previously lived in _extract_tensor_specs_from_onnx_config. Args: - onnx_config_cls: ONNX config constructor (may be functools.partial) - hf_config: HuggingFace PretrainedConfig for shape bounds + io_specs: Dict returned by export/io.py resolve_io_specs() Returns: Tuple of (input_tensors, output_tensors) @@ -195,88 +226,44 @@ def _extract_tensor_specs_from_onnx_config( input_tensors: list[TensorInfo] = [] output_tensors: list[TensorInfo] = [] - try: - # Instantiate ONNX config with HF config - onnx_config = onnx_config_cls(hf_config) - - # Generate dummy inputs to get actual shapes - dummy_inputs: dict = {} - try: - dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - except Exception as e: - logger.debug("Failed to generate dummy inputs: %s", e) - - # Helper to convert shape to description with dynamic axis markers - def shape_to_desc( - shape: tuple | list | None, dynamic_axes: dict[int, str] - ) -> str: - """Convert tensor shape to human-readable string with dynamic markers.""" - if shape is None: - # Fallback: use dynamic axes only - parts = [] - for _idx, axis_name in sorted(dynamic_axes.items()): - if "batch" in axis_name.lower(): - parts.append("B") - else: - parts.append(axis_name) - return f"[{', '.join(parts)}]" if parts else "[]" - - parts = [] - for i, dim in enumerate(shape): - if i in dynamic_axes: - axis_name = dynamic_axes[i].lower() - if "batch" in axis_name: - parts.append("B") - elif "sequence" in axis_name: - parts.append("S") - elif "height" in axis_name or "width" in axis_name: - parts.append(str(dim)) # Use actual size - else: - parts.append(str(dim)) - else: - parts.append(str(dim)) - return f"[{', '.join(parts)}]" - - # Standard input dtypes based on tensor name patterns - def infer_dtype(name: str) -> str: - name_lower = name.lower() - if "ids" in name_lower or "label" in name_lower: - return "int64" - if "mask" in name_lower and "pixel" not in name_lower: - return "int64" - return "float32" - - # Process inputs - use actual shapes from dummy inputs - if hasattr(onnx_config, "inputs"): - for name, axes in onnx_config.inputs.items(): - shape = None - if name in dummy_inputs: - shape = tuple(dummy_inputs[name].shape) - shape_desc = shape_to_desc(shape, axes) - dtype = infer_dtype(name) - input_tensors.append( - TensorInfo( - name=name, - dtype=dtype, - shape_desc=shape_desc, - dynamic_axes=dict(axes), - ) - ) - - # Process outputs - we don't have actual shapes, use dynamic axes - if hasattr(onnx_config, "outputs"): - for name, axes in onnx_config.outputs.items(): - shape_desc = shape_to_desc(None, axes) - output_tensors.append( - TensorInfo( - name=name, - shape_desc=shape_desc, - dynamic_axes=dict(axes), - ) - ) + input_names = io_specs.get("input_names", []) + input_shapes = io_specs.get("input_shapes", []) + input_dtypes = io_specs.get("input_dtypes", []) + inputs_axes = io_specs.get("inputs", {}) + value_ranges = io_specs.get("value_ranges", {}) + + for i, name in enumerate(input_names): + shape = input_shapes[i] if i < len(input_shapes) else None + dtype = input_dtypes[i] if i < len(input_dtypes) else None + axes = inputs_axes.get(name, {}) + vr = value_ranges.get(name) + + shape_desc = _shape_to_desc(shape, axes) if shape else None + + input_tensors.append( + TensorInfo( + name=name, + dtype=dtype, + shape=shape, + shape_desc=shape_desc, + dynamic_axes=dict(axes) if axes else None, + value_range=vr, + ) + ) - except Exception as e: - logger.debug("Failed to extract tensor specs from ONNX config: %s", e) + output_names = io_specs.get("output_names", []) + outputs_axes = io_specs.get("outputs", {}) + + for name in output_names: + axes = outputs_axes.get(name, {}) + shape_desc = _shape_to_desc(None, axes) if axes else None + output_tensors.append( + TensorInfo( + name=name, + shape_desc=shape_desc, + dynamic_axes=dict(axes) if axes else None, + ) + ) return input_tensors, output_tensors @@ -285,15 +272,22 @@ def resolve_exporter( model_type: str, task: str, hf_config: PretrainedConfig | None = None, + *, + model_id: str | None = None, ) -> ExporterInfo: """Resolve exporter configuration for a model. - Uses MODEL_BUILD_CONFIGS registry from models/__init__.py. + Uses MODEL_BUILD_CONFIGS registry, then falls back to + export/io.py resolve_io_specs() for I/O extraction. This ensures + inspect and config share the same battle-tested I/O extraction path, + including correct image sizes from preprocessor_config.json. Args: model_type: HuggingFace model type (e.g., "clip") task: Canonical task name hf_config: Optional HuggingFace config for extracting tensor shapes + model_id: Optional HuggingFace model ID for preprocessor_config.json + (needed for correct image sizes on models like ResNet) Returns: ExporterInfo with ONNX config, tensors, and support level @@ -321,8 +315,7 @@ def resolve_exporter( output_tensors: list[TensorInfo] = [] if export_config.output_tensors: output_tensors.extend( - TensorInfo(name=spec.name or "unknown") - for spec in export_config.output_tensors + TensorInfo(name=spec.name or "unknown") for spec in export_config.output_tensors ) return ExporterInfo( @@ -357,14 +350,23 @@ def resolve_exporter( else: config_name = onnx_config_cls.__name__ - # Extract tensor specs from ONNX config if HF config is available + # Extract tensor specs via resolve_io_specs (shared with config command) input_tensors: list[TensorInfo] = [] output_tensors: list[TensorInfo] = [] if hf_config is not None: - input_tensors, output_tensors = _extract_tensor_specs_from_onnx_config( - onnx_config_cls, hf_config - ) + try: + from ..export.io import resolve_io_specs + + io_specs = resolve_io_specs( + model_type=model_type, + task=task, + hf_config=hf_config, + model_id=model_id, + ) + input_tensors, output_tensors = build_tensor_infos_from_io_specs(io_specs) + except Exception as e: + logger.debug("resolve_io_specs failed for %s/%s: %s", model_type, task, e) return ExporterInfo( onnx_config_class=config_name, @@ -535,9 +537,7 @@ def resolve_cache(model_id: str) -> CacheInfo: filename = ms.get("filename") artifact = model_dir / filename if filename else None size_bytes = ( - artifact.stat().st_size - if artifact and artifact.exists() - else 0 + artifact.stat().st_size if artifact and artifact.exists() else 0 ) stage_info = CacheStageInfo( stage=stage, @@ -575,7 +575,7 @@ def resolve_cache(model_id: str) -> CacheInfo: stem = f.stem last_sep = stem.rfind("_") if last_sep > 0: - stage_name = stem[last_sep + 1:] + stage_name = stem[last_sep + 1 :] cached_files[stage_name] = f for stage in pipeline_stages: @@ -606,65 +606,200 @@ def resolve_cache(model_id: str) -> CacheInfo: ) -def resolve_io_config(config: PretrainedConfig) -> IOConfigInfo: +def _find_nested_configs(config: PretrainedConfig) -> list: + """Discover all nested PretrainedConfig objects dynamically. + + Walks config attributes to find nested configs without hardcoding + names like "text_config", "vision_config", etc. Fixes D-2 and D-5 + from #247. + + Args: + config: HuggingFace PretrainedConfig object + + Returns: + List of nested PretrainedConfig instances + """ + from transformers import PretrainedConfig + + nested = [] + for attr_name in vars(config): + if attr_name.startswith("_"): + continue + try: + val = getattr(config, attr_name) + if isinstance(val, PretrainedConfig): + nested.append(val) + except Exception: + continue + return nested + + +def _discover_io_attrs_from_onnx_config( + model_type: str, + task: str, + hf_config: PretrainedConfig, +) -> set[str]: + """Discover IO-relevant config attributes from OnnxConfig. + + Instead of hardcoding which config attributes to show, we read the + uppercase class attrs on NormalizedConfig subclasses. These define + the canonical attribute mapping for each model type, e.g.: + + NormalizedTextConfig.VOCAB_SIZE = "vocab_size" + NormalizedVisionConfig.IMAGE_SIZE = "image_size" + + We also scan DUMMY_INPUT_GENERATOR_CLASSES for additional attrs + referenced via normalized_config.xxx in generator __init__ code. + + Returns: + Set of config attribute names relevant to I/O for this model. + """ + import inspect + import re + + attrs: set[str] = set() + try: + from ..export.io import _get_onnx_config + + onnx_config = _get_onnx_config(model_type, task, hf_config) + + # Primary: enumerate uppercase attrs on NormalizedConfig class. + # These ARE the canonical IO attribute mapping (e.g., VOCAB_SIZE="vocab_size"). + nc = getattr(onnx_config, "_normalized_config", None) + if nc is not None: + for attr_name in dir(type(nc)): + if attr_name.isupper() and not attr_name.startswith("_"): + # The value is the actual config attr name (e.g., "vocab_size") + val = getattr(type(nc), attr_name) + if isinstance(val, str): + # Handle dotted paths like "text_config.hidden_size" + leaf = val.split(".")[-1] + # Skip structural pointers (nested config references) + if not leaf.endswith("_config"): + attrs.add(leaf) + + # Secondary: scan generator __init__ for additional normalized_config refs + for gen_cls in getattr(onnx_config, "DUMMY_INPUT_GENERATOR_CLASSES", []): + try: + src = inspect.getsource(gen_cls.__init__) + except (TypeError, OSError): + continue + refs = re.findall(r"normalized_config\.(\w+)", src) + attrs.update(r for r in refs if r != "has_attribute") + except Exception as e: + logger.debug("Failed to discover IO attrs from OnnxConfig: %s", e) + + return attrs + + +def resolve_io_config( + config: PretrainedConfig, + *, + model_id: str | None = None, + model_type: str | None = None, + task: str | None = None, +) -> IOConfigInfo: """Extract IO configuration from HuggingFace config. - Extracts IO-related configuration values from a PretrainedConfig object. - For multimodal models (like CLIP), also checks nested configs (text_config, - vision_config) to gather all relevant settings. + Dynamically discovers which config attributes matter for I/O by + inspecting OnnxConfig's NormalizedConfig and input generators. + Falls back to a universal set of well-known attrs if OnnxConfig + lookup fails. No hardcoded model-specific attribute names. Args: config: HuggingFace PretrainedConfig object + model_id: Optional HF model ID for preprocessor_config.json fallback + model_type: HF model type for OnnxConfig lookup + task: Task name for OnnxConfig lookup Returns: IOConfigInfo with extracted configuration values """ - # Helper to get attribute from config or nested configs + # Dynamically discover nested configs (fixes D-2: no hardcoded names) + nested_configs = _find_nested_configs(config) + def get_config_attr( attr_name: str, - nested_configs: list[str] | None = None, - ) -> int | tuple[int, int] | None: - """Get attribute from main config or nested configs. - - Args: - attr_name: Attribute name to look for - nested_configs: List of nested config names to check (e.g., ["text_config"]) - - Returns: - Attribute value or None if not found - """ - # First check the main config + ) -> int | tuple[int, int] | list | None: + """Get attribute from main config or any nested config.""" value = getattr(config, attr_name, None) if value is not None: return value - # Check nested configs if provided - if nested_configs: - for nested_name in nested_configs: - nested_config = getattr(config, nested_name, None) - if nested_config is not None: - value = getattr(nested_config, attr_name, None) - if value is not None: - return value + for nested in nested_configs: + value = getattr(nested, attr_name, None) + if value is not None: + return value return None - # Text-related attributes - check main and text_config - max_position_embeddings = get_config_attr( - "max_position_embeddings", ["text_config"] - ) - vocab_size = get_config_attr("vocab_size", ["text_config"]) - - # Vision-related attributes - check main and vision_config - image_size = get_config_attr("image_size", ["vision_config"]) - patch_size = get_config_attr("patch_size", ["vision_config"]) - num_channels = get_config_attr("num_channels", ["vision_config"]) + # Step 1: Discover which attrs the OnnxConfig actually uses + io_attrs: set[str] = set() + if model_type and task: + io_attrs = _discover_io_attrs_from_onnx_config( + model_type, + task, + config, + ) - # Audio-related attributes - check main and audio_config - sampling_rate = get_config_attr("sampling_rate", ["audio_config"]) + # Step 2: Always include universal well-known IO attrs that Optimum's + # NormalizedConfig classes reference. These are framework conventions, + # not model-specific — they appear in NormalizedTextConfig, + # NormalizedVisionConfig, NormalizedSeq2SeqConfig, etc. + universal_io_attrs = { + "max_position_embeddings", + "vocab_size", + "image_size", + "patch_size", + "num_channels", + "input_size", + "sampling_rate", + "hidden_size", + "hidden_sizes", + } + io_attrs.update(universal_io_attrs) + + # Step 3: Look up each discovered attr + max_position_embeddings = get_config_attr("max_position_embeddings") + vocab_size = get_config_attr("vocab_size") + image_size = get_config_attr("image_size") + patch_size = get_config_attr("patch_size") + num_channels = get_config_attr("num_channels") + sampling_rate = get_config_attr("sampling_rate") + hidden_size = get_config_attr("hidden_size") + hidden_sizes = get_config_attr("hidden_sizes") + + # Step 4: Collect any extra attrs discovered from OnnxConfig + # that aren't in our dataclass fields + known_fields = { + "max_position_embeddings", + "vocab_size", + "image_size", + "patch_size", + "num_channels", + "sampling_rate", + "hidden_size", + "hidden_sizes", + } + extra: dict[str, int | str | list | None] = {} + for attr in io_attrs - known_fields: + val = get_config_attr(attr) + if val is not None: + extra[attr] = val + + # Step 5: Fallback — read image_size from preprocessor_config.json + # for models like ResNet where HF config lacks image_size + if image_size is None and model_id is not None: + try: + from ..export.io import _populate_image_size_from_preprocessor - # General attributes - check main config only - hidden_size = get_config_attr("hidden_size", ["text_config", "vision_config"]) + shape_kwargs: dict = {} + _populate_image_size_from_preprocessor(model_id, shape_kwargs) + if "height" in shape_kwargs: + h, w = shape_kwargs["height"], shape_kwargs["width"] + image_size = h if h == w else (h, w) + except Exception as e: + logger.debug("Failed to get image_size from preprocessor: %s", e) return IOConfigInfo( max_position_embeddings=max_position_embeddings, @@ -674,22 +809,27 @@ def get_config_attr( num_channels=num_channels, sampling_rate=sampling_rate, hidden_size=hidden_size, + hidden_sizes=hidden_sizes, + extra=extra if extra else None, ) -def resolve_processor(model_id: str) -> ProcessorInfo: +def resolve_processor( + model_id: str, + model_type: str | None = None, +) -> ProcessorInfo: """Resolve data processing classes for a HuggingFace model. Detects the processor/tokenizer/image_processor/feature_extractor classes associated with a model. Uses a multi-strategy approach: - 1. First tries to fetch config files from HuggingFace Hub without downloading - the full model (fast, no dependencies) - 2. Uses Auto classes to fill in any missing information that wasn't found - in the config files + 0. Check HF's IMAGE_PROCESSOR_MAPPING_NAMES for model_type-specific mapping + 1. Fetch config files from HuggingFace Hub (fast, no model download) + 2. Use Auto classes to fill in any remaining gaps Args: model_id: HuggingFace model identifier (e.g., "openai/clip-vit-base-patch32") + model_type: HuggingFace model type (e.g., "resnet") for registry lookup Returns: ProcessorInfo with detected class names for each processor type @@ -698,13 +838,47 @@ def resolve_processor(model_id: str) -> ProcessorInfo: tokenizer_class: str | None = None image_processor_class: str | None = None feature_extractor_class: str | None = None + # Source tracking + processor_source: str | None = None + tokenizer_source: str | None = None + image_processor_source: str | None = None + feature_extractor_source: str | None = None + + # Strategy 0: Check HF registry for the canonical image processor class + # for this model_type. This is authoritative — HF maps model types to + # their processor classes (e.g., resnet → ConvNextImageProcessor). + if model_type is not None: + try: + from transformers.models.auto.image_processing_auto import ( + IMAGE_PROCESSOR_MAPPING_NAMES, + ) + + mapping = IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type) + if mapping: + # mapping is (SlowProcessor, FastProcessor) or a string + image_processor_class = mapping[0] if isinstance(mapping, tuple) else mapping + image_processor_source = "hf_registry" + except Exception as e: + logger.debug("Registry lookup failed for %s: %s", model_type, e) # Strategy 1: Try to get class names from config files via HuggingFace Hub API # This is fast and doesn't require downloading/instantiating processors + # NOTE: These JSON keys (processor_class, image_processor_type, etc.) are + # standard HuggingFace config conventions, not model-specific hardcoding. try: - processor_class, tokenizer_class, image_processor_class, feature_extractor_class = ( - _resolve_processor_from_hub_configs(model_id) - ) + hub_proc, hub_tok, hub_img, hub_fe = _resolve_processor_from_hub_configs(model_id) + if hub_proc and processor_class is None: + processor_class = hub_proc + processor_source = "hub_config" + if hub_tok and tokenizer_class is None: + tokenizer_class = hub_tok + tokenizer_source = "hub_config" + if hub_img and image_processor_class is None: + image_processor_class = hub_img + image_processor_source = "hub_config" + if hub_fe and feature_extractor_class is None: + feature_extractor_class = hub_fe + feature_extractor_source = "hub_config" except Exception as e: logger.debug("Failed to resolve processors from hub configs: %s", e) @@ -719,14 +893,18 @@ def resolve_processor(model_id: str) -> ProcessorInfo: ) = _resolve_processor_from_auto_classes(model_id) # Fill in missing values from auto classes - if processor_class is None: + if processor_class is None and auto_processor: processor_class = auto_processor - if tokenizer_class is None: + processor_source = "auto_class" + if tokenizer_class is None and auto_tokenizer: tokenizer_class = auto_tokenizer - if image_processor_class is None: + tokenizer_source = "auto_class" + if image_processor_class is None and auto_image_processor: image_processor_class = auto_image_processor - if feature_extractor_class is None: + image_processor_source = "auto_class" + if feature_extractor_class is None and auto_feature_extractor: feature_extractor_class = auto_feature_extractor + feature_extractor_source = "auto_class" except Exception as e: logger.debug("Failed to resolve processors from auto classes: %s", e) @@ -735,6 +913,10 @@ def resolve_processor(model_id: str) -> ProcessorInfo: tokenizer_class=tokenizer_class, image_processor_class=image_processor_class, feature_extractor_class=feature_extractor_class, + processor_source=processor_source, + tokenizer_source=tokenizer_source, + image_processor_source=image_processor_source, + feature_extractor_source=feature_extractor_source, ) diff --git a/src/winml/modelkit/inspect/types.py b/src/winml/modelkit/inspect/types.py index 58b092fa1..60d1c156e 100644 --- a/src/winml/modelkit/inspect/types.py +++ b/src/winml/modelkit/inspect/types.py @@ -27,6 +27,7 @@ class TensorInfo: shape: tuple[int, ...] | None = None shape_desc: str | None = None # Human-readable shape like "[B, 3, 224, 224]" dynamic_axes: dict[int, str] | None = None # {0: "batch", 1: "sequence"} + value_range: tuple[float, float] | None = None # e.g., (0.0, 1.0) for pixel values @dataclass @@ -90,6 +91,11 @@ class ProcessorInfo: tokenizer_class: str | None = None # e.g., "CLIPTokenizerFast" image_processor_class: str | None = None # e.g., "CLIPImageProcessor" feature_extractor_class: str | None = None # e.g., "Wav2Vec2FeatureExtractor" + # Source tracking for transparency (e.g., ResNet -> ConvNextImageProcessorFast) + processor_source: str | None = None # "hub_config" | "auto_class" + image_processor_source: str | None = None + feature_extractor_source: str | None = None + tokenizer_source: str | None = None @dataclass @@ -110,6 +116,10 @@ class IOConfigInfo: # General hidden_size: int | None = None + hidden_sizes: list[int] | None = None # Per-stage hidden dims (e.g., ResNet) + + # Extra attrs discovered dynamically from OnnxConfig + extra: dict[str, Any] | None = None @dataclass diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index c65d5016b..efd6bc8c9 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -26,10 +26,6 @@ """ from .config import WinMLLoaderConfig, resolve_loader_config -from .hf import ( - load_hf_model, - resolve_hf_model_class, -) from .task import ( HF_TASK_DEFAULTS, get_supported_tasks, @@ -50,3 +46,26 @@ "resolve_loader_config", "resolve_task_and_model_class", ] + + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "load_hf_model": (".hf", "load_hf_model"), + "resolve_hf_model_class": (".hf", "resolve_hf_model_class"), +} + + +def __getattr__(name: str): + """Lazy-load heavy exports (hf.py imports transformers).""" + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return list(set(list(globals()) + __all__)) diff --git a/src/winml/modelkit/loader/task.py b/src/winml/modelkit/loader/task.py index b876c060f..1e6a2c4af 100644 --- a/src/winml/modelkit/loader/task.py +++ b/src/winml/modelkit/loader/task.py @@ -267,9 +267,10 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t try: model_class = TasksManager.get_model_class_for_task(task) - # Warn if TasksManager returns different class than architecture + # Informational: TasksManager may return a generic AutoModel* class + # that differs from config.architectures — this is expected behavior. if model_class.__name__ != arch_name: - logger.warning( + logger.info( "TasksManager returned %s, but config.architectures specifies %s. " "Honoring TasksManager's choice.", model_class.__name__, diff --git a/src/winml/modelkit/models/__init__.py b/src/winml/modelkit/models/__init__.py index 246207f01..2fa39bda4 100644 --- a/src/winml/modelkit/models/__init__.py +++ b/src/winml/modelkit/models/__init__.py @@ -55,15 +55,28 @@ # Lazy loading for modules that cause circular imports # WinMLAutoModel imports from loader/, which imports from models/ +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "WinMLAutoModel": (".auto", "WinMLAutoModel"), +} + + def __getattr__(name: str): """Lazy load modules that would cause circular imports.""" - if name == "WinMLAutoModel": - from .auto import WinMLAutoModel - - return WinMLAutoModel + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def __dir__() -> list[str]: + return list(set(list(globals()) + __all__)) + + __all__ = [ "HF_MODEL_CLASS_MAPPING", "MODEL_BUILD_CONFIGS", diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 3195d0fee..ae21e1881 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -334,6 +334,8 @@ def from_pretrained( from ..build import build_hf_model + # Pass resolved EP so the static analyzer targets only this EP + resolved_ep = config.compile.ep_config.provider if config.compile is not None else None result = build_hf_model( config=config, output_dir=output_dir, @@ -342,7 +344,7 @@ def from_pretrained( rebuild=force_rebuild, trust_remote_code=trust_remote_code, cache_key=cache_key, - ep=kwargs.get("ep"), + ep=resolved_ep, device=device, ) onnx_path = result.final_onnx_path @@ -353,11 +355,13 @@ def from_pretrained( winml_class = get_winml_class(model_type, task) logger.info("Creating inference wrapper: %s", winml_class.__name__) - return winml_class( + model = winml_class( onnx_path=onnx_path, config=hf_config, # HF PretrainedConfig for pipeline compatibility device=device, # pass user's original device string; WinMLSession handles "auto" ) + model._build_config = config # resolved build config (task, quant, compile) + return model @classmethod def supported_tasks(cls) -> list[str]: diff --git a/src/winml/modelkit/models/hf/bert.py b/src/winml/modelkit/models/hf/bert.py index 4838d3e37..d537c8010 100644 --- a/src/winml/modelkit/models/hf/bert.py +++ b/src/winml/modelkit/models/hf/bert.py @@ -33,9 +33,6 @@ BERT_CONFIG = WinMLBuildConfig( optim=WinMLOptimizationConfig( - gelu_fusion=True, - layer_norm_fusion=True, - matmul_add_fusion=True, clamp_constant_values=True, ), ) diff --git a/src/winml/modelkit/models/winml/base.py b/src/winml/modelkit/models/winml/base.py index 4d4e892bd..7f5e55c70 100644 --- a/src/winml/modelkit/models/winml/base.py +++ b/src/winml/modelkit/models/winml/base.py @@ -76,6 +76,9 @@ def __init__( self.config = config self._device = device + # Set by WinMLAutoModel.from_pretrained() after construction + self._build_config: Any = None + # Create WinMLSession (delegates ORT operations) self._session = WinMLSession( onnx_path=self._onnx_path, @@ -200,8 +203,26 @@ def perf(self, warmup: int = 0) -> contextlib.AbstractContextManager: @property def device(self) -> str: - """Current device.""" - return self._device + """Current device (delegates to session, resolved after compile).""" + return self._session.device + + @property + def task(self) -> str | None: + """Resolved task from build config, or None if unavailable.""" + build_config = getattr(self, "_build_config", None) + if build_config is not None: + loader = getattr(build_config, "loader", None) + if loader: + return loader.task + return None + + @property + def precision(self) -> str | None: + """Resolved precision from build config, or None if unavailable. + + TODO: derive from _build_config.quant.weight_type when ready. + """ + return None @property def dtype(self) -> torch.dtype: diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 120d7af3e..521f9c0f8 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -13,7 +13,6 @@ from __future__ import annotations -from .detection import is_compiled_onnx, is_quantized_onnx from .domains import ONNXDomain from .dtypes import SupportedONNXType, remove_optional_from_type_annotation from .external_data import copy_onnx_model @@ -46,3 +45,18 @@ "restore_metadata", "save_onnx", ] + + +def __getattr__(name: str): + """Lazy-load detection module to avoid circular import with compiler.""" + if name in ("is_compiled_onnx", "is_quantized_onnx"): + from .detection import is_compiled_onnx, is_quantized_onnx + + globals()["is_compiled_onnx"] = is_compiled_onnx + globals()["is_quantized_onnx"] = is_quantized_onnx + return globals()[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return __all__ diff --git a/src/winml/modelkit/onnx/detection.py b/src/winml/modelkit/onnx/detection.py index 007ef9267..c82c2ce34 100644 --- a/src/winml/modelkit/onnx/detection.py +++ b/src/winml/modelkit/onnx/detection.py @@ -14,7 +14,6 @@ import logging from typing import TYPE_CHECKING -from ..compiler.utils import QDQ_OP_TYPES from .persistence import load_onnx @@ -43,6 +42,8 @@ def _load_model_lightweight(model_path: Path, operation: str) -> onnx.ModelProto def is_quantized_onnx(model_path: Path) -> bool: """Check if ONNX model is quantized (contains QuantizeLinear/DequantizeLinear nodes).""" model = _load_model_lightweight(model_path, "quantization check") + from ..compiler import QDQ_OP_TYPES + return any(n.op_type in QDQ_OP_TYPES for n in model.graph.node) diff --git a/src/winml/modelkit/optim/api.py b/src/winml/modelkit/optim/api.py index 8993d76f7..f403d8c4c 100644 --- a/src/winml/modelkit/optim/api.py +++ b/src/winml/modelkit/optim/api.py @@ -160,6 +160,19 @@ def _convert_to_kwargs(config: dict[str, Any], all_caps: dict[str, Any]) -> dict return result +def _hack_inject_quant_preprocess_metadata(model: onnx.ModelProto) -> None: + """Inject metadata that signals pre-processing was done. + + Suppresses the ORT quantization warning: + 'Please consider to run pre-processing before quantization.' + """ + metadata = {"onnx.quant.pre_process": "onnxruntime.quant"} + if model.metadata_props: + for prop in model.metadata_props: + metadata[prop.key] = prop.value + onnx.helper.set_model_props(model, metadata) + + def optimize_onnx( model: str | Path | onnx.ModelProto, output: str | Path | None = None, @@ -259,6 +272,9 @@ def optimize_onnx( optimized_model = optimizer.optimize(loaded_model, **optimizer_kwargs) optimized_model = optimizer.optimize(optimized_model, **optimizer_kwargs) + # Step 9.5: Inject quant pre-processing metadata to suppress ORT warning + _hack_inject_quant_preprocess_metadata(optimized_model) + # Step 10: Save if output path provided if output is not None: output_path = Path(output) diff --git a/src/winml/modelkit/optim/registry.py b/src/winml/modelkit/optim/registry.py index a3fc3591a..e5117a528 100644 --- a/src/winml/modelkit/optim/registry.py +++ b/src/winml/modelkit/optim/registry.py @@ -204,7 +204,9 @@ def validate(config: dict[str, Any], capabilities: dict[str, CapabilityDef]) -> errors = [] for key, value in config.items(): - cap = capabilities.get(key) + # Accept both snake_case and kebab-case (normalize to kebab-case) + normalized_key = key.replace("_", "-") + cap = capabilities.get(normalized_key) or capabilities.get(key) if cap is None: errors.append(f"Unknown capability '{key}'") continue diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index fe3770eae..2e6c2c279 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -17,7 +17,6 @@ """ from .config import QuantizeResult, WinMLQuantizationConfig -from .quantizer import quantize_onnx __all__ = [ @@ -25,3 +24,25 @@ "WinMLQuantizationConfig", "quantize_onnx", ] + + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "quantize_onnx": (".quantizer", "quantize_onnx"), +} + + +def __getattr__(name: str): + """Lazy-load quantizer (imports onnxruntime.quantization).""" + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, attr_name) + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return list(set(list(globals()) + __all__)) diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index a8b7da920..c5d2b06f1 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -33,6 +33,28 @@ logger = logging.getLogger(__name__) +@contextmanager +def _suppress_native_output(log_path: str | Path | None = None): + """Redirect native stdout to a log file (or devnull). + + QNN SDK compiler writes progress to stdout via native C++ code that + Python logging/warnings cannot intercept. Only redirects stdout — + stderr is left untouched so Rich displays and Python logging work. + """ + if log_path is not None: + fd = os.open(str(log_path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC) + else: + fd = os.open(os.devnull, os.O_WRONLY) + old_stdout = os.dup(1) + os.dup2(fd, 1) + os.close(fd) + try: + yield + finally: + os.dup2(old_stdout, 1) + os.close(old_stdout) + + class SessionState(Enum): """WinMLSession states.""" @@ -172,7 +194,8 @@ def __init__( if not self._onnx_path.exists(): raise FileNotFoundError(f"ONNX model not found: {onnx_path}") - self._device = device + # HF Pipeline may pass torch.device; coerce to string for downstream .lower() calls + self._device = str(device) if not isinstance(device, str) else device self._ep = ep.lower() if ep else None self._persist_jit = ep_config.enable_ep_context if ep_config else False self._embed_context = ep_config.embed_context if ep_config else False @@ -234,6 +257,10 @@ def compile(self) -> None: logger.info("Using cached EPContext: %s", ctx_path) # Compile if needed (persist_jit=True and no cache) + # Native QNN SDK compiler writes progress to stdout/stderr; + # redirect to log file to keep the console clean. + compile_log = self._onnx_path.parent / "compile.log" + if self._persist_jit and model_path == self._onnx_path: # Skip ModelCompiler if input model is already compiled (EPContext) if is_compiled_onnx(self._onnx_path): @@ -247,7 +274,8 @@ def compile(self) -> None: str(self._onnx_path), embed_compiled_data_into_model=self._embed_context, ) - model_compiler.compile_to_file(str(ctx_path)) + with _suppress_native_output(compile_log): + model_compiler.compile_to_file(str(ctx_path)) # Use compiled model if it was created if ctx_path.exists(): @@ -261,7 +289,8 @@ def compile(self) -> None: try: # Create InferenceSession sess_options = self._build_session_options(target_device) - session = ort.InferenceSession(str(model_path), sess_options=sess_options) + with _suppress_native_output(compile_log): + session = ort.InferenceSession(str(model_path), sess_options=sess_options) # Log which providers were selected by ORT (based on policy) actual_providers = session.get_providers() @@ -288,6 +317,15 @@ def compile(self) -> None: self._session = session self._state = SessionState.COMPILED + # Resolve device label from the primary provider ORT actually selected + if self._device == "auto" and actual_providers: + from ..sysinfo.device import get_ep_device_map + + ep_map = get_ep_device_map() + resolved = ep_map.get(actual_providers[0]) + if resolved and "/" not in resolved: + self._device = resolved + def run( self, inputs: dict[str, Any], diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 07c1b7eb6..d06459c6f 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -22,6 +22,7 @@ def model_option(required=True): """ return click.option( "--model", + "-m", required=required, type=click.Path(exists=True, path_type=Path), help="Path to ONNX model file to analyze", @@ -78,7 +79,7 @@ def device_option(required=True, optional_message=None, default="NPU"): "--device", required=required, default=default if not required else None, - type=click.Choice(SUPPORTED_DEVICES, case_sensitive=False), + type=click.Choice(SUPPORTED_DEVICES, case_sensitive=True), help=help_text, ) @@ -86,8 +87,11 @@ def device_option(required=True, optional_message=None, default="NPU"): def verbosity_options(f): """Add verbose and quiet logging options to a Click command. - Adds --verbose/-v and --quiet/-q flags that control logging verbosity. - These options are automatically passed to the decorated function. + Adds --verbose/-v (stackable: -v, -vv, -vvv) and --quiet/-q flags. + The decorated function receives ``verbose`` (int, count of -v flags) + and ``quiet`` (bool). + + See :mod:`winml.modelkit.utils.logging` for the verbosity convention. Args: f: Click command function to decorate @@ -105,8 +109,7 @@ def verbosity_options(f): f = click.option( "--verbose", "-v", - is_flag=True, - default=False, - help="Enable verbose logging to stderr", + count=True, + help="Increase verbosity (-v=INFO, -vv=DEBUG)", )(f) return f # noqa: RET504 diff --git a/src/winml/modelkit/utils/console.py b/src/winml/modelkit/utils/console.py new file mode 100644 index 000000000..2377aadf9 --- /dev/null +++ b/src/winml/modelkit/utils/console.py @@ -0,0 +1,561 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Shared console output utilities for winml CLI commands. + +Provides consistent Rich-based formatting for: +- Config command: headers, I/O specs, resolution summary +- Build command: cascading StageLive, setup/stages sections, graph summary + +All output goes to stderr via Console(stderr=True) so stdout stays clean +for machine-readable output (JSON configs, build manifests). +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from rich.console import Console, Group, RenderableType +from rich.live import Live +from rich.text import Text + + +if TYPE_CHECKING: + from ..export.config import WinMLExportConfig + +logger = logging.getLogger(__name__) + +HEAVY_SEP = "\u2550" * 60 # ═ +LIGHT_SEP = "\u2500" * 60 # ─ +MAX_BAR_WIDTH = 36 + +# Stage status icons +ICON_RUNNING = "\u23f3" # ⏳ +ICON_DONE = "\u2705" # ✅ +ICON_SKIP = "\u23f8\ufe0f " # ⏸️ +ICON_ERROR = "\u274c" # ❌ + + +def get_console() -> Console: + """Return a Console that prints to stderr.""" + return Console(stderr=True) + + +# ══════════════════════════════════════════════════════════════════════════ +# SHARED FORMATTING +# ══════════════════════════════════════════════════════════════════════════ + + +def print_command_header( + console: Console, + title: str, + subtitle: str | None = None, +) -> None: + """Print a command header block (═══ separators).""" + console.print() + console.print(HEAVY_SEP) + label = f"[bold]{title}[/bold]" + if subtitle: + label += f" [dim]({subtitle})[/dim]" + console.print(label) + console.print(HEAVY_SEP) + + +def print_kv( + console: Console, + label: str, + value: str, + *, + note: str | None = None, + icon: str = "", +) -> None: + """Print a key-value line with optional note.""" + line = f" {icon} [bold]{label:<14}[/bold] [cyan]{value}[/cyan]" + if note: + line += f" [dim]({note})[/dim]" + console.print(line) + + +def print_success(console: Console, message: str) -> None: + """Print a green success line with check icon.""" + console.print(f" [green]{ICON_DONE} {message}[/green]") + + +def print_error( + console: Console, + message: str, + hint: str | None = None, +) -> None: + """Print a red error line with optional hint.""" + console.print(f" [red]{ICON_ERROR} {message}[/red]") + if hint: + console.print(f" [dim]\U0001f4a1 {hint}[/dim]") + + +# ══════════════════════════════════════════════════════════════════════════ +# CONFIG COMMAND HELPERS +# ══════════════════════════════════════════════════════════════════════════ + + +def print_io_specs_detail( + console: Console, + export_config: WinMLExportConfig, +) -> None: + """Print resolved I/O specs — always full detail, aligned columns.""" + inputs = export_config.input_tensors or [] + outputs = export_config.output_tensors or [] + + for i, t in enumerate(inputs): + name = t.name or "(unnamed)" + shape_str = str(list(t.shape)) if getattr(t, "shape", None) else "dynamic" + dtype_str = getattr(t, "dtype", None) or "?" + label = "Input: " if i == 0 else " " + console.print(f" {label}[cyan]{name:<18}[/cyan] {shape_str:<14} [dim]{dtype_str}[/dim]") + for i, t in enumerate(outputs): + name = t.name or "(unnamed)" + # Fix #3: OutputTensorSpec only has name — show name only + label = "Output: " if i == 0 else " " + console.print(f" {label}[cyan]{name}[/cyan]") + + +def print_io_specs_na(console: Console, reason: str = "") -> None: + """Print I/O specs not-available line (e.g., ONNX mode).""" + msg = reason or "inferred from ONNX graph at build time" + console.print(f" \U0001f4d0 [bold]I/O specs:[/bold] [dim]N/A \u2014 {msg}[/dim]") + + +# ══════════════════════════════════════════════════════════════════════════ +# BUILD COMMAND — SETUP / STAGES SECTIONS +# ══════════════════════════════════════════════════════════════════════════ + + +def print_setup( + console: Console, + *, + model: str, + config: str, + output: str, + source: str = "HuggingFace", +) -> None: + """Print the 🔧 Setup section header.""" + console.print() + console.print(HEAVY_SEP) + console.print(f"[bold]\U0001f527 Setup \u2014 {source}[/bold]") + console.print(HEAVY_SEP) + console.print(f" \U0001f4e6 [bold]{'Model:':<10}[/bold] [cyan]{model}[/cyan]") + console.print(f" \U0001f4c1 [bold]{'Config:':<10}[/bold] [cyan]{config}[/cyan]") + console.print(f" \U0001f4c2 [bold]{'Output:':<10}[/bold] [cyan]{output}[/cyan]") + console.print() + + +def print_stages_header(console: Console) -> None: + """Print the 🎯 Stages section header.""" + console.print(HEAVY_SEP) + console.print("[bold]\U0001f3af Stages[/bold]") + console.print(HEAVY_SEP) + + +def print_final( + console: Console, + elapsed: float, + artifact: str, + stage_timings: list[tuple[str, float | None]] | None = None, +) -> None: + """Print the 📊 Summary section with stage timing breakdown. + + Args: + stage_timings: list of (stage_name, elapsed_seconds | None for skipped) + """ + console.print() + console.print(HEAVY_SEP) + console.print("[bold]\U0001f4ca Summary[/bold]") + console.print(HEAVY_SEP) + console.print(f"{ICON_DONE} [bold green]Build complete in {elapsed:.1f}s[/bold green]") + if stage_timings: + for name, t in stage_timings: + if t is not None: + console.print(f" {name:<12} [green]{t:.1f}s[/green]") + else: + console.print(f" {name:<12} [dim]skipped[/dim]") + console.print(f"\U0001f4e6 Final artifact: [bold]{artifact}[/bold]") + console.print() + + +def print_stage_skip( + console: Console, + name: str, + reason: str = "", +) -> None: + """Print a skipped stage as static text (no Live needed).""" + line = Text() + line.append(f"{ICON_SKIP} ") + line.append(name.capitalize(), style="dim") + if reason: + line.append(f" {reason}", style="dim italic") + console.print(line) + console.print() + + +def detect_model_source(model_id: str | None) -> str: + """Detect model source for Setup header.""" + if model_id is None: + return "HuggingFace" + p = Path(model_id) + if p.suffix == ".onnx": + return "ONNX" + if p.is_dir(): + return "Local" + return "HuggingFace" + + +def fmt_size(size_bytes: int | float) -> str: + """Format file size from bytes to human-readable string.""" + mb = size_bytes / (1024 * 1024) + if mb >= 1000: + return f"{mb / 1000:.1f} GB" + return f"{mb:.1f} MB" + + +def get_onnx_total_size(onnx_path: Path) -> int: + """Get total ONNX model size including external data files. + + When ONNX models use external data storage, the main .onnx file + is just metadata (~1-2MB) while weights live in separate .data files. + This function sums all related files. + """ + total = onnx_path.stat().st_size + try: + import onnx + from onnx import external_data_helper as edh + + model = onnx.load(str(onnx_path), load_external_data=False) + seen: set[str] = set() + for init in model.graph.initializer: + if edh.uses_external_data(init): + ext_info = edh.ExternalDataInfo(init) + if ext_info.location and ext_info.location not in seen: + seen.add(ext_info.location) + ext_path = onnx_path.parent / ext_info.location + if ext_path.exists(): + total += ext_path.stat().st_size + except Exception: + logger.debug( + "Could not read external data for %s; reporting main file size only", + onnx_path, + exc_info=True, + ) + return total + + +# ══════════════════════════════════════════════════════════════════════════ +# BUILD COMMAND — STAGE LIVE (cascading Live per stage) +# ══════════════════════════════════════════════════════════════════════════ + + +class StageLive: + """Live region for a single build stage. + + Each stage gets its own Rich Live context. When the stage completes, + Live stops and the final frame persists as static text (transient=False). + The next stage starts a new Live below. + + Usage:: + + with StageLive("export", console) as sl: + sl.kv("Task:", "fill-mask [dim](auto-detected)[/dim]") + sl.io_input("input_ids", "[1, 128]", "int64") + # ... blocking work ... + sl.set_done(12.3) + sl.artifact("output/export.onnx", 438_200_000) + """ + + def __init__(self, name: str, console: Console) -> None: + self._name = name + self._console = console + self._lines: list[RenderableType] = [] + self._live: Live | None = None + self._status_idx: int = 0 + + def __enter__(self) -> StageLive: + self._lines = [self._make_running_line()] + self._status_idx = 0 + self._live = Live( + self._render(), + console=self._console, + refresh_per_second=15, + transient=False, + ) + self._live.start() + return self + + def __exit__(self, *_: object) -> None: + if self._live: + self._live.update(self._render()) + self._live.stop() + self._live = None + + def _render(self) -> Group: + return Group(*self._lines) + + def _update(self) -> None: + if self._live: + self._live.update(self._render()) + + # ── Status line management ──────────────────────────────────── + + def _make_running_line(self, detail: str = "") -> Text: + line = Text() + line.append(f"{ICON_RUNNING} ") + line.append(self._name.capitalize(), style="bold yellow") + if detail: + line.append(f" {detail}", style="dim") + return line + + def set_status(self, detail: str) -> None: + """Update the running status text.""" + self._lines[self._status_idx] = self._make_running_line(detail) + self._update() + + def set_done(self, elapsed: float) -> None: + """Mark stage as done.""" + line = Text() + line.append(f"{ICON_DONE} ") + line.append(f"{self._name.capitalize():<48}", style="green") + line.append(f"{elapsed:.1f}s", style="green") + self._lines[self._status_idx] = line + self._update() + + def set_error(self, error: str = "") -> None: + """Mark stage as failed.""" + line = Text() + line.append(f"{ICON_ERROR} ") + line.append(self._name.capitalize(), style="bold red") + if error: + line.append(f" {error}", style="red") + self._lines[self._status_idx] = line + self._update() + + # ── Detail lines (indented under stage) ─────────────────────── + + def detail(self, markup: str) -> None: + """Add a Rich markup detail line.""" + self._lines.append(Text.from_markup(f" {markup}")) + self._update() + + def kv(self, label: str, value: str) -> None: + """Add a key-value detail line with aligned columns.""" + self._lines.append(Text.from_markup(f" {label:<14}{value}")) + self._update() + + def artifact(self, path: str, size_bytes: int | float) -> None: + """Add artifact line (always last in stage).""" + label = "\U0001f4e6 Artifact:" + self._lines.append( + Text.from_markup(f" {label:<14}[dim]{path}[/dim] ({fmt_size(size_bytes)})") + ) + self._update() + + def blank(self) -> None: + """Add a blank line.""" + self._lines.append(Text("")) + self._update() + + # ── I/O lines (aligned columns) ────────────────────────────── + + def io_input( + self, + name: str, + shape: str, + dtype: str, + *, + first: bool = True, + ) -> None: + """Add an input tensor line.""" + label = "Input: " if first else " " + self._lines.append( + Text.from_markup(f" {label}[cyan]{name:<18}[/cyan] {shape:<14} [dim]{dtype}[/dim]") + ) + self._update() + + def io_output( + self, + name: str, + shape: str, + dtype: str, + *, + first: bool = True, + ) -> None: + """Add an output tensor line.""" + label = "Output: " if first else " " + self._lines.append( + Text.from_markup(f" {label}[cyan]{name:<18}[/cyan] {shape:<14} [dim]{dtype}[/dim]") + ) + self._update() + + # ── EP analyzer bar lines (for optimize stage) ──────────────── + + def ep_bar_add(self, ep_name: str, total: int = 0) -> int: + """Add a placeholder EP bar line, return index.""" + idx = len(self._lines) + line = Text() + line.append(" - ") + line.append(f"{ep_name:<28}", style="dim") + if total: + line.append("\u2591" * MAX_BAR_WIDTH, style="dim") + self._lines.append(line) + self._update() + return idx + + def ep_bar_update( + self, + idx: int, + ep_name: str, + s: int, + p: int, + u: int, + total: int = 0, + ) -> None: + """Update an EP bar line by index with progress.""" + line = Text() + line.append(" - ") + line.append(f"{ep_name:<28}", style="cyan") + line.append_text(_spu_text(s, p, u)) + line.append(" ") + # Scale bar proportional to total (not analyzed count) + analyzed = s + p + u + anchor = max(total, analyzed, 1) + line.append_text(_build_bar_scaled(s, p, u, anchor)) + remaining = total - analyzed if total else 0 + if remaining > 0: + rem_w = max( + 1, + round(remaining / anchor * MAX_BAR_WIDTH), + ) + line.append("\u2591" * rem_w, style="dim") + self._lines[idx] = line + self._update() + + +# ══════════════════════════════════════════════════════════════════════════ +# EP ANALYZER BAR HELPERS +# ══════════════════════════════════════════════════════════════════════════ + + +def _build_bar(s: int, p: int, u: int) -> Text: + """Build a compact stacked bar for S/P/U counts.""" + total = s + p + u + if total == 0: + return Text() + return _build_bar_scaled(s, p, u, total) + + +def _build_bar_scaled(s: int, p: int, u: int, anchor: int) -> Text: + """Build a stacked bar scaled to an anchor total.""" + if anchor == 0: + return Text() + bar = Text() + s_w = max(1, round(s / anchor * MAX_BAR_WIDTH)) if s else 0 + p_w = max(1, round(p / anchor * MAX_BAR_WIDTH)) if p else 0 + u_w = max(1, round(u / anchor * MAX_BAR_WIDTH)) if u else 0 + # Clamp total to MAX_BAR_WIDTH + used = s_w + p_w + u_w + if used > MAX_BAR_WIDTH: + overflow = used - MAX_BAR_WIDTH + # Shrink from the largest segment + if s_w >= p_w and s_w >= u_w: + s_w = max(1, s_w - overflow) + elif p_w >= u_w: + p_w = max(1, p_w - overflow) + else: + u_w = max(1, u_w - overflow) + bar.append("\u2588" * s_w, style="green") + if p_w: + bar.append("\u2588" * p_w, style="yellow") + if u_w: + bar.append("\u2588" * u_w, style="red") + return bar + + +def _spu_text(s: int, p: int, u: int) -> Text: + """Build 'S/P/U' colored count text.""" + t = Text() + t.append(str(s), style="bold green") + t.append("/", style="dim") + t.append(str(p), style="bold yellow" if p > 0 else "dim") + t.append("/", style="dim") + t.append(str(u), style="bold red" if u > 0 else "dim") + return t + + +# ══════════════════════════════════════════════════════════════════════════ +# ONNX GRAPH SUMMARY (for compile stage) +# ══════════════════════════════════════════════════════════════════════════ + + +def get_onnx_graph_summary(model_path: Path | str) -> dict[str, Any]: + """Extract graph summary from ONNX model without loading weights. + + Returns dict with: + op_counts: dict[str, int] — node count per op_type (excl QDQ) + inputs: list[dict] — [{name, shape, dtype}, ...] + outputs: list[dict] — [{name, shape, dtype}, ...] + num_initializers: int + total_nodes: int + """ + import onnx + from onnx import TensorProto + + _dtype_map = { + TensorProto.FLOAT: "float32", + TensorProto.FLOAT16: "float16", + TensorProto.INT32: "int32", + TensorProto.INT64: "int64", + TensorProto.INT8: "int8", + TensorProto.UINT8: "uint8", + TensorProto.BOOL: "bool", + TensorProto.STRING: "string", + } + + model = onnx.load(str(model_path), load_external_data=False) + graph = model.graph + + # Op counts (exclude QDQ nodes from display) + qdq_ops = {"QuantizeLinear", "DequantizeLinear"} + op_counts: dict[str, int] = {} + for node in graph.node: + if node.op_type not in qdq_ops: + op_counts[node.op_type] = op_counts.get(node.op_type, 0) + 1 + + # Sort by count descending + op_counts = dict(sorted(op_counts.items(), key=lambda x: x[1], reverse=True)) + + # Inputs (exclude initializer names — they appear in graph.input too) + init_names = {init.name for init in graph.initializer} + + def _parse_io(value_info: Any) -> dict: + name = value_info.name + tt = value_info.type.tensor_type + dtype = _dtype_map.get(tt.elem_type, f"type({tt.elem_type})") + dims = [] + if tt.HasField("shape"): + for d in tt.shape.dim: + if d.dim_param: + dims.append(d.dim_param) + else: + dims.append(d.dim_value) + return {"name": name, "shape": dims, "dtype": dtype} + + inputs = [_parse_io(inp) for inp in graph.input if inp.name not in init_names] + outputs = [_parse_io(out) for out in graph.output] + + return { + "op_counts": op_counts, + "inputs": inputs, + "outputs": outputs, + "num_initializers": len(graph.initializer), + "total_nodes": len(graph.node), + } diff --git a/src/winml/modelkit/utils/logging.py b/src/winml/modelkit/utils/logging.py index 3cd16ab23..94a0b3ba9 100644 --- a/src/winml/modelkit/utils/logging.py +++ b/src/winml/modelkit/utils/logging.py @@ -2,27 +2,52 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Logging utilities for ModelKit.""" +"""Logging utilities for ModelKit. + +Verbosity Convention (adopted from pip, ansible, pytest): +========================================================= + + Flag Level Value Use case + ---- ----- ----- -------- + -q ERROR 40 Errors only (quiet / scripting) + (default) WARNING 30 Warnings + errors (production default) + -v INFO 20 Operational progress messages + -vv DEBUG 10 Developer-level tracing + --debug DEBUG 10 Alias for -vv (backward compat) + + Formula: level = WARNING - (verbosity * 10) -> 30, 20, 10 + Quiet: level = ERROR (40) + +All log output goes to stderr so stdout stays clean for structured data +(JSON, compact output, piped commands). +""" import logging import sys -def configure_logging(verbose: bool = False, quiet: bool = False) -> None: - """Configure logging level based on verbosity flags. +def configure_logging( + verbosity: int = 0, + quiet: bool = False, + *, + # Backward-compat: accept old bool signature + verbose: bool = False, +) -> None: + """Configure root logger based on verbosity level. Args: - verbose: Enable verbose logging (DEBUG level) - quiet: Enable quiet mode (ERROR level only) - - Default level is INFO when both flags are False. + verbosity: Number of ``-v`` flags (0=WARNING, 1=INFO, 2+=DEBUG). + quiet: If True, override to ERROR level regardless of verbosity. + verbose: **Deprecated bool compat** — treated as verbosity=1 when + True and verbosity is 0. Existing callers that pass + ``verbose=True`` keep working without changes. """ - if quiet: - log_level = logging.ERROR - elif verbose: - log_level = logging.DEBUG - else: - log_level = logging.INFO + # Backward compat: bool verbose → int, also handles count passthrough + if verbose and verbosity == 0: + verbosity = int(verbose) + + # Clamp between DEBUG (10) and WARNING (30); quiet overrides to ERROR + log_level = logging.ERROR if quiet else max(logging.DEBUG, logging.WARNING - verbosity * 10) logging.basicConfig( level=log_level, diff --git a/tests/conftest.py b/tests/conftest.py index 61ed770f8..3307d76ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,14 +35,25 @@ def _skip_winml_ep_init(request: pytest.FixtureRequest, monkeypatch: pytest.Monk """Mock WinML EP initialization for non-e2e tests.""" if "e2e" in {m.name for m in request.node.iter_markers()}: return - monkeypatch.setattr( - "winml.modelkit.session.session.WinMLSession._init_winml_eps_once", - classmethod(lambda cls: None), - ) - monkeypatch.setattr( - "winml.modelkit.analyze.core.runtime_checker_query.RuntimeCheckerQuery._is_ep_available_locally", - lambda self: False, - ) + try: + monkeypatch.setattr( + "winml.modelkit.session.session.WinMLSession._init_winml_eps_once", + classmethod(lambda cls: None), + ) + except ImportError as e: + import warnings + + warnings.warn(f"Could not mock _init_winml_eps_once: {e}", stacklevel=2) + + try: + monkeypatch.setattr( + "winml.modelkit.analyze.core.runtime_checker_query.RuntimeCheckerQuery._is_ep_available_locally", + lambda self: False, + ) + except ImportError as e: + import warnings + + warnings.warn(f"Could not mock _is_ep_available_locally: {e}", stacklevel=2) # ============================================================================= diff --git a/tests/regression/test_design_gaps.py b/tests/regression/test_design_gaps.py index 5c336b93f..55a9f0412 100644 --- a/tests/regression/test_design_gaps.py +++ b/tests/regression/test_design_gaps.py @@ -65,19 +65,19 @@ def test_optimize_list_rewrites_is_ascii_safe(self): # =========================================================================== -# M-1: --list-tasks NOT in inspect --help +# M-1: --list-tasks IS in inspect --help (implemented in MVP v2 port) # =========================================================================== -class TestM1ListTasksAbsent: - """Document that --list-tasks is not implemented in inspect.""" +class TestM1ListTasksPresent: + """Verify that --list-tasks is implemented in inspect.""" - def test_list_tasks_not_in_help(self): - """inspect --help should NOT contain --list-tasks option.""" + def test_list_tasks_in_help(self): + """inspect --help should contain --list-tasks option.""" runner = CliRunner() result = runner.invoke(inspect, ["--help"], obj={}) assert result.exit_code == 0 - assert "--list-tasks" not in result.output + assert "--list-tasks" in result.output # =========================================================================== diff --git a/tests/test_import_time.py b/tests/test_import_time.py new file mode 100644 index 000000000..939275051 --- /dev/null +++ b/tests/test_import_time.py @@ -0,0 +1,460 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Regression tests for lazy loading and import-time tracking. + +These tests ensure that importing ModelKit modules and running CLI commands +do not pull in heavy ML dependencies (torch, transformers, optimum, etc.) +unless the functionality actually requires them. + +Every test runs in a fresh subprocess so sys.modules starts clean. + +Test Categories: + (A) Per-module isolation: verify each winml.modelkit.* package's import budget + (B) Per-command: verify each CLI command's import budget (--help and --model) +""" + +from __future__ import annotations + +import subprocess +import sys +import textwrap + +import pytest + + +# --------------------------------------------------------------------------- +# Discovery — dynamic lists from the actual codebase +# --------------------------------------------------------------------------- + + +# Discover commands by scanning the commands/ directory (same logic as cli.py) +def _discover_command_names() -> list[str]: + from pathlib import Path + + root = Path(__file__).resolve().parent.parent + commands_dir = root / "src" / "winml" / "modelkit" / "commands" + return sorted(f.stem for f in commands_dir.glob("*.py") if not f.name.startswith("_")) + + +_CLI_COMMANDS = _discover_command_names() + +HEAVY_PREFIXES = ("torch", "transformers", "optimum", "diffusers", "sklearn") + + +def _run_in_subprocess(code: str) -> subprocess.CompletedProcess[str]: + """Run Python code in a fresh subprocess via a temp script approach.""" + return subprocess.run( # noqa: S603 + [sys.executable, "-c", textwrap.dedent(code)], + capture_output=True, + text=True, + timeout=120, + ) + + +def assert_no_heavy_imports( + setup_code: str, + *, + forbidden: tuple[str, ...] = HEAVY_PREFIXES, + allowed: tuple[str, ...] = (), +) -> None: + """Assert that running setup_code loads no forbidden modules. + + Args: + setup_code: Python code to execute (will be dedented). + forbidden: Module prefixes that must NOT appear in sys.modules. + allowed: Module prefixes to exclude from the forbidden check. + """ + script = textwrap.dedent(f"""\ + import sys + {setup_code} + loaded = sorted(set( + m.split('.')[0] for m in sys.modules + if m.startswith({forbidden!r}) + )) + allowed = set({allowed!r}) + bad = [m for m in loaded if m not in allowed] + if bad: + print(f"FAIL: unexpected heavy modules: {{bad}}", file=sys.stderr) + print(f" allowed: {{allowed}}", file=sys.stderr) + sys.exit(1) + """) + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=120, + ) + assert result.returncode == 0, f"Import budget violated.\nstderr: {result.stderr.strip()}" + + +def assert_cli_no_heavy_imports( + cli_args: list[str], + *, + allowed: tuple[str, ...] = (), +) -> None: + """Assert that invoking ``main(cli_args)`` loads no forbidden modules. + + Uses try/except to catch SystemExit and Click errors gracefully. + """ + args_str = repr(cli_args) + script = textwrap.dedent(f"""\ + import sys + from winml.modelkit.cli import main + import click + try: + main({args_str}, standalone_mode=False) + except (SystemExit, click.exceptions.UsageError, Exception): + pass + loaded = sorted(set( + m.split('.')[0] for m in sys.modules + if m.startswith({HEAVY_PREFIXES!r}) + )) + allowed = set({allowed!r}) + bad = [m for m in loaded if m not in allowed] + if bad: + print(f"FAIL: unexpected heavy modules: {{bad}}", file=sys.stderr) + print(f" allowed: {{allowed}}", file=sys.stderr) + sys.exit(1) + """) + result = subprocess.run( # noqa: S603 + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=120, + ) + assert result.returncode == 0, ( + f"Import budget violated for args {cli_args}.\nstderr: {result.stderr.strip()}" + ) + + +# =========================================================================== +# (A) Per-Module Isolation Tests +# =========================================================================== + + +class TestModuleIsolation: + """Verify each winml.modelkit.* module's import budget.""" + + @pytest.mark.parametrize( + "module", + [ + "winml.modelkit", + "winml.modelkit.cli", + "winml.modelkit.cache", + "winml.modelkit.compiler", + "winml.modelkit.config", + "winml.modelkit.core", + "winml.modelkit.export", + "winml.modelkit.loader", + "winml.modelkit.onnx", + "winml.modelkit.optim", + "winml.modelkit.optracing", + "winml.modelkit.quant", + "winml.modelkit.session", + "winml.modelkit.analyze", + "winml.modelkit.pattern", + "winml.modelkit.sysinfo", + "winml.modelkit.utils", + ], + ) + def test_module_no_heavy_deps(self, module: str) -> None: + """Importing this module must not load torch/transformers/optimum.""" + assert_no_heavy_imports(f"import {module}") + + @pytest.mark.parametrize( + ("module", "allowed"), + [ + ("winml.modelkit.build", ("torch", "torchgen")), + ("winml.modelkit.data", ("torch", "torchgen", "torchvision")), + ( + "winml.modelkit.datasets", + ("torch", "torchgen", "torchvision", "transformers", "sklearn"), + ), + ( + "winml.modelkit.eval", + ("torch", "torchgen", "torchvision", "transformers", "sklearn"), + ), + ("winml.modelkit.inspect", (*HEAVY_PREFIXES, "torchgen", "torchvision")), + ("winml.modelkit.models", (*HEAVY_PREFIXES, "torchgen", "torchvision")), + ], + ) + def test_module_with_expected_deps(self, module: str, allowed: tuple[str, ...]) -> None: + """Modules that legitimately need heavy deps — verify nothing extra.""" + assert_no_heavy_imports(f"import {module}", allowed=allowed) + + def test_lazy_access_triggers_import(self) -> None: + """Accessing WinMLAutoModel must trigger the full import chain.""" + script = textwrap.dedent("""\ + import sys + from winml.modelkit import WinMLAutoModel + assert 'torch' in sys.modules, ( + 'torch should be loaded after accessing WinMLAutoModel' + ) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"Lazy access did not trigger torch.\nstderr: {result.stderr}" + ) + + # -- Gap 2: lazy-trigger tests for subpackage __getattr__ implementations -- + + def test_lazy_core_get_io_config(self) -> None: + """core.get_io_config must be lazily accessible and callable.""" + script = textwrap.dedent("""\ + import winml.modelkit.core + obj = winml.modelkit.core.get_io_config + assert obj is not None + assert callable(obj) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"core.get_io_config not lazily accessible.\nstderr: {result.stderr}" + ) + + def test_lazy_export_resolve_io_specs(self) -> None: + """export.resolve_io_specs must be lazily accessible and callable.""" + script = textwrap.dedent("""\ + import winml.modelkit.export + obj = winml.modelkit.export.resolve_io_specs + assert obj is not None + assert callable(obj) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"export.resolve_io_specs not lazily accessible.\nstderr: {result.stderr}" + ) + + def test_lazy_loader_load_hf_model(self) -> None: + """loader.load_hf_model must be lazily accessible and callable.""" + script = textwrap.dedent("""\ + import winml.modelkit.loader + obj = winml.modelkit.loader.load_hf_model + assert obj is not None + assert callable(obj) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"loader.load_hf_model not lazily accessible.\nstderr: {result.stderr}" + ) + + def test_lazy_quant_quantize_onnx(self) -> None: + """quant.quantize_onnx must be lazily accessible and callable.""" + script = textwrap.dedent("""\ + import winml.modelkit.quant + obj = winml.modelkit.quant.quantize_onnx + assert obj is not None + assert callable(obj) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"quant.quantize_onnx not lazily accessible.\nstderr: {result.stderr}" + ) + + # -- Gap 3: AttributeError negative test -- + + def test_nonexistent_attr_raises(self) -> None: + """Importing a nonexistent attribute must raise ImportError.""" + script = textwrap.dedent("""\ + try: + from winml.modelkit import nonexistent_xyz_12345 + except ImportError: + pass # expected + else: + raise AssertionError( + "Expected ImportError for nonexistent attribute" + ) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"Nonexistent attr did not raise ImportError.\nstderr: {result.stderr}" + ) + + # -- Gap 4: __dir__ correctness test -- + + def test_dir_includes_lazy_attrs(self) -> None: + """dir(winml.modelkit) must include lazy attrs without loading torch.""" + script = textwrap.dedent("""\ + import sys + import winml.modelkit + assert "WinMLAutoModel" in dir(winml.modelkit), ( + "WinMLAutoModel missing from dir()" + ) + loaded = sorted(set( + m.split('.')[0] for m in sys.modules + if m.startswith(('torch', 'transformers', 'optimum', 'diffusers', 'sklearn')) + )) + if loaded: + print(f"FAIL: dir() triggered heavy imports: {loaded}", file=sys.stderr) + sys.exit(1) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, f"dir() test failed.\nstderr: {result.stderr}" + + +# =========================================================================== +# (C) _LAZY_IMPORTS Dict Consistency Tests +# =========================================================================== + +_LAZY_MODULES = [ + "winml.modelkit", + "winml.modelkit.core", + "winml.modelkit.export", + "winml.modelkit.loader", + "winml.modelkit.quant", + "winml.modelkit.models", + "winml.modelkit.onnx", +] + + +class TestLazyImportsDict: + """Verify the standardized _LAZY_IMPORTS pattern across all modules.""" + + @pytest.mark.parametrize("module", _LAZY_MODULES) + def test_lazy_imports_dict_exists(self, module: str) -> None: + """Each module must define a non-empty _LAZY_IMPORTS dict.""" + script = textwrap.dedent(f"""\ + import {module} as mod + lazy = getattr(mod, '_LAZY_IMPORTS', None) + assert lazy is not None, '_LAZY_IMPORTS not found on {module}' + assert isinstance(lazy, dict), ( + f'_LAZY_IMPORTS is {{type(lazy).__name__}}, expected dict' + ) + assert len(lazy) > 0, '_LAZY_IMPORTS is empty' + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"_LAZY_IMPORTS check failed for {module}.\nstderr: {result.stderr.strip()}" + ) + + @pytest.mark.parametrize("module", _LAZY_MODULES) + def test_lazy_imports_all_consistent(self, module: str) -> None: + """Every key in _LAZY_IMPORTS must also appear in __all__.""" + script = textwrap.dedent(f"""\ + import {module} as mod + lazy = set(mod._LAZY_IMPORTS.keys()) + all_ = set(mod.__all__) + missing = lazy - all_ + assert not missing, f'In _LAZY_IMPORTS but not __all__: {{missing}}' + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"_LAZY_IMPORTS/__all__ drift in {module}.\nstderr: {result.stderr.strip()}" + ) + + @pytest.mark.parametrize("module", _LAZY_MODULES) + def test_lazy_imports_all_resolvable(self, module: str) -> None: + """Every _LAZY_IMPORTS entry must resolve to a real attribute.""" + script = textwrap.dedent(f"""\ + import importlib + import {module} as mod + errors = [] + for attr_name, submodule_path in mod._LAZY_IMPORTS.items(): + try: + sub = importlib.import_module(submodule_path) + if not hasattr(sub, attr_name): + errors.append( + f'{{attr_name}}: {{submodule_path}} has no attribute {{attr_name}}' + ) + except ImportError as exc: + errors.append(f'{{attr_name}}: cannot import {{submodule_path}} ({{exc}})') + if errors: + raise AssertionError( + f'Unresolvable _LAZY_IMPORTS in {module}:\\n' + '\\n'.join(errors) + ) + """) + result = _run_in_subprocess(script) + assert result.returncode == 0, ( + f"Unresolvable _LAZY_IMPORTS in {module}.\nstderr: {result.stderr.strip()}" + ) + + +# =========================================================================== +# (B) Per-Command Tests -- --help (no heavy imports at command load time) +# =========================================================================== + + +class TestCommandHelp: + """Verify ``winml`` and ``winml --help`` do not load heavy deps.""" + + def test_winml_bare(self) -> None: + """Bare ``winml`` (no args) must not load heavy deps.""" + assert_cli_no_heavy_imports([]) + + def test_winml_help(self) -> None: + """``winml --help`` must not load heavy deps.""" + assert_cli_no_heavy_imports(["--help"]) + + @pytest.mark.parametrize("cmd", _CLI_COMMANDS) + def test_command_help_no_heavy_deps(self, cmd: str) -> None: + """``winml --help`` must not load heavy deps.""" + assert_cli_no_heavy_imports([cmd, "--help"]) + + +# =========================================================================== +# (B) Per-Command Tests — with --model (actual command execution) +# =========================================================================== + +_FAKE_ONNX = "nonexistent_test_model.onnx" +_HF_MODEL = "microsoft/resnet-50" + + +class TestCommandWithModel: + """Verify import budgets when commands are invoked with --model. + + Commands that operate on ONNX files should NOT need torch/transformers. + Commands that operate on HF models legitimately need them. + + We use a fake model path so commands fail at file I/O, but the import + chain is already established by that point. + """ + + @pytest.mark.parametrize( + ("cmd_args", "allowed"), + [ + # ONNX-path commands — should NOT need torch/transformers + ( + ["compile", "--model", _FAKE_ONNX, "-o", "o.onnx", "--ep", "qnn"], + (), + ), + ( + ["quantize", "--model", _FAKE_ONNX, "-o", "o.onnx", "--ep", "qnn"], + (), + ), + ( + ["optimize", "--model", _FAKE_ONNX, "-o", "o.onnx"], + ("torch", "torchgen"), # ORT tools.__init__ pulls torch + ), + ( + ["perf", "--model", _FAKE_ONNX], + (), + ), + ( + ["static-analyzer", "check", "--model", _FAKE_ONNX, "--ep", "qnn"], + ("torch", "torchgen"), # ORT tools.__init__ pulls torch + ), + # HF model commands — legitimately need heavy deps + ( + ["inspect", "-m", _HF_MODEL], + (*HEAVY_PREFIXES, "torchgen", "torchvision"), + ), + ( + ["config", "-m", _HF_MODEL, "--device", "npu", "--precision", "int8"], + (*HEAVY_PREFIXES, "torchgen", "torchvision"), + ), + ], + ids=[ + "compile-onnx", + "quantize-onnx", + "optimize-onnx", + "perf-onnx", + "static-analyzer-onnx", + "inspect-hf", + "config-hf", + ], + ) + def test_command_import_budget(self, cmd_args: list[str], allowed: tuple[str, ...]) -> None: + """Verify each command's import budget with --model.""" + assert_cli_no_heavy_imports(cmd_args, allowed=allowed) diff --git a/tests/unit/analyze/test_static_analyzer_cli.py b/tests/unit/analyze/test_static_analyzer_cli.py index 12e6941cf..7752cef0b 100644 --- a/tests/unit/analyze/test_static_analyzer_cli.py +++ b/tests/unit/analyze/test_static_analyzer_cli.py @@ -37,10 +37,16 @@ def runner() -> CliRunner: @pytest.fixture def mock_analyzer_result() -> Mock: - """Create a mock AnalysisOutput result.""" + """Create a mock AnalysisResult (returned by ONNXStaticAnalyzer.analyze). + + The command accesses ``result.output.results`` (list of EPSupport) for + Rich live display, ``result.is_fully_supported()`` for exit code, and + ``result.to_json()`` for JSON output. + """ mock_result = Mock() mock_result.is_fully_supported.return_value = True mock_result.get_unsupported_operators.return_value = [] + mock_result.output.results = [] # empty EP results list (iterable) mock_result.to_json.return_value = json.dumps( { "analyzer_version": "0.1.0", @@ -64,6 +70,7 @@ def mock_analyzer_partial_support() -> Mock: mock_result = Mock() mock_result.is_fully_supported.return_value = False mock_result.get_unsupported_operators.return_value = ["Conv", "Gemm", "Add"] + mock_result.output.results = [] # empty EP results list (iterable) mock_result.to_json.return_value = json.dumps( { "analyzer_version": "0.1.0", @@ -609,7 +616,7 @@ def test_analyze_called_with_correct_parameters( # Verify analyze was called with correct parameters mock_instance.analyze.assert_called_once() call_kwargs = mock_instance.analyze.call_args[1] - assert call_kwargs["model_path"] == model_file + assert call_kwargs["model_path"] == str(model_file) assert call_kwargs["ep"] == "OpenVINOExecutionProvider" assert call_kwargs["device"] == "GPU" assert call_kwargs["enable_information"] is True diff --git a/tests/unit/build/test_hf.py b/tests/unit/build/test_hf.py index e3ab1e56a..38ce5ec0b 100644 --- a/tests/unit/build/test_hf.py +++ b/tests/unit/build/test_hf.py @@ -635,8 +635,8 @@ def test_autoconf_converges_in_one_iteration( # Autoconf is part of optimize, not a separate stage assert "optimize" in result.stages_completed - # Single analyze call (no autoconf suggestions, no loop) - assert m.call_count == 1 + # Two analyze calls: one in loop (no autoconf), one final validation + assert m.call_count == 2 def test_autoconf_discovers_and_reoptimizes( self, tmp_path: Path, sample_config_no_quant_compile, mock_pipeline @@ -650,7 +650,7 @@ def test_autoconf_discovers_and_reoptimizes( with patch( "winml.modelkit.build.common.analyze_onnx", - side_effect=[result_with_gelu, result_converged], + side_effect=[result_with_gelu, result_converged, result_converged], ) as m_analyze: result = build_hf_model( config=sample_config_no_quant_compile, @@ -661,8 +661,8 @@ def test_autoconf_discovers_and_reoptimizes( ) assert "optimize" in result.stages_completed - # 2 analyze calls: initial (found gelu) + after re-optimize (converged) - assert m_analyze.call_count == 2 + # 3 analyze calls: initial (found gelu) + after re-optimize (converged) + final validation + assert m_analyze.call_count == 3 # optimize_onnx called: once initial + once re-optimize in autoconf assert mock_pipeline["optimize"].call_count == 2 @@ -739,7 +739,7 @@ def test_manifest_records_analyze_details( with patch( "winml.modelkit.build.common.analyze_onnx", - side_effect=[result_with_gelu, result_converged], + side_effect=[result_with_gelu, result_converged, result_converged], ): result = build_hf_model( config=sample_config_no_quant_compile, @@ -774,7 +774,7 @@ def test_autoconf_merges_config_for_downstream( with patch( "winml.modelkit.build.common.analyze_onnx", - side_effect=[result_with_flags, result_converged], + side_effect=[result_with_flags, result_converged, result_converged], ): build_hf_model( config=sample_config_no_quant_compile, @@ -847,7 +847,7 @@ def test_post_export_qdq_still_compiles( def test_post_export_qdq_runs_analyze_only( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """Analyze is called (via run_analyze_only) but optimize is not.""" + """Pre-quantized path runs optimize but skips autoconf (no analyze).""" mock_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -856,7 +856,8 @@ def test_post_export_qdq_runs_analyze_only( output_dir=output_dir, pytorch_model=mock_pipeline["model"], ) - mock_pipeline["analyze"].assert_called() + # max_optim_iterations=0 means no analyze loop runs + mock_pipeline["analyze"].assert_not_called() mock_pipeline["optimize"].assert_called_once() def test_skip_optimize_kwarg(self, tmp_path: Path, sample_config, mock_pipeline) -> None: diff --git a/tests/unit/build/test_onnx.py b/tests/unit/build/test_onnx.py index 5d14f69be..a4e6b41d8 100644 --- a/tests/unit/build/test_onnx.py +++ b/tests/unit/build/test_onnx.py @@ -400,7 +400,7 @@ def test_pre_quantized_still_compiles( def test_pre_quantized_runs_analyze_only( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """QDQ model runs analyze (via run_analyze_only) but not optimize.""" + """Pre-quantized path runs optimize but skips autoconf (no analyze).""" mock_onnx_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -409,8 +409,8 @@ def test_pre_quantized_runs_analyze_only( config=sample_onnx_config, output_dir=output_dir, ) - # analyze_onnx should be called (via run_analyze_only) - mock_onnx_pipeline["analyze"].assert_called() + # max_optim_iterations=0 means no analyze loop runs + mock_onnx_pipeline["analyze"].assert_not_called() mock_onnx_pipeline["optimize"].assert_called_once() def test_skip_optimize_kwarg( diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index d88dec71a..12fdcaf77 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -5,20 +5,22 @@ """Tests for build CLI command — mock-based, no network, no actual builds. -Tests the CLI wrapper around build_hf_model() API. +Tests the CLI wrapper around _run_single_build() internal pipeline. NO WinMLAutoModel involvement. """ from __future__ import annotations import json -from pathlib import Path +from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest from click.testing import CliRunner -from winml.modelkit.build.hf import BuildResult + +if TYPE_CHECKING: + from pathlib import Path @pytest.fixture(autouse=True) @@ -74,31 +76,15 @@ def sample_config_file(tmp_path: Path) -> Path: @pytest.fixture def mock_build_api(): - """Mock build_hf_model to avoid actual pipeline execution.""" - result = BuildResult( - output_dir=Path("/fake/output"), - final_onnx_path=Path("/fake/output/model.onnx"), - config_path=Path("/fake/output/winml_build_config.json"), - stages_completed=["export", "optimize", "quantize", "compile"], - stages_skipped=[], - stage_timings={"export": 1.0, "optimize": 0.5, "quantize": 2.0, "compile": 0.3}, - elapsed=3.8, - ) - with patch("winml.modelkit.build.build_hf_model", return_value=result) as mock: + """Mock _run_single_build to avoid actual pipeline execution.""" + with patch("winml.modelkit.commands.build._run_single_build", return_value=None) as mock: yield mock @pytest.fixture def mock_build_reused(): - """Mock build_hf_model returning a reused result.""" - result = BuildResult( - output_dir=Path("/fake/output"), - final_onnx_path=Path("/fake/output/model.onnx"), - config_path=Path("/fake/output/winml_build_config.json"), - reused=True, - elapsed=0.01, - ) - with patch("winml.modelkit.build.build_hf_model", return_value=result) as mock: + """Mock _run_single_build returning None (reuse is handled internally).""" + with patch("winml.modelkit.commands.build._run_single_build", return_value=None) as mock: yield mock @@ -197,7 +183,6 @@ def test_basic_build( ) assert result.exit_code == 0, f"Build failed: {result.output}" assert mock_build_api.called - assert "Build complete" in result.output def test_model_id_passed( self, @@ -214,16 +199,16 @@ def test_model_id_passed( obj={"debug": False}, ) call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("model_id") == "microsoft/resnet-50" + assert call_kwargs["model_id"] == "microsoft/resnet-50" - def test_model_required( + def test_model_optional_for_random_weight( self, runner: CliRunner, sample_config_file: Path, mock_build_api: MagicMock, tmp_path: Path, ) -> None: - """Omitting -m/--model is rejected because it is now required.""" + """Omitting -m/--model is valid — triggers random-weight build.""" from winml.modelkit.commands.build import build result = runner.invoke( @@ -231,8 +216,9 @@ def test_model_required( ["-c", str(sample_config_file), "-o", str(tmp_path)], obj={"debug": False}, ) - assert result.exit_code != 0 - assert "model" in result.output.lower() + assert result.exit_code == 0 + call_kwargs = mock_build_api.call_args.kwargs + assert call_kwargs["model_id"] is None def test_rebuild_passed( self, @@ -249,7 +235,7 @@ def test_rebuild_passed( obj={"debug": False}, ) call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("rebuild") is True + assert call_kwargs["rebuild"] is True def test_default_rebuild_false( self, @@ -266,7 +252,7 @@ def test_default_rebuild_false( obj={"debug": False}, ) call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("rebuild") is False + assert call_kwargs["rebuild"] is False # ============================================================================= @@ -291,7 +277,7 @@ def test_no_quant_sets_none( ["-c", str(sample_config_file), "-m", "test", "-o", str(tmp_path), "--no-quant"], obj={"debug": False}, ) - config = mock_build_api.call_args.kwargs.get("config") + config = mock_build_api.call_args.kwargs["config"] assert config.quant is None def test_no_compile_sets_none( @@ -308,7 +294,7 @@ def test_no_compile_sets_none( ["-c", str(sample_config_file), "-m", "test", "-o", str(tmp_path), "--no-compile"], obj={"debug": False}, ) - config = mock_build_api.call_args.kwargs.get("config") + config = mock_build_api.call_args.kwargs["config"] assert config.compile is None def test_no_quant_no_compile_together( @@ -334,7 +320,7 @@ def test_no_quant_no_compile_together( ], obj={"debug": False}, ) - config = mock_build_api.call_args.kwargs.get("config") + config = mock_build_api.call_args.kwargs["config"] assert config.quant is None assert config.compile is None @@ -362,8 +348,8 @@ def test_reuse_message( obj={"debug": False}, ) assert result.exit_code == 0 - assert "Existing artifact" in result.output - assert "--rebuild" in result.output + # Reuse detection is handled inside _run_single_build; verify it was called + assert mock_build_reused.called # ============================================================================= @@ -407,7 +393,7 @@ def test_build_failure_reported( ) -> None: from winml.modelkit.commands.build import build - with patch("winml.modelkit.build.build_hf_model") as mock: + with patch("winml.modelkit.commands.build._run_single_build") as mock: mock.side_effect = RuntimeError("ONNX export failed") result = runner.invoke( @@ -423,7 +409,7 @@ def test_value_error_becomes_usage_error( ) -> None: from winml.modelkit.commands.build import build - with patch("winml.modelkit.build.build_hf_model") as mock: + with patch("winml.modelkit.commands.build._run_single_build") as mock: mock.side_effect = ValueError("Invalid config") result = runner.invoke( @@ -458,7 +444,7 @@ def test_ep_flag_passed( obj={"debug": False}, ) call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("ep") == "qnn" + assert call_kwargs["ep"] == "qnn" def test_device_flag_passed( self, @@ -475,7 +461,7 @@ def test_device_flag_passed( obj={"debug": False}, ) call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("device") == "NPU" + assert call_kwargs["device"] == "NPU" # ============================================================================= @@ -533,7 +519,7 @@ def test_build_auto_detect_onnx_file( sample_config_file: Path, tmp_path: Path, ) -> None: - """When -m points to an existing .onnx file, dispatches to build_onnx_model.""" + """When -m points to an existing .onnx file, dispatches to _build_onnx_pipeline.""" from winml.modelkit.commands.build import build # Create a fake .onnx file on disk @@ -542,16 +528,9 @@ def test_build_auto_detect_onnx_file( output_dir = tmp_path / "out" - onnx_result = BuildResult( - output_dir=output_dir, - final_onnx_path=output_dir / "model.onnx", - config_path=output_dir / "winml_build_config.json", - stages_completed=["optimize"], - stages_skipped=["quantize", "compile"], - stage_timings={"optimize": 0.5}, - elapsed=0.5, - ) - with patch("winml.modelkit.build.build_onnx_model", return_value=onnx_result) as mock_onnx: + with patch( + "winml.modelkit.commands.build._build_onnx_pipeline", return_value=[] + ) as mock_onnx: result = runner.invoke( build, ["-c", str(sample_config_file), "-m", str(onnx_file), "-o", str(output_dir)], @@ -569,7 +548,7 @@ def test_build_auto_detect_hf_model( mock_build_api: MagicMock, tmp_path: Path, ) -> None: - """When -m is a HF model ID (not .onnx), dispatches to build_hf_model.""" + """When -m is a HF model ID (not .onnx), dispatches to _run_single_build.""" from winml.modelkit.commands.build import build output_dir = tmp_path / "out" @@ -581,7 +560,7 @@ def test_build_auto_detect_hf_model( assert result.exit_code == 0, f"Build failed: {result.output}" assert mock_build_api.called call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("model_id") == "microsoft/resnet-50" + assert call_kwargs["model_id"] == "microsoft/resnet-50" def test_build_onnx_suffix_but_not_exists_uses_hf( self, @@ -604,7 +583,7 @@ def test_build_onnx_suffix_but_not_exists_uses_hf( assert result.exit_code == 0, f"Build failed: {result.output}" assert mock_build_api.called call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("model_id") == "nonexistent.onnx" + assert call_kwargs["model_id"] == "nonexistent.onnx" # ============================================================================= @@ -641,8 +620,8 @@ def test_no_analyze_sets_zero_iterations( ["-c", str(sample_config_file), "-m", "test", "-o", str(tmp_path), "--no-analyze"], obj={"debug": False}, ) - call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("hack_max_optim_iterations") == 0 + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert extra.get("hack_max_optim_iterations") == 0 def test_max_optim_iterations_passed( self, @@ -667,8 +646,8 @@ def test_max_optim_iterations_passed( ], obj={"debug": False}, ) - call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("hack_max_optim_iterations") == 5 + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert extra.get("hack_max_optim_iterations") == 5 def test_no_analyze_takes_precedence_over_max_iterations( self, @@ -695,8 +674,8 @@ def test_no_analyze_takes_precedence_over_max_iterations( ], obj={"debug": False}, ) - call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("hack_max_optim_iterations") == 0 + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert extra.get("hack_max_optim_iterations") == 0 def test_default_no_analyzer_kwargs( self, @@ -712,8 +691,8 @@ def test_default_no_analyzer_kwargs( ["-c", str(sample_config_file), "-m", "test", "-o", str(tmp_path)], obj={"debug": False}, ) - call_kwargs = mock_build_api.call_args.kwargs - assert "hack_max_optim_iterations" not in call_kwargs + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert "hack_max_optim_iterations" not in extra # ============================================================================= @@ -736,24 +715,16 @@ def test_no_optimize_passed_to_onnx_build( sample_config_file: Path, tmp_path: Path, ) -> None: - """--no-optimize passes skip_optimize=True to build_onnx_model.""" + """--no-optimize passes skip_optimize=True via extra_kwargs.""" from winml.modelkit.commands.build import build # Create a fake .onnx file for ONNX path detection onnx_file = tmp_path / "model.onnx" onnx_file.write_text("fake") - result_obj = BuildResult( - output_dir=tmp_path / "out", - final_onnx_path=tmp_path / "out" / "model.onnx", - config_path=tmp_path / "out" / "config.json", - stages_completed=["compile"], - stages_skipped=["optimize", "quantize"], - stage_timings={"compile": 0.3}, - elapsed=1.0, - ) - - with patch("winml.modelkit.build.build_onnx_model", return_value=result_obj) as mock_build: + with patch( + "winml.modelkit.commands.build._run_single_build", return_value=None + ) as mock_build: result = runner.invoke( build, [ @@ -769,8 +740,8 @@ def test_no_optimize_passed_to_onnx_build( ) assert result.exit_code == 0, f"Failed: {result.output}" - call_kwargs = mock_build.call_args.kwargs - assert call_kwargs.get("skip_optimize") is True + extra = mock_build.call_args.kwargs["extra_kwargs"] + assert extra.get("skip_optimize") is True def test_no_optimize_passed_to_hf_build( self, @@ -779,7 +750,7 @@ def test_no_optimize_passed_to_hf_build( tmp_path: Path, mock_build_api: MagicMock, ) -> None: - """--no-optimize passes skip_optimize=True to build_hf_model.""" + """--no-optimize passes skip_optimize=True via extra_kwargs.""" from winml.modelkit.commands.build import build result = runner.invoke( @@ -797,8 +768,8 @@ def test_no_optimize_passed_to_hf_build( ) assert result.exit_code == 0, f"Failed: {result.output}" - call_kwargs = mock_build_api.call_args.kwargs - assert call_kwargs.get("skip_optimize") is True + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert extra.get("skip_optimize") is True def test_no_optimize_default_not_present( self, @@ -807,7 +778,7 @@ def test_no_optimize_default_not_present( tmp_path: Path, mock_build_api: MagicMock, ) -> None: - """Without --no-optimize, skip_optimize is not in kwargs.""" + """Without --no-optimize, skip_optimize is not in extra_kwargs.""" from winml.modelkit.commands.build import build runner.invoke( @@ -816,5 +787,5 @@ def test_no_optimize_default_not_present( obj={"debug": False}, ) - call_kwargs = mock_build_api.call_args.kwargs - assert "skip_optimize" not in call_kwargs + extra = mock_build_api.call_args.kwargs["extra_kwargs"] + assert "skip_optimize" not in extra diff --git a/tests/unit/commands/test_inspect_cli.py b/tests/unit/commands/test_inspect_cli.py index a5dca0f24..c8089026e 100644 --- a/tests/unit/commands/test_inspect_cli.py +++ b/tests/unit/commands/test_inspect_cli.py @@ -59,13 +59,10 @@ def mock_inspect_result() -> MagicMock: return result -# The inspect command uses deferred imports inside the function body: -# from ..inspect import inspect_model, InspectError, ... -# from ..inspect.formatter import output_json, output_table -# -# Since `from X import Y` resolves Y from X at import time, we must -# patch at the SOURCE modules so the deferred import picks up the mock. -_INSPECT_MODEL = "winml.modelkit.inspect.inspect_model" +# The inspect command calls _inspect_model_v2 (a module-level function in +# commands/inspect.py) then dispatches to output_json / output_table from +# the formatter module. We patch at their actual locations. +_INSPECT_MODEL = "winml.modelkit.commands.inspect._inspect_model_v2" _OUTPUT_JSON = "winml.modelkit.inspect.formatter.output_json" _OUTPUT_TABLE = "winml.modelkit.inspect.formatter.output_table" diff --git a/tests/unit/commands/test_perf_cli.py b/tests/unit/commands/test_perf_cli.py index 7cc2d4335..cdd62b9c8 100644 --- a/tests/unit/commands/test_perf_cli.py +++ b/tests/unit/commands/test_perf_cli.py @@ -270,15 +270,14 @@ def test_no_quantize_false_passes_no_override(self) -> None: override = mock_from_pretrained.call_args.kwargs["config"] assert override is None - def test_cli_onnx_goes_through_perfbenchmark(self, runner: CliRunner, tmp_path: Path) -> None: - """CLI with .onnx file should route through PerfBenchmark, not _run_onnx_benchmark.""" + def test_cli_onnx_goes_through_onnx_benchmark(self, runner: CliRunner, tmp_path: Path) -> None: + """CLI with .onnx file should route through _run_onnx_benchmark.""" onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake onnx") with ( - patch.object( - PerfBenchmark, - "run", + patch( + "winml.modelkit.commands.perf._run_onnx_benchmark", return_value=MagicMock(), ) as mock_run, patch( diff --git a/tests/unit/config/test_build.py b/tests/unit/config/test_build.py index 4e03eae7c..9849dd710 100644 --- a/tests/unit/config/test_build.py +++ b/tests/unit/config/test_build.py @@ -35,7 +35,6 @@ ) from winml.modelkit.export import ( InputTensorSpec, - ONNXConfigNotFoundError, OutputTensorSpec, WinMLExportConfig, resolve_io_specs, @@ -339,18 +338,18 @@ def test_merge_config_called_with_override( class TestRegistryShortCircuit: - """Tests for the registry export config merge in generate_build_config. + """Tests for the registry short-circuit path in generate_build_config. - Optimum is always tried first (may fail for unsupported models). - Registered export config is always merged on top (registry wins). + When MODEL_BUILD_CONFIGS has a registered config with input_tensors, + the Optimum _resolve_export_config_from_specs() call is skipped. """ - def test_optimum_fails_registry_fills_in( + def test_registry_with_input_tensors_skips_optimum( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, ) -> None: - """When Optimum fails, registered export config provides I/O specs.""" + """Registry config with input_tensors skips Optimum lookup.""" blip_like_export = WinMLExportConfig( input_tensors=[ InputTensorSpec(name="pixel_values", dtype="float32", shape=(1, 3, 384, 384)), @@ -373,38 +372,30 @@ def test_optimum_fails_registry_fills_in( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=ONNXConfigNotFoundError("blip not supported"), - ), + ) as mock_optimum, patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"blip": blip_like_config}), ): result = generate_build_config("Salesforce/blip-image-captioning-base") - # Registry fills in the I/O specs after Optimum failure + # Optimum should NOT have been called + mock_optimum.assert_not_called() + # Result should have the registered input_tensors assert result.export.input_tensors is not None assert len(result.export.input_tensors) == 2 assert result.export.input_tensors[0].name == "pixel_values" - def test_optimum_succeeds_registry_overrides( + def test_registry_without_export_falls_through_to_optimum( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, mock_loader_config: WinMLLoaderConfig, + mock_export_config: WinMLExportConfig, ) -> None: - """When Optimum succeeds, registered export config overrides on top.""" - optimum_export = WinMLExportConfig( - input_tensors=[ - InputTensorSpec(name="input_ids", dtype="int64", shape=(1, 16)), - ], - output_tensors=[OutputTensorSpec(name="logits")], - ) - # Registry overrides with a different shape - registry_export = WinMLExportConfig( - input_tensors=[ - InputTensorSpec(name="input_ids", dtype="int64", shape=(1, 512)), - ], - output_tensors=[OutputTensorSpec(name="logits")], + """Registry config without export falls through to Optimum.""" + # BERT_CONFIG has optim only, no export + bert_like_config = WinMLBuildConfig( + optim=WinMLOptimizationConfig(gelu_fusion=True), ) - registry_config = WinMLBuildConfig(export=registry_export) with ( patch( @@ -413,26 +404,25 @@ def test_optimum_succeeds_registry_overrides( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - return_value=optimum_export, - ), - patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"bert": registry_config}), + return_value=mock_export_config, + ) as mock_optimum, + patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"bert": bert_like_config}), ): - result = generate_build_config("bert-base-uncased") + generate_build_config("bert-base-uncased") - # Registry wins — shape is 512, not Optimum's 16 - assert result.export.input_tensors[0].shape == (1, 512) + # Optimum SHOULD have been called + mock_optimum.assert_called_once() - def test_registry_without_export_uses_optimum( + def test_registry_with_none_input_tensors_falls_through( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, mock_loader_config: WinMLLoaderConfig, mock_export_config: WinMLExportConfig, ) -> None: - """Registry config without export uses Optimum result unchanged.""" - # BERT_CONFIG has optim only, no export - bert_like_config = WinMLBuildConfig( - optim=WinMLOptimizationConfig(gelu_fusion=True), + """Registry config with export but input_tensors=None falls through.""" + config_with_empty_export = WinMLBuildConfig( + export=WinMLExportConfig(), # input_tensors defaults to None ) with ( @@ -443,20 +433,23 @@ def test_registry_without_export_uses_optimum( patch( "winml.modelkit.config.build._resolve_export_config_from_specs", return_value=mock_export_config, + ) as mock_optimum, + patch( + "winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", + {"bert": config_with_empty_export}, ), - patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"bert": bert_like_config}), ): - result = generate_build_config("bert-base-uncased") + generate_build_config("bert-base-uncased") - # No registered export → Optimum result used as-is - assert result.export.input_tensors == mock_export_config.input_tensors + # Optimum SHOULD have been called (input_tensors is None) + mock_optimum.assert_called_once() - def test_registry_merge_does_not_mutate_singleton( + def test_registry_deepcopy_prevents_mutation( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, ) -> None: - """merge_config produces a new object, not mutating the registry.""" + """Registry export config is deepcopied, preventing singleton mutation.""" original_export = WinMLExportConfig( input_tensors=[ InputTensorSpec(name="pixel_values", dtype="float32", shape=(1, 3, 224, 224)), @@ -478,19 +471,21 @@ def test_registry_merge_does_not_mutate_singleton( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=ONNXConfigNotFoundError("unsupported"), ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"some-vision": registry_config}), ): result = generate_build_config("some/vision-model") - # Result should NOT be the same object as registry export + # Result export should NOT be the same object as registry export assert result.export is not original_export - # Content should be preserved + assert result.export.input_tensors is not original_export.input_tensors + # Content should be preserved (deepcopy correctness) + assert len(result.export.input_tensors) == 1 assert result.export.input_tensors[0].name == "pixel_values" assert result.export.input_tensors[0].shape == (1, 3, 224, 224) + assert result.export.input_tensors[0].dtype == "float32" - def test_underscore_normalization( + def test_registry_underscore_normalization( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, @@ -517,22 +512,27 @@ def test_underscore_normalization( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=ONNXConfigNotFoundError("unsupported"), - ), + ) as mock_optimum, # Registry uses hyphens patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"clip-text-model": clip_config}), ): result = generate_build_config("openai/clip-vit-base-patch32") + # Underscore model_type should match hyphenated registry key + mock_optimum.assert_not_called() assert result.export.input_tensors[0].name == "input_ids" - def test_no_registry_no_optimum_returns_empty( + def test_registry_empty_list_input_tensors_skips_optimum( self, mock_hf_config: MagicMock, mock_model_class: MagicMock, mock_loader_config: WinMLLoaderConfig, ) -> None: - """No registry + Optimum fails → empty export config (no crash).""" + """Registry config with input_tensors=[] skips Optimum (is not None).""" + config_with_empty_list = WinMLBuildConfig( + export=WinMLExportConfig(input_tensors=[]), + ) + with ( patch( "winml.modelkit.config.build.resolve_loader_config", @@ -540,14 +540,38 @@ def test_no_registry_no_optimum_returns_empty( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=ONNXConfigNotFoundError("unsupported"), + ) as mock_optimum, + patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"bert": config_with_empty_list}), + ): + result = generate_build_config("bert-base-uncased") + + # [] is not None, so short-circuit fires + mock_optimum.assert_not_called() + assert result.export.input_tensors == [] + + def test_registry_miss_falls_through_to_optimum( + self, + mock_hf_config: MagicMock, + mock_model_class: MagicMock, + mock_loader_config: WinMLLoaderConfig, + mock_export_config: WinMLExportConfig, + ) -> None: + """Model not in registry at all falls through to Optimum.""" + with ( + patch( + "winml.modelkit.config.build.resolve_loader_config", + return_value=(mock_loader_config, mock_hf_config, mock_model_class), ), - patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), + patch( + "winml.modelkit.config.build._resolve_export_config_from_specs", + return_value=mock_export_config, + ) as mock_optimum, + patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), # empty registry ): result = generate_build_config("some/unknown-model") - # Empty export config — no crash, downstream will handle - assert result.export.input_tensors is None + mock_optimum.assert_called_once() + assert result.export is mock_export_config # ============================================================================= @@ -1867,8 +1891,12 @@ def test_auto_auto_is_noop(self) -> None: # Default compile provider is "qnn" (from WinMLCompileConfig -> EPConfig) assert result.compile.ep_config.provider == "qnn" - def test_auto_auto_skips_resolve_device(self) -> None: - """device='auto' + precision='auto' does NOT call resolve_device.""" + def test_auto_auto_still_calls_resolve_device(self) -> None: + """device='auto' + precision='auto' DOES call resolve_device (#412). + + Previously this was skipped, causing EPConfig to default to 'qnn' + on machines without an NPU. Now we always detect hardware. + """ with ( patch( "winml.modelkit.config.build.resolve_loader_config", @@ -1894,7 +1922,7 @@ def test_auto_auto_skips_resolve_device(self) -> None: precision="auto", ) - mock_rd.assert_not_called() + mock_rd.assert_called_once_with(device="auto") def test_explicit_precision_triggers_resolve_device(self) -> None: """device='auto' + precision='int8' DOES call resolve_device.""" diff --git a/tests/unit/models/auto/test_automodel.py b/tests/unit/models/auto/test_automodel.py index 8ed62c863..52580db77 100644 --- a/tests/unit/models/auto/test_automodel.py +++ b/tests/unit/models/auto/test_automodel.py @@ -38,6 +38,7 @@ def _make_mock_model(num_labels: int = 1000): "output_names": ["logits"], } mock_session.is_compiled = True + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() diff --git a/tests/unit/models/auto/test_feature_extraction.py b/tests/unit/models/auto/test_feature_extraction.py index 95c42dce5..294d0cb15 100644 --- a/tests/unit/models/auto/test_feature_extraction.py +++ b/tests/unit/models/auto/test_feature_extraction.py @@ -31,6 +31,7 @@ def create_mock_model(): mock_session.run.return_value = { "last_hidden_state": np.random.randn(1, 8, 384).astype(np.float32), } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model._onnx_path = "mock.onnx" @@ -41,14 +42,17 @@ def create_mock_model(): class TestWinMLModelForFeatureExtractionBasic: def test_class_importable(self): from winml.modelkit.models.winml import WinMLModelForFeatureExtraction + assert WinMLModelForFeatureExtraction is not None def test_inherits_from_base(self): from winml.modelkit.models.winml import WinMLModelForFeatureExtraction, WinMLPreTrainedModel + assert issubclass(WinMLModelForFeatureExtraction, WinMLPreTrainedModel) def test_exported_from_winml_package(self): from winml.modelkit.models.winml import WinMLModelForFeatureExtraction + assert WinMLModelForFeatureExtraction is not None @@ -107,6 +111,7 @@ def test_sentence_embedding_unsqueezed(self): mock_session.run.return_value = { "sentence_embedding": np.zeros((1, 384), dtype=np.float32), } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model._onnx_path = "mock.onnx" @@ -130,6 +135,7 @@ def test_generic_2d_output_unsqueezed(self): mock_session.run.return_value = { "pooler_output": np.zeros((1, 768), dtype=np.float32), } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model._onnx_path = "mock.onnx" diff --git a/tests/unit/models/auto/test_image_classification.py b/tests/unit/models/auto/test_image_classification.py index 7dce67ad0..5cb20360f 100644 --- a/tests/unit/models/auto/test_image_classification.py +++ b/tests/unit/models/auto/test_image_classification.py @@ -41,6 +41,7 @@ def create_mock_model(num_labels: int = 1000): "input_names": ["pixel_values"], "output_names": ["logits"], } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model.config.num_labels = num_labels diff --git a/tests/unit/models/auto/test_image_segmentation.py b/tests/unit/models/auto/test_image_segmentation.py index f90784c12..f981dc38e 100644 --- a/tests/unit/models/auto/test_image_segmentation.py +++ b/tests/unit/models/auto/test_image_segmentation.py @@ -54,6 +54,7 @@ def create_mock_model( "input_names": ["pixel_values"], "output_names": ["logits", "pred_boxes", "pred_masks"], } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model.config.num_labels = num_classes @@ -185,6 +186,7 @@ def test_forward_missing_outputs_are_none(self): "input_names": ["pixel_values"], "output_names": ["logits"], } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model._onnx_path = "mock.onnx" @@ -292,6 +294,7 @@ def create_mock_semantic_model(num_labels: int = 150, output_h: int = 128, outpu "input_names": ["pixel_values"], "output_names": ["logits"], } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model.config.num_labels = num_labels diff --git a/tests/unit/models/auto/test_sequence_classification.py b/tests/unit/models/auto/test_sequence_classification.py index d7734876a..3133117db 100644 --- a/tests/unit/models/auto/test_sequence_classification.py +++ b/tests/unit/models/auto/test_sequence_classification.py @@ -39,6 +39,7 @@ def create_mock_model(num_labels: int = 2): "input_names": ["input_ids", "attention_mask", "token_type_ids"], "output_names": ["logits"], } + mock_session.device = "cpu" model._session = mock_session model.config = MagicMock() model.config.num_labels = num_labels diff --git a/tests/unit/session/test_ep_monitor.py b/tests/unit/session/test_ep_monitor.py index c5af9c74e..3f5174bff 100644 --- a/tests/unit/session/test_ep_monitor.py +++ b/tests/unit/session/test_ep_monitor.py @@ -797,7 +797,7 @@ class TestLiveMonitorDisplay: """Test LiveMonitorDisplay logic (non-visual).""" def test_render_status_warmup_phase(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=110, warmup=10, model_id="test", device="npu") status = display._render_status( @@ -813,7 +813,7 @@ def test_render_status_warmup_phase(self): assert "npu" in status.lower() or "Device" in status def test_render_status_benchmark_phase(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=110, warmup=10, model_id="test", device="npu") status = display._render_status( @@ -830,7 +830,7 @@ def test_render_status_benchmark_phase(self): assert "Latency" in status def test_render_status_zero_latency_no_crash(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # latency_ms=0 should not cause division by zero @@ -842,7 +842,7 @@ def test_render_status_zero_latency_no_crash(self): assert "Throughput" in status def test_render_status_empty_samples(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") status = display._render_status( @@ -853,7 +853,7 @@ def test_render_status_empty_samples(self): assert "0.0%" in status # NPU should show 0.0% def test_update_noop_when_live_is_none(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # _live is None (not entered context) — should not crash @@ -864,7 +864,7 @@ def test_update_noop_when_live_is_none(self): ) def test_print_final_snapshot_is_noop(self): - from winml.modelkit.commands.live_chart import LiveMonitorDisplay + from winml.modelkit.commands._live_chart import LiveMonitorDisplay display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # Should not crash or print anything