Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import click
from rich.console import Console

from ..config.precision import _DEVICE_TO_PROVIDER, VALID_EPS
from ..config.precision import _DEVICE_TO_PROVIDER, _EP_TO_DEVICE, VALID_EPS
from ..utils.logging import configure_logging


Expand Down Expand Up @@ -184,7 +184,7 @@ def compile(

# Show info
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Device:[/bold blue] {device}")
console.print(f"[bold blue]Device:[/bold blue] {_EP_TO_DEVICE.get(provider, device)}")
if ep:
console.print(f"[bold blue]EP:[/bold blue] {ep}")
console.print(f"[bold blue]Provider:[/bold blue] {provider}")
Expand Down
36 changes: 21 additions & 15 deletions src/winml/modelkit/commands/eval.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont quite get the default value changes here, why remove default values from cli args, and delay to the function call?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the previous bugbash report in description

Bug B — P2: wmk eval --samples N silently ignored when no --dataset is given
Command: wmk eval -m microsoft/resnet-50 --samples 20
Symptom: Output always shows 'samples': 100 regardless of --samples value.
Root cause: evaluate.py:149-155 — when config.dataset.path is None (no --dataset flag), the entire DatasetConfig is replaced wholesale with the hardcoded _DEFAULT_DATASETS entry (which has samples=100), discarding the user's value:

if config.dataset.path is None:
config.dataset = deepcopy(default) # overwrites samples, split, shuffle
Fix: Merge user-specified fields (samples, split, shuffle, seed) onto the default rather than replacing it entirely.
Location: src/winml/modelkit/eval/evaluate.py:155

The None could let user decide either user value or default config value could take effect

Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,19 @@
@click.option(
"--samples",
type=int,
default=100,
show_default=True,
help="Number of dataset samples.",
default=None,
Comment thread
zhenchaoni marked this conversation as resolved.
help="Number of dataset samples (default: from dataset config).",
)
@click.option(
"--split",
type=str,
default="validation",
show_default=True,
help="Dataset split.",
default=None,
help="Dataset split (default: from dataset config).",
)
@click.option(
"--shuffle/--no-shuffle",
default=True,
show_default=True,
help="Shuffle dataset before sampling.",
default=None,
help="Shuffle dataset before sampling (default: from dataset config).",
)
@click.option(
"--streaming",
Expand Down Expand Up @@ -125,9 +122,9 @@ def eval(
dataset_name: str | None,
task: str | None,
device: str,
samples: int,
split: str,
shuffle: bool,
samples: int | None,
split: str | None,
shuffle: bool | None,
streaming: bool,
column: tuple[str, ...],
label_mapping: Path | None,
Expand Down Expand Up @@ -213,15 +210,24 @@ def eval(
from ..datasets.config import DatasetConfig
from ..eval import WinMLEvaluationConfig, evaluate

explicit: set[str] = set()
if samples is not None:
explicit.add("samples")
if split is not None:
explicit.add("split")
if shuffle is not None:
explicit.add("shuffle")

ds_config = DatasetConfig(
path=dataset_path,
name=dataset_name,
split=split,
samples=samples,
shuffle=shuffle,
split=split if split is not None else "validation",
samples=samples if samples is not None else 100,
shuffle=shuffle if shuffle is not None else True,
columns_mapping=columns_mapping,
label_mapping=parsed_label_mapping,
streaming=streaming,
explicit_fields=frozenset(explicit),
)

config = WinMLEvaluationConfig(
Expand Down
16 changes: 9 additions & 7 deletions src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,15 @@ def perf(
console.print(f"[dim]Precision: {precision} (applied during model build)[/dim]")
console.print(f"[dim]Loading model:[/dim] {hf_model}")

# Op-tracing pre-flight: fail fast before running any benchmark iterations
if op_tracing:
from ..optracing import is_qnn_profiling_available

if not is_qnn_profiling_available():
console.print("[red]Error:[/red] Op-tracing requires onnxruntime-qnn")
console.print("Install with: [bold]pip install onnxruntime-qnn[/bold]")
raise SystemExit(1)

benchmark = PerfBenchmark(config)
result = benchmark.run()

Expand All @@ -1135,13 +1144,6 @@ def perf(
# Op-tracing (additive to existing benchmark)
# =================================================================
if op_tracing:
from ..optracing import is_qnn_profiling_available

if not is_qnn_profiling_available():
console.print("[red]Error:[/red] Op-tracing requires onnxruntime-qnn")
console.print("Install with: [bold]pip install onnxruntime-qnn[/bold]")
raise SystemExit(1)

from ..optracing.registry import get_tracer
from ..optracing.report import (
display_op_trace_report,
Expand Down
3 changes: 3 additions & 0 deletions src/winml/modelkit/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class DatasetConfig:
columns_mapping: dict[str, str] = field(default_factory=dict)
label_mapping: dict[str, int] | None = None
streaming: bool = False
# Tracks which fields were explicitly set by the caller (e.g. CLI).
# Not serialized; used by evaluate.py to merge user overrides onto defaults.
explicit_fields: frozenset[str] = field(default_factory=frozenset, repr=False, compare=False)

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
Expand Down
5 changes: 5 additions & 0 deletions src/winml/modelkit/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ def evaluate(config: WinMLEvaluationConfig) -> EvalResult:
raise ValueError(
f"No dataset provided and no default for task '{config.task}'. Use --dataset."
)
user_dataset = config.dataset
config.dataset = deepcopy(default)
# Apply fields the caller explicitly set (tracked via explicit_fields sentinel).
for f in ("samples", "split", "shuffle", "seed"):
if f in user_dataset.explicit_fields:
setattr(config.dataset, f, getattr(user_dataset, f))
logger.info(
"Using default dataset for %s: %s",
config.task,
Expand Down
33 changes: 18 additions & 15 deletions src/winml/modelkit/export/htp/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,16 @@ def export(
export_time = time.time() - start_time
self._export_stats["export_time"] = export_time
self._export_stats["hierarchy_modules"] = len(self._hierarchy_data)
self._export_stats["onnx_nodes"] = len(onnx_model.graph.node)
self._export_stats["tagged_nodes"] = len(self._tagged_nodes)
total_nodes = len(onnx_model.graph.node)
self._export_stats["onnx_nodes"] = total_nodes

# Calculate empty tags (should be 0 with our implementation)
empty_tag_count = sum(
1 for tag in self._tagged_nodes.values() if not tag or not tag.strip()
)
self._export_stats["empty_tags"] = empty_tag_count

# Calculate coverage percentage
total_nodes = len(onnx_model.graph.node)
tagged_nodes = len(self._tagged_nodes)
coverage = (tagged_nodes / total_nodes * 100.0) if total_nodes > 0 else 0.0
self._export_stats["coverage_percentage"] = coverage
self._update_tag_stats(total_nodes)

# Update monitor with actual export time
monitor.data.export_time = export_time
Expand Down Expand Up @@ -493,6 +489,18 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any:
)
return contextlib.nullcontext()

def _update_tag_stats(self, total_nodes: int) -> None:
"""Update tagged_nodes and coverage_percentage in export stats.

Centralises the embed-aware calculation so _apply_hierarchy_tags and
the final stats block in export() always stay in sync.
"""
embedded_count = len(self._tagged_nodes) if self.embed_hierarchy_attributes else 0
self._export_stats["tagged_nodes"] = embedded_count
self._export_stats["coverage_percentage"] = (
embedded_count / total_nodes * 100.0 if total_nodes > 0 else 0.0
Comment thread
xieofxie marked this conversation as resolved.
)

def _initialize_node_tagger(self, enable_operation_fallback: bool) -> None:
"""Create node tagger internally."""
self._node_tagger = create_node_tagger_from_hierarchy(
Expand All @@ -510,21 +518,16 @@ def _apply_hierarchy_tags(self, onnx_model: onnx.ModelProto) -> None:
self._tagging_stats = stats

# Update export stats
self._export_stats["onnx_nodes"] = len(onnx_model.graph.node)
self._export_stats["tagged_nodes"] = len(self._tagged_nodes)
total_nodes = len(onnx_model.graph.node)
self._export_stats["onnx_nodes"] = total_nodes
self._update_tag_stats(total_nodes)

# Calculate empty tags (should be 0 with our implementation)
empty_tag_count = sum(
1 for tag in self._tagged_nodes.values() if not tag or not tag.strip()
)
self._export_stats["empty_tags"] = empty_tag_count

# Calculate coverage percentage
total_nodes = len(onnx_model.graph.node)
tagged_nodes = len(self._tagged_nodes)
coverage = (tagged_nodes / total_nodes * 100.0) if total_nodes > 0 else 0.0
self._export_stats["coverage_percentage"] = coverage

def _embed_graph_metadata(
self, onnx_model: onnx.ModelProto, export_config: WinMLExportConfig
) -> None:
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/commands/test_compile_quantize_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from winml.modelkit.commands.compile import _resolve_compile_provider
Expand Down Expand Up @@ -119,3 +121,35 @@ def test_unknown_precision_uses_defaults(self):
w, a = _resolve_quant_types("fp16", None, None)
assert w == "uint8"
assert a == "uint8"


# =============================================================================
# BUG C: compile summary shows wrong Device when --ep overrides device
# =============================================================================


class TestCompileDeviceDisplayLabel:
"""Bug C: Device label in compile summary must reflect the resolved EP, not the CLI default."""

def test_dml_ep_shows_gpu_device(self, tmp_path):
from click.testing import CliRunner

from winml.modelkit.commands.compile import compile

model_file = tmp_path / "model.onnx"
model_file.write_bytes(b"fake")

mock_result = MagicMock()
mock_result.success = True
mock_result.output_path = None
mock_result.compile_time = None
mock_result.total_time = None

with (
patch("winml.modelkit.compiler.compile_onnx", return_value=mock_result),
patch("winml.modelkit.compiler.WinMLCompileConfig"),
):
result = CliRunner().invoke(compile, ["-m", str(model_file), "--ep", "dml"])

assert "Device: gpu" in result.output
assert "Device: npu" not in result.output
23 changes: 23 additions & 0 deletions tests/unit/commands/test_perf_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,26 @@ def test_onnx_load_model_passes_ep(self, tmp_path: Path) -> None:
benchmark._load_model()

assert mock_from_onnx.call_args.kwargs["ep"] == "qnn"


# =============================================================================
# BUG A: op-tracing preflight check
# =============================================================================


class TestOpTracingPreflight:
"""Bug A: is_qnn_profiling_available must be checked before the benchmark runs."""

def test_benchmark_does_not_run_when_qnn_unavailable(self, runner: CliRunner) -> None:
with (
patch("winml.modelkit.optracing.is_qnn_profiling_available", return_value=False),
patch("winml.modelkit.commands.perf.PerfBenchmark") as mock_cls,
):
result = runner.invoke(
perf,
["-m", "microsoft/resnet-50", "--op-tracing", "basic"],
obj={},
)

assert result.exit_code != 0
mock_cls.return_value.run.assert_not_called()
32 changes: 32 additions & 0 deletions tests/unit/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,35 @@ def test_load_model_from_onnx(self):

mock_auto.from_onnx.assert_called_once()
assert result.config is mock_hf_config


# =============================================================================
# BUG B: --samples silently ignored when no --dataset
# =============================================================================


class TestEvaluateSamplesPreserved:
"""Bug B: user --samples must survive the default-dataset merge."""

def test_user_samples_preserved_when_no_dataset_path(self) -> None:
import importlib
import sys

eval_mod = sys.modules.get("winml.modelkit.eval.evaluate") or importlib.import_module(
"winml.modelkit.eval.evaluate"
)

config = WinMLEvaluationConfig(
model_id="test/model",
task="image-classification",
dataset=DatasetConfig(samples=20, explicit_fields=frozenset({"samples"})),
)

with (
patch.object(eval_mod, "_load_model"),
patch.object(eval_mod, "WinMLEvaluator") as mock_cls,
):
mock_cls.return_value.compute.return_value = {}
result = eval_mod.evaluate(config)

assert result.config.dataset.samples == 20
66 changes: 66 additions & 0 deletions tests/unit/export/test_htp_exporter_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for HTPExporter export statistics correctness."""

from __future__ import annotations

from unittest.mock import MagicMock

from winml.modelkit.export.htp import HTPExporter

Comment thread
xieofxie marked this conversation as resolved.

class TestHTPExporterTaggedNodesStats:
"""Bug D: tagged_nodes and coverage must be 0 when embed_hierarchy_attributes=False."""

def test_tagged_nodes_zero_when_hierarchy_disabled(self) -> None:
exporter = HTPExporter(embed_hierarchy_attributes=False)
exporter._node_tagger = MagicMock()
exporter._node_tagger.tag_all_nodes.return_value = {
"node1": "/Model/Layer1",
"node2": "/Model/Layer2",
"node3": "/Model/Layer3",
}
exporter._node_tagger.get_tagging_statistics.return_value = {}

mock_model = MagicMock()
mock_model.graph.node = [MagicMock() for _ in range(5)]

exporter._apply_hierarchy_tags(mock_model)

assert exporter._export_stats["tagged_nodes"] == 0

def test_coverage_zero_when_hierarchy_disabled(self) -> None:
exporter = HTPExporter(embed_hierarchy_attributes=False)
exporter._node_tagger = MagicMock()
exporter._node_tagger.tag_all_nodes.return_value = {
"n1": "/t1",
"n2": "/t2",
}
exporter._node_tagger.get_tagging_statistics.return_value = {}

mock_model = MagicMock()
mock_model.graph.node = [MagicMock() for _ in range(4)]

exporter._apply_hierarchy_tags(mock_model)

assert exporter._export_stats["coverage_percentage"] == 0.0

def test_tagged_nodes_nonzero_when_hierarchy_enabled(self) -> None:
"""Control: stats are populated normally when embedding is enabled."""
exporter = HTPExporter(embed_hierarchy_attributes=True)
exporter._node_tagger = MagicMock()
exporter._node_tagger.tag_all_nodes.return_value = {
"n1": "/t1",
"n2": "/t2",
}
exporter._node_tagger.get_tagging_statistics.return_value = {}

mock_model = MagicMock()
mock_model.graph.node = [MagicMock() for _ in range(4)]

exporter._apply_hierarchy_tags(mock_model)

assert exporter._export_stats["tagged_nodes"] == 2
assert exporter._export_stats["coverage_percentage"] == 50.0
Loading