From 54c274b4eda697485f6e6b21199f5dd99a3a456c Mon Sep 17 00:00:00 2001 From: Shiyi Zheng Date: Thu, 23 Apr 2026 15:19:03 +0800 Subject: [PATCH 01/12] Add eval_dataset support in config and eval command --- src/winml/modelkit/commands/eval.py | 51 +++++++++++++++++++++++++++++ src/winml/modelkit/config/build.py | 6 +++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 24caaf87b..94806281a 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -171,6 +171,57 @@ def eval( samples = build_cfg.quant.samples if not cli_utils.is_cli_provided(ctx, "dataset_name"): dataset_name = build_cfg.quant.dataset_name + # Apply eval_dataset from config (CLI options take precedence) + if build_cfg.eval_dataset: + ed = build_cfg.eval_dataset + # Run build_script if needed to generate local dataset + build_script = ed.get("build_script") + if build_script: + import subprocess + import sys + + script_path = Path(build_script) + # Use the path from config as cache dir (expand ~), + # fallback to ~/.cache/winml/eval_datasets// + raw_path = ed.get("path", "") + if raw_path: + cache_dir = Path(raw_path).expanduser() + else: + cache_dir = ( + Path.home() / ".cache" / "winml" / "eval_datasets" / script_path.stem + ) + if not (cache_dir / "dataset_info.json").exists(): + if script_path.exists(): + logger.info("Building dataset via %s ...", script_path.name) + result = subprocess.run( # noqa: S603 + [sys.executable, str(script_path), "--output", str(cache_dir)], + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + logger.warning( + "Dataset build failed: %s", result.stderr.strip()[-200:] + ) + else: + logger.warning("Build script not found: %s", script_path) + # Use the built local dataset path + if not cli_utils.is_cli_provided(ctx, "dataset_path"): + dataset_path = str(cache_dir) + elif not cli_utils.is_cli_provided(ctx, "dataset_path") and ed.get("path"): + dataset_path = ed["path"] + if not cli_utils.is_cli_provided(ctx, "dataset_name") and ed.get("name"): + dataset_name = ed["name"] + if not cli_utils.is_cli_provided(ctx, "split") and ed.get("split"): + split = ed["split"] + if not cli_utils.is_cli_provided(ctx, "samples") and ed.get("samples"): + samples = ed["samples"] + if not column and ed.get("columns_mapping"): + column = tuple(f"{k}={v}" for k, v in ed["columns_mapping"].items()) + if label_mapping is None and ed.get("label_mapping_file"): + lm_path = Path(ed["label_mapping_file"]) + if lm_path.exists(): + label_mapping = lm_path if show_schema: from ..eval import WinMLEvaluator diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 3f334f287..68833dd2c 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -126,6 +126,7 @@ class WinMLBuildConfig: optim: WinMLOptimizationConfig = field(default_factory=WinMLOptimizationConfig) quant: WinMLQuantizationConfig | None = field(default_factory=WinMLQuantizationConfig) compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) + eval_dataset: dict | None = None @classmethod def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: @@ -134,6 +135,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: export_data = config_dict.get("export", {}) quant_data = config_dict.get("quant") compile_data = config_dict.get("compile") + eval_dataset_data = config_dict.get("eval_dataset") return cls( loader=WinMLLoaderConfig.from_dict(loader_data), export=(WinMLExportConfig.from_dict(export_data) if export_data is not None else None), @@ -144,6 +146,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: compile=( WinMLCompileConfig.from_dict(compile_data) if compile_data is not None else None ), + eval_dataset=eval_dataset_data, ) def to_dict(self) -> dict: @@ -158,6 +161,8 @@ def to_dict(self) -> dict: loader_dict = self.loader.to_dict() if loader_dict: result["loader"] = loader_dict + if self.eval_dataset is not None: + result["eval_dataset"] = self.eval_dataset return result def validate(self) -> None: @@ -874,7 +879,6 @@ def _merge_export_config( dynamic_axes=( override.dynamic_axes if override.dynamic_axes is not None else base.dynamic_axes ), - dynamo=override.dynamo if override.dynamo else base.dynamo, ) From c63e6f34ae7d06854c4de240b5385896c01e045d Mon Sep 17 00:00:00 2001 From: Shiyi Zheng Date: Mon, 27 Apr 2026 21:23:44 +0800 Subject: [PATCH 02/12] refactor: address PR review comments - Rename eval_dataset to eval_option with WinMLEvaluationConfig type - Dynamic load WinMLEvaluationConfig in build config (lazy import) - Rename build_script to dataset_script - Add --dataset-script and --trust-remote-code CLI options - Decouple config defaults from script execution logic - Simplify: dataset script prints path to stdout, no cache_dir logic - Config section only provides default values, no file existence checks --- src/winml/modelkit/commands/eval.py | 110 +++++++++++++++------------- src/winml/modelkit/config/build.py | 15 ++-- src/winml/modelkit/eval/config.py | 8 ++ 3 files changed, 77 insertions(+), 56 deletions(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 94806281a..427ca7a6b 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -111,6 +111,18 @@ default=False, help="Enable verbose output.", ) +@click.option( + "--dataset-script", + type=str, + default=None, + help="Path to a Python script that builds the evaluation dataset.", +) +@click.option( + "--trust-remote-code", + is_flag=True, + default=False, + help="Allow execution of dataset scripts. Required when --dataset-script is used.", +) @click.option( "--schema", "show_schema", @@ -136,6 +148,8 @@ def eval( label_mapping: Path | None, output: Path | None, verbose: bool, + dataset_script: str | None, + trust_remote_code: bool, show_schema: bool, config_file: Path | None, ) -> None: @@ -171,57 +185,51 @@ def eval( samples = build_cfg.quant.samples if not cli_utils.is_cli_provided(ctx, "dataset_name"): dataset_name = build_cfg.quant.dataset_name - # Apply eval_dataset from config (CLI options take precedence) - if build_cfg.eval_dataset: - ed = build_cfg.eval_dataset - # Run build_script if needed to generate local dataset - build_script = ed.get("build_script") - if build_script: - import subprocess - import sys - - script_path = Path(build_script) - # Use the path from config as cache dir (expand ~), - # fallback to ~/.cache/winml/eval_datasets// - raw_path = ed.get("path", "") - if raw_path: - cache_dir = Path(raw_path).expanduser() - else: - cache_dir = ( - Path.home() / ".cache" / "winml" / "eval_datasets" / script_path.stem - ) - if not (cache_dir / "dataset_info.json").exists(): - if script_path.exists(): - logger.info("Building dataset via %s ...", script_path.name) - result = subprocess.run( # noqa: S603 - [sys.executable, str(script_path), "--output", str(cache_dir)], - capture_output=True, - text=True, - timeout=300, - ) - if result.returncode != 0: - logger.warning( - "Dataset build failed: %s", result.stderr.strip()[-200:] - ) - else: - logger.warning("Build script not found: %s", script_path) - # Use the built local dataset path - if not cli_utils.is_cli_provided(ctx, "dataset_path"): - dataset_path = str(cache_dir) - elif not cli_utils.is_cli_provided(ctx, "dataset_path") and ed.get("path"): - dataset_path = ed["path"] - if not cli_utils.is_cli_provided(ctx, "dataset_name") and ed.get("name"): - dataset_name = ed["name"] - if not cli_utils.is_cli_provided(ctx, "split") and ed.get("split"): - split = ed["split"] - if not cli_utils.is_cli_provided(ctx, "samples") and ed.get("samples"): - samples = ed["samples"] - if not column and ed.get("columns_mapping"): - column = tuple(f"{k}={v}" for k, v in ed["columns_mapping"].items()) - if label_mapping is None and ed.get("label_mapping_file"): - lm_path = Path(ed["label_mapping_file"]) - if lm_path.exists(): - label_mapping = lm_path + # Apply eval_option from config (CLI options take precedence) + if build_cfg.eval_option: + eo = build_cfg.eval_option + if not cli_utils.is_cli_provided(ctx, "dataset_path") and eo.dataset.path: + dataset_path = eo.dataset.path + if not cli_utils.is_cli_provided(ctx, "dataset_name") and eo.dataset.name: + dataset_name = eo.dataset.name + if not cli_utils.is_cli_provided(ctx, "split"): + split = eo.dataset.split + if not cli_utils.is_cli_provided(ctx, "samples"): + samples = eo.dataset.samples + if not column and eo.dataset.columns_mapping: + column = tuple(f"{k}={v}" for k, v in eo.dataset.columns_mapping.items()) + if label_mapping is None and eo.label_mapping_file: + label_mapping = Path(eo.label_mapping_file) + if dataset_script is None and eo.dataset_script: + dataset_script = eo.dataset_script + + # Run dataset_script if provided (requires --trust-remote-code) + if dataset_script: + if not trust_remote_code: + raise click.UsageError( + "--trust-remote-code is required to execute a dataset script." + ) + import subprocess + import sys + + script_path = Path(dataset_script) + if not script_path.exists(): + raise click.BadParameter(f"Dataset script not found: {script_path}") + logger.info("Building dataset via %s ...", script_path.name) + result = subprocess.run( # noqa: S603 + [sys.executable, str(script_path)], + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + raise click.ClickException( + f"Dataset script failed: {result.stderr.strip()[-200:]}" + ) + # Use script output (stdout) as dataset path if not already set + script_output = result.stdout.strip() + if script_output and not cli_utils.is_cli_provided(ctx, "dataset_path"): + dataset_path = script_output if show_schema: from ..eval import WinMLEvaluator diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 68833dd2c..4736f43d4 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -58,6 +58,7 @@ ) from ..loader.config import WinMLLoaderConfig, resolve_loader_config from ..optim.config import WinMLOptimizationConfig +from ..eval.config import WinMLEvaluationConfig from ..quant.config import WinMLQuantizationConfig from ..utils.config_utils import merge_config @@ -126,7 +127,7 @@ class WinMLBuildConfig: optim: WinMLOptimizationConfig = field(default_factory=WinMLOptimizationConfig) quant: WinMLQuantizationConfig | None = field(default_factory=WinMLQuantizationConfig) compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) - eval_dataset: dict | None = None + eval_option: WinMLEvaluationConfig | None = None @classmethod def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: @@ -135,7 +136,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: export_data = config_dict.get("export", {}) quant_data = config_dict.get("quant") compile_data = config_dict.get("compile") - eval_dataset_data = config_dict.get("eval_dataset") + eval_option_data = config_dict.get("eval_option") return cls( loader=WinMLLoaderConfig.from_dict(loader_data), export=(WinMLExportConfig.from_dict(export_data) if export_data is not None else None), @@ -146,7 +147,11 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: compile=( WinMLCompileConfig.from_dict(compile_data) if compile_data is not None else None ), - eval_dataset=eval_dataset_data, + eval_option=( + WinMLEvaluationConfig.from_dict(eval_option_data) + if eval_option_data is not None + else None + ), ) def to_dict(self) -> dict: @@ -161,8 +166,8 @@ def to_dict(self) -> dict: loader_dict = self.loader.to_dict() if loader_dict: result["loader"] = loader_dict - if self.eval_dataset is not None: - result["eval_dataset"] = self.eval_dataset + if self.eval_option is not None: + result["eval_option"] = self.eval_option.to_dict() return result def validate(self) -> None: diff --git a/src/winml/modelkit/eval/config.py b/src/winml/modelkit/eval/config.py index 6decd9087..360d59b50 100644 --- a/src/winml/modelkit/eval/config.py +++ b/src/winml/modelkit/eval/config.py @@ -59,6 +59,8 @@ class WinMLEvaluationConfig: device: str = "cpu" dataset: DatasetConfig = field(default_factory=DatasetConfig) output_path: Path | None = None + dataset_script: str | None = None + label_mapping_file: str | None = None def to_dict(self) -> dict: """Convert to dictionary for serialization.""" @@ -73,6 +75,10 @@ def to_dict(self) -> dict: result["dataset"] = self.dataset.to_dict() if self.output_path is not None: result["output_path"] = str(self.output_path) + if self.dataset_script is not None: + result["dataset_script"] = self.dataset_script + if self.label_mapping_file is not None: + result["label_mapping_file"] = self.label_mapping_file return result @classmethod @@ -96,4 +102,6 @@ def from_dict(cls, data: dict) -> WinMLEvaluationConfig: device=data.get("device", "cpu"), dataset=dataset, output_path=(Path(data["output_path"]) if data.get("output_path") else None), + dataset_script=data.get("dataset_script"), + label_mapping_file=data.get("label_mapping_file"), ) From 1e8f748f01eb6b86657b7d8aa3057638950c4045 Mon Sep 17 00:00:00 2001 From: Shiyi Zheng Date: Mon, 27 Apr 2026 22:17:34 +0800 Subject: [PATCH 03/12] fix comments --- scripts/e2e_eval/datasets/build_ai4privacy.py | 9 +++++++-- scripts/e2e_eval/datasets/build_fairface.py | 9 +++++++-- scripts/e2e_eval/datasets/build_indonlu_posp.py | 9 +++++++-- scripts/e2e_eval/datasets/build_pubtables1m_detection.py | 9 +++++++-- scripts/e2e_eval/datasets/build_pubtables1m_structure.py | 9 +++++++-- src/winml/modelkit/commands/eval.py | 7 ++++--- src/winml/modelkit/config/build.py | 3 ++- 7 files changed, 41 insertions(+), 14 deletions(-) diff --git a/scripts/e2e_eval/datasets/build_ai4privacy.py b/scripts/e2e_eval/datasets/build_ai4privacy.py index a56c4f8df..52724e96d 100644 --- a/scripts/e2e_eval/datasets/build_ai4privacy.py +++ b/scripts/e2e_eval/datasets/build_ai4privacy.py @@ -59,11 +59,16 @@ def build_dataset(output_dir: Path) -> None: print("Done.") +_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_ai4privacy" + + def main() -> None: parser = argparse.ArgumentParser(description="Build ai4privacy PII dataset") - parser.add_argument("--output", type=Path, required=True, help="Output directory") + parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_ai4privacy)") args = parser.parse_args() - build_dataset(args.output) + output_dir = args.output or _DEFAULT_CACHE_DIR + build_dataset(output_dir) + print(output_dir) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_fairface.py b/scripts/e2e_eval/datasets/build_fairface.py index 82b3641b0..22a77f489 100644 --- a/scripts/e2e_eval/datasets/build_fairface.py +++ b/scripts/e2e_eval/datasets/build_fairface.py @@ -95,11 +95,16 @@ def build_dataset(output_dir: Path) -> None: print("Done.") +_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_fairface" + + def main() -> None: parser = argparse.ArgumentParser(description="Build fairface validation dataset") - parser.add_argument("--output", type=Path, required=True, help="Output directory") + parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_fairface)") args = parser.parse_args() - build_dataset(args.output) + output_dir = args.output or _DEFAULT_CACHE_DIR + build_dataset(output_dir) + print(output_dir) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_indonlu_posp.py b/scripts/e2e_eval/datasets/build_indonlu_posp.py index 724bef63d..52e276d48 100644 --- a/scripts/e2e_eval/datasets/build_indonlu_posp.py +++ b/scripts/e2e_eval/datasets/build_indonlu_posp.py @@ -44,11 +44,16 @@ def build_dataset(output_dir: Path) -> None: print("Done.") +_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_indonlu_posp" + + def main() -> None: parser = argparse.ArgumentParser(description="Build indonlu posp dataset") - parser.add_argument("--output", type=Path, required=True, help="Output directory") + parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_indonlu_posp)") args = parser.parse_args() - build_dataset(args.output) + output_dir = args.output or _DEFAULT_CACHE_DIR + build_dataset(output_dir) + print(output_dir) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_pubtables1m_detection.py b/scripts/e2e_eval/datasets/build_pubtables1m_detection.py index 043f5b483..687caecd7 100644 --- a/scripts/e2e_eval/datasets/build_pubtables1m_detection.py +++ b/scripts/e2e_eval/datasets/build_pubtables1m_detection.py @@ -198,11 +198,16 @@ def build_dataset(output_dir: Path) -> None: print("Done.") +_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_pubtables1m_detection" + + def main() -> None: parser = argparse.ArgumentParser(description="Build PubTables-1M detection dataset") - parser.add_argument("--output", type=Path, required=True, help="Output directory") + parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_pubtables1m_detection)") args = parser.parse_args() - build_dataset(args.output) + output_dir = args.output or _DEFAULT_CACHE_DIR + build_dataset(output_dir) + print(output_dir) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_pubtables1m_structure.py b/scripts/e2e_eval/datasets/build_pubtables1m_structure.py index 9a640717b..775495e0a 100644 --- a/scripts/e2e_eval/datasets/build_pubtables1m_structure.py +++ b/scripts/e2e_eval/datasets/build_pubtables1m_structure.py @@ -206,11 +206,16 @@ def build_dataset(output_dir: Path) -> None: print("Done.") +_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_pubtables1m_structure" + + def main() -> None: parser = argparse.ArgumentParser(description="Build PubTables-1M structure recognition dataset") - parser.add_argument("--output", type=Path, required=True, help="Output directory") + parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_pubtables1m_structure)") args = parser.parse_args() - build_dataset(args.output) + output_dir = args.output or _DEFAULT_CACHE_DIR + build_dataset(output_dir) + print(output_dir) if __name__ == "__main__": diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 427ca7a6b..b46780c19 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -226,9 +226,10 @@ def eval( raise click.ClickException( f"Dataset script failed: {result.stderr.strip()[-200:]}" ) - # Use script output (stdout) as dataset path if not already set - script_output = result.stdout.strip() - if script_output and not cli_utils.is_cli_provided(ctx, "dataset_path"): + # Use the last line of stdout as dataset path (earlier lines are log messages) + stdout = result.stdout.strip() + script_output = stdout.splitlines()[-1].strip() if stdout else "" + if script_output: dataset_path = script_output if show_schema: diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 4736f43d4..60b14319b 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -50,6 +50,7 @@ from typing import TYPE_CHECKING, Any, overload from ..compiler.configs import WinMLCompileConfig +from ..eval.config import WinMLEvaluationConfig from ..export.config import ( InputTensorSpec, OutputTensorSpec, @@ -58,7 +59,6 @@ ) from ..loader.config import WinMLLoaderConfig, resolve_loader_config from ..optim.config import WinMLOptimizationConfig -from ..eval.config import WinMLEvaluationConfig from ..quant.config import WinMLQuantizationConfig from ..utils.config_utils import merge_config @@ -884,6 +884,7 @@ def _merge_export_config( dynamic_axes=( override.dynamic_axes if override.dynamic_axes is not None else base.dynamic_axes ), + dynamo=override.dynamo if override.dynamo else base.dynamo, ) From 1da338413ce0244887c8d77c18d4f4d8ef1f9cad Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Wed, 29 Apr 2026 11:24:04 +0800 Subject: [PATCH 04/12] refactor eval command: config-centric flow, rename eval_option to eval, move dataset fields --- src/winml/modelkit/commands/eval.py | 328 ++++++++++++++++---------- src/winml/modelkit/config/build.py | 18 +- src/winml/modelkit/datasets/config.py | 11 + src/winml/modelkit/eval/config.py | 10 +- 4 files changed, 230 insertions(+), 137 deletions(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index b46780c19..b8c16bb83 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -175,159 +175,245 @@ def eval( if verbose or (ctx.obj and ctx.obj.get("debug")): logging.getLogger("winml.modelkit").setLevel(logging.DEBUG) - # Apply build config defaults (CLI explicit options take precedence) + from ..eval import WinMLEvaluationConfig, evaluate + + # ── 1. Build config: CLI > config file eval > quant/loader > defaults ── + cfg = WinMLEvaluationConfig() if config_file is not None: - build_cfg = cli_utils.load_build_config(config_file) - if build_cfg.loader and not cli_utils.is_cli_provided(ctx, "task"): - task = build_cfg.loader.task - if build_cfg.quant: - if not cli_utils.is_cli_provided(ctx, "samples"): - samples = build_cfg.quant.samples - if not cli_utils.is_cli_provided(ctx, "dataset_name"): - dataset_name = build_cfg.quant.dataset_name - # Apply eval_option from config (CLI options take precedence) - if build_cfg.eval_option: - eo = build_cfg.eval_option - if not cli_utils.is_cli_provided(ctx, "dataset_path") and eo.dataset.path: - dataset_path = eo.dataset.path - if not cli_utils.is_cli_provided(ctx, "dataset_name") and eo.dataset.name: - dataset_name = eo.dataset.name - if not cli_utils.is_cli_provided(ctx, "split"): - split = eo.dataset.split - if not cli_utils.is_cli_provided(ctx, "samples"): - samples = eo.dataset.samples - if not column and eo.dataset.columns_mapping: - column = tuple(f"{k}={v}" for k, v in eo.dataset.columns_mapping.items()) - if label_mapping is None and eo.label_mapping_file: - label_mapping = Path(eo.label_mapping_file) - if dataset_script is None and eo.dataset_script: - dataset_script = eo.dataset_script - - # Run dataset_script if provided (requires --trust-remote-code) - if dataset_script: - if not trust_remote_code: - raise click.UsageError( - "--trust-remote-code is required to execute a dataset script." - ) - import subprocess - import sys - - script_path = Path(dataset_script) - if not script_path.exists(): - raise click.BadParameter(f"Dataset script not found: {script_path}") - logger.info("Building dataset via %s ...", script_path.name) - result = subprocess.run( # noqa: S603 - [sys.executable, str(script_path)], - capture_output=True, - text=True, - timeout=300, - ) - if result.returncode != 0: - raise click.ClickException( - f"Dataset script failed: {result.stderr.strip()[-200:]}" - ) - # Use the last line of stdout as dataset path (earlier lines are log messages) - stdout = result.stdout.strip() - script_output = stdout.splitlines()[-1].strip() if stdout else "" - if script_output: - dataset_path = script_output + cfg = _merge_from_config_file(cfg, config_file) + cfg = _merge_cli_overrides(cfg, ctx, column, label_mapping) if show_schema: from ..eval import WinMLEvaluator from ..eval.evaluate import _EVALUATOR_REGISTRY - if task is None: + if cfg.task is None: raise click.UsageError( "--schema requires --task. Example: winml eval --schema --task object-detection" ) - cls = _EVALUATOR_REGISTRY.get(task, WinMLEvaluator) - _print_schema(task, cls.schema_info()) + cls = _EVALUATOR_REGISTRY.get(cfg.task, WinMLEvaluator) + _print_schema(cfg.task, cls.schema_info()) return - if model is None and model_id is None: - raise click.UsageError( - "A model is required. Provide -m with a HuggingFace model ID or path to an .onnx file." - ) + # ── 2. Resolve in place ── + _resolve_model(cfg, model) + _resolve_device(cfg) + _resolve_label_mapping(cfg) + _run_dataset_script(cfg, trust_remote_code) + + logger.debug("Effective eval config: %s", cfg.to_dict()) + + # ── 3. Evaluate ── + try: + result = evaluate(cfg) + _write_and_display(result, cfg.output_path) + except Exception as e: + logger.exception("Evaluation failed") + raise click.ClickException(f"Evaluation failed: {e}") from e + + +# ── CLI-to-config field mappings ── +# Adding a new option only requires one entry here; both config-file and CLI +# paths pick it up automatically via merge_config. +_CLI_TO_CONFIG: dict[str, str] = { + "model_id": "model_id", + "task": "task", + "device": "device", +} + +_CLI_TO_DATASET: dict[str, str] = { + "dataset_path": "path", + "dataset_name": "name", + "split": "split", + "samples": "samples", + "shuffle": "shuffle", + "streaming": "streaming", + "dataset_script": "build_script", +} + + +def _merge_from_config_file( + cfg: object, + config_file: Path, +) -> object: + """Merge eval-relevant values from a build config file into *cfg*. + + Reads the raw JSON so that only fields explicitly present in the file + are applied (avoids overriding with dataclass defaults). + + Precedence (lowest → highest): quant/loader fields → eval section. + """ + from ..utils.config_utils import merge_config + + raw = json.loads(config_file.read_text()) + + # Loader task as lowest-priority fallback + loader_data = raw.get("loader", {}) + if "task" in loader_data: + cfg.task = loader_data["task"] + + # Quant fields as fallback (only explicitly-set values) + quant_data = raw.get("quant", {}) + quant_overrides: dict = {} + if "samples" in quant_data: + quant_overrides.setdefault("dataset", {})["samples"] = quant_data["samples"] + if "dataset_name" in quant_data: + quant_overrides.setdefault("dataset", {})["name"] = quant_data["dataset_name"] + if quant_overrides: + cfg = merge_config(cfg, quant_overrides) + + # Eval section overrides quant/loader + eval_data = raw.get("eval") + if eval_data: + cfg = merge_config(cfg, eval_data) + + return cfg + + +def _merge_cli_overrides( + cfg: object, + ctx: click.Context, + column: tuple[str, ...], + label_mapping: Path | None, +) -> object: + """Overlay CLI-provided values onto *cfg* (highest priority). + + Only values the user explicitly passed on the command line are applied. + Full precedence chain (lowest → highest): + dataclass defaults → quant/loader → config file eval → CLI + """ + from ..utils.config_utils import merge_config + + overrides: dict = {} + ds_overrides: dict = {} + + for cli_name, cfg_name in _CLI_TO_CONFIG.items(): + if cli_utils.is_cli_provided(ctx, cli_name): + overrides[cfg_name] = ctx.params[cli_name] + + for cli_name, cfg_name in _CLI_TO_DATASET.items(): + if cli_utils.is_cli_provided(ctx, cli_name): + ds_overrides[cfg_name] = ctx.params[cli_name] + + # --column is multiple=True; non-empty tuple means user provided it + if column: + columns_mapping: dict[str, str] = {} + for c in column: + if "=" not in c: + raise click.BadParameter( + f"Invalid column format: '{c}'. Use key=value.", + param_hint="--column", + ) + k, v = c.split("=", 1) + columns_mapping[k] = v + ds_overrides["columns_mapping"] = columns_mapping + + if label_mapping is not None: + ds_overrides["label_mapping_file"] = str(label_mapping) + + if cli_utils.is_cli_provided(ctx, "output"): + overrides["output_path"] = ctx.params["output"] + + if ds_overrides: + overrides["dataset"] = ds_overrides - # Detect: -m as HF model ID (not an ONNX file) -> treat as model_id - model_path = None - if model is not None: - p = Path(model) + if overrides: + cfg = merge_config(cfg, overrides) + + return cfg + + +def _resolve_model(cfg: object, model_arg: str | None) -> None: + """Resolve ``-m`` into ``model_id`` / ``model_path`` on *cfg* in place.""" + if model_arg is not None: + p = Path(model_arg) if p.suffix.lower() == ".onnx": if not p.exists(): raise click.BadParameter( - f"ONNX file not found: {model}", + f"ONNX file not found: {model_arg}", param_hint="-m/--model", ) - model_path = model - if model_id is None: + cfg.model_path = model_arg + if cfg.model_id is None: raise click.UsageError( "When using an ONNX file, --model-id is required " "for preprocessor and config resolution." ) else: - model_id = model_id or model - - # Parse column mappings from --column key=value pairs - columns_mapping: dict[str, str] = {} - for c in column: - if "=" not in c: - raise click.BadParameter( - f"Invalid column format: '{c}'. Use key=value.", - param_hint="--column", - ) - k, v = c.split("=", 1) - columns_mapping[k] = v + cfg.model_id = cfg.model_id or model_arg - # Parse label mapping from JSON file - parsed_label_mapping = None - if label_mapping: - with Path(label_mapping).open() as f: - parsed_label_mapping = json.load(f) + if cfg.model_id is None and cfg.model_path is None: + raise click.UsageError( + "A model is required. Provide -m with a HuggingFace model ID or path to an .onnx file." + ) - from ..datasets import DatasetConfig - from ..eval import WinMLEvaluationConfig, evaluate + +def _resolve_device(cfg: object) -> None: + """Resolve ``'auto'`` → concrete device string on *cfg* in place.""" from ..sysinfo import resolve_device - resolved_device, _ = resolve_device(device) - - ds_config = DatasetConfig( - path=dataset_path, - name=dataset_name, - split=split, - samples=samples, - shuffle=shuffle, - columns_mapping=columns_mapping, - label_mapping=parsed_label_mapping, - streaming=streaming, - ) + resolved, _ = resolve_device(cfg.device) + cfg.device = resolved - config = WinMLEvaluationConfig( - model_path=model_path, - model_id=model_id, - task=task, - device=resolved_device, - dataset=ds_config, - output_path=output, - ) - try: - result = evaluate(config) +def _resolve_label_mapping(cfg: object) -> None: + """Load label-mapping JSON file (if any) into ``cfg.dataset.label_mapping``.""" + if cfg.dataset.label_mapping_file: + with Path(cfg.dataset.label_mapping_file).open() as f: + cfg.dataset.label_mapping = json.load(f) - from rich.console import Console - console = Console() - display_eval_report(result, console) +def _run_dataset_script(cfg: object, trust_remote_code: bool) -> None: + """Run the dataset build script referenced by *cfg*, if any. - if output is not None: - output.parent.mkdir(parents=True, exist_ok=True) - with output.open("w") as f: - json.dump(result.to_dict(), f, indent=2, default=_json_default) - console.print(f"[green]Results saved to:[/green] {output}") + The script is invoked with ``--output `` so the built + dataset lands at the path already configured in the config file. + """ + if not cfg.dataset.build_script: + return - except Exception as e: - logger.exception("Evaluation failed") - raise click.ClickException(f"Evaluation failed: {e}") from e + if not trust_remote_code: + raise click.UsageError( + "--trust-remote-code is required to execute a dataset script." + ) + + import subprocess + import sys + + script_path = Path(cfg.dataset.build_script) + if not script_path.exists(): + raise click.BadParameter(f"Dataset script not found: {script_path}") + + cmd = [sys.executable, str(script_path)] + if cfg.dataset.path: + cmd += ["--output", str(Path(cfg.dataset.path).expanduser())] + + logger.info("Building dataset via %s ...", script_path.name) + result = subprocess.run( # noqa: S603 + cmd, + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + raise click.ClickException( + f"Dataset script failed (exit {result.returncode}): " + f"{result.stderr.strip()[-200:] or '(no stderr)'}" + ) + + +def _write_and_display(result: object, output_path: Path | None) -> None: + """Display evaluation results and optionally save to JSON.""" + from rich.console import Console + + console = Console() + display_eval_report(result, console) + + if output_path is not None: + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(result.to_dict(), f, indent=2, default=_json_default) + console.print(f"[green]Results saved to:[/green] {output_path}") def _json_default(obj: object) -> object: diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 60b14319b..e18937513 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -16,7 +16,8 @@ ├── export: WinMLExportConfig # from modelkit/export/config.py ├── optim: WinMLOptimizationConfig # from modelkit/optim/config.py ├── quant: WinMLQuantizationConfig # from modelkit/quant/config.py - └── compile: WinMLCompileConfig # from modelkit/compiler/configs.py + ├── compile: WinMLCompileConfig # from modelkit/compiler/configs.py + └── eval: WinMLEvaluationConfig # from modelkit/eval/config.py Design Principles (P1 FUNDAMENTAL): - CALLS existing APIs from loader/, export/, models/hf/ @@ -97,6 +98,7 @@ class WinMLBuildConfig: optim: Optimization configuration quant: Quantization configuration compile: Compilation configuration + eval: Evaluation configuration Example: from winml.modelkit.config import WinMLBuildConfig @@ -127,7 +129,7 @@ class WinMLBuildConfig: optim: WinMLOptimizationConfig = field(default_factory=WinMLOptimizationConfig) quant: WinMLQuantizationConfig | None = field(default_factory=WinMLQuantizationConfig) compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) - eval_option: WinMLEvaluationConfig | None = None + eval: WinMLEvaluationConfig | None = None @classmethod def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: @@ -136,7 +138,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: export_data = config_dict.get("export", {}) quant_data = config_dict.get("quant") compile_data = config_dict.get("compile") - eval_option_data = config_dict.get("eval_option") + eval_data = config_dict.get("eval") return cls( loader=WinMLLoaderConfig.from_dict(loader_data), export=(WinMLExportConfig.from_dict(export_data) if export_data is not None else None), @@ -147,9 +149,9 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: compile=( WinMLCompileConfig.from_dict(compile_data) if compile_data is not None else None ), - eval_option=( - WinMLEvaluationConfig.from_dict(eval_option_data) - if eval_option_data is not None + eval=( + WinMLEvaluationConfig.from_dict(eval_data) + if eval_data is not None else None ), ) @@ -166,8 +168,8 @@ def to_dict(self) -> dict: loader_dict = self.loader.to_dict() if loader_dict: result["loader"] = loader_dict - if self.eval_option is not None: - result["eval_option"] = self.eval_option.to_dict() + if self.eval is not None: + result["eval"] = self.eval.to_dict() return result def validate(self) -> None: diff --git a/src/winml/modelkit/datasets/config.py b/src/winml/modelkit/datasets/config.py index a431eef82..fd1331b7a 100644 --- a/src/winml/modelkit/datasets/config.py +++ b/src/winml/modelkit/datasets/config.py @@ -25,6 +25,11 @@ class DatasetConfig: columns_mapping: Column name overrides as key=value pairs. If empty, consumer uses its own defaults. streaming: Whether to stream dataset (avoids full download). + build_script: Path to a Python script that builds the dataset locally. + When set alongside ``path``, the script is invoked with + ``--output `` before the dataset is loaded. + label_mapping_file: Path to a JSON file with label mapping. + Resolved into ``label_mapping`` at eval time. """ path: str | None = None @@ -36,6 +41,8 @@ class DatasetConfig: columns_mapping: dict[str, str] = field(default_factory=dict) label_mapping: dict[str, int] | None = None streaming: bool = False + build_script: str | None = None + label_mapping_file: str | None = None def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" @@ -55,4 +62,8 @@ def to_dict(self) -> dict[str, Any]: result["label_mapping"] = self.label_mapping if self.streaming: result["streaming"] = self.streaming + if self.build_script is not None: + result["build_script"] = self.build_script + if self.label_mapping_file is not None: + result["label_mapping_file"] = self.label_mapping_file return result diff --git a/src/winml/modelkit/eval/config.py b/src/winml/modelkit/eval/config.py index 360d59b50..082b78244 100644 --- a/src/winml/modelkit/eval/config.py +++ b/src/winml/modelkit/eval/config.py @@ -59,8 +59,6 @@ class WinMLEvaluationConfig: device: str = "cpu" dataset: DatasetConfig = field(default_factory=DatasetConfig) output_path: Path | None = None - dataset_script: str | None = None - label_mapping_file: str | None = None def to_dict(self) -> dict: """Convert to dictionary for serialization.""" @@ -75,10 +73,6 @@ def to_dict(self) -> dict: result["dataset"] = self.dataset.to_dict() if self.output_path is not None: result["output_path"] = str(self.output_path) - if self.dataset_script is not None: - result["dataset_script"] = self.dataset_script - if self.label_mapping_file is not None: - result["label_mapping_file"] = self.label_mapping_file return result @classmethod @@ -94,6 +88,8 @@ def from_dict(cls, data: dict) -> WinMLEvaluationConfig: seed=ds_data.get("seed", 42), columns_mapping=ds_data.get("columns_mapping", {}), streaming=ds_data.get("streaming", False), + build_script=ds_data.get("build_script"), + label_mapping_file=ds_data.get("label_mapping_file"), ) return cls( model_id=data.get("model_id"), @@ -102,6 +98,4 @@ def from_dict(cls, data: dict) -> WinMLEvaluationConfig: device=data.get("device", "cpu"), dataset=dataset, output_path=(Path(data["output_path"]) if data.get("output_path") else None), - dataset_script=data.get("dataset_script"), - label_mapping_file=data.get("label_mapping_file"), ) From 75f931b4165802f035fa2621cea7ddf18fa00579 Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Wed, 29 Apr 2026 11:26:30 +0800 Subject: [PATCH 05/12] revert: restore build scripts to original state --- scripts/e2e_eval/datasets/build_ai4privacy.py | 9 ++------- scripts/e2e_eval/datasets/build_fairface.py | 9 ++------- scripts/e2e_eval/datasets/build_indonlu_posp.py | 9 ++------- scripts/e2e_eval/datasets/build_pubtables1m_detection.py | 9 ++------- scripts/e2e_eval/datasets/build_pubtables1m_structure.py | 9 ++------- 5 files changed, 10 insertions(+), 35 deletions(-) diff --git a/scripts/e2e_eval/datasets/build_ai4privacy.py b/scripts/e2e_eval/datasets/build_ai4privacy.py index 52724e96d..a56c4f8df 100644 --- a/scripts/e2e_eval/datasets/build_ai4privacy.py +++ b/scripts/e2e_eval/datasets/build_ai4privacy.py @@ -59,16 +59,11 @@ def build_dataset(output_dir: Path) -> None: print("Done.") -_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_ai4privacy" - - def main() -> None: parser = argparse.ArgumentParser(description="Build ai4privacy PII dataset") - parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_ai4privacy)") + parser.add_argument("--output", type=Path, required=True, help="Output directory") args = parser.parse_args() - output_dir = args.output or _DEFAULT_CACHE_DIR - build_dataset(output_dir) - print(output_dir) + build_dataset(args.output) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_fairface.py b/scripts/e2e_eval/datasets/build_fairface.py index 22a77f489..82b3641b0 100644 --- a/scripts/e2e_eval/datasets/build_fairface.py +++ b/scripts/e2e_eval/datasets/build_fairface.py @@ -95,16 +95,11 @@ def build_dataset(output_dir: Path) -> None: print("Done.") -_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_fairface" - - def main() -> None: parser = argparse.ArgumentParser(description="Build fairface validation dataset") - parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_fairface)") + parser.add_argument("--output", type=Path, required=True, help="Output directory") args = parser.parse_args() - output_dir = args.output or _DEFAULT_CACHE_DIR - build_dataset(output_dir) - print(output_dir) + build_dataset(args.output) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_indonlu_posp.py b/scripts/e2e_eval/datasets/build_indonlu_posp.py index 52e276d48..724bef63d 100644 --- a/scripts/e2e_eval/datasets/build_indonlu_posp.py +++ b/scripts/e2e_eval/datasets/build_indonlu_posp.py @@ -44,16 +44,11 @@ def build_dataset(output_dir: Path) -> None: print("Done.") -_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_indonlu_posp" - - def main() -> None: parser = argparse.ArgumentParser(description="Build indonlu posp dataset") - parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_indonlu_posp)") + parser.add_argument("--output", type=Path, required=True, help="Output directory") args = parser.parse_args() - output_dir = args.output or _DEFAULT_CACHE_DIR - build_dataset(output_dir) - print(output_dir) + build_dataset(args.output) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_pubtables1m_detection.py b/scripts/e2e_eval/datasets/build_pubtables1m_detection.py index 687caecd7..043f5b483 100644 --- a/scripts/e2e_eval/datasets/build_pubtables1m_detection.py +++ b/scripts/e2e_eval/datasets/build_pubtables1m_detection.py @@ -198,16 +198,11 @@ def build_dataset(output_dir: Path) -> None: print("Done.") -_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_pubtables1m_detection" - - def main() -> None: parser = argparse.ArgumentParser(description="Build PubTables-1M detection dataset") - parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_pubtables1m_detection)") + parser.add_argument("--output", type=Path, required=True, help="Output directory") args = parser.parse_args() - output_dir = args.output or _DEFAULT_CACHE_DIR - build_dataset(output_dir) - print(output_dir) + build_dataset(args.output) if __name__ == "__main__": diff --git a/scripts/e2e_eval/datasets/build_pubtables1m_structure.py b/scripts/e2e_eval/datasets/build_pubtables1m_structure.py index 775495e0a..9a640717b 100644 --- a/scripts/e2e_eval/datasets/build_pubtables1m_structure.py +++ b/scripts/e2e_eval/datasets/build_pubtables1m_structure.py @@ -206,16 +206,11 @@ def build_dataset(output_dir: Path) -> None: print("Done.") -_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "winml" / "eval_datasets" / "build_pubtables1m_structure" - - def main() -> None: parser = argparse.ArgumentParser(description="Build PubTables-1M structure recognition dataset") - parser.add_argument("--output", type=Path, default=None, help="Output directory (default: ~/.cache/winml/eval_datasets/build_pubtables1m_structure)") + parser.add_argument("--output", type=Path, required=True, help="Output directory") args = parser.parse_args() - output_dir = args.output or _DEFAULT_CACHE_DIR - build_dataset(output_dir) - print(output_dir) + build_dataset(args.output) if __name__ == "__main__": From 151e7a13632f8390957fe83e952433289ef720ed Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Wed, 29 Apr 2026 11:28:25 +0800 Subject: [PATCH 06/12] fix: require dataset.path when build_script is set, always pass --output --- src/winml/modelkit/commands/eval.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index b8c16bb83..dd5de6a30 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -372,6 +372,12 @@ def _run_dataset_script(cfg: object, trust_remote_code: bool) -> None: if not cfg.dataset.build_script: return + if not cfg.dataset.path: + raise click.UsageError( + "dataset.path is required when dataset.build_script is set. " + "The path tells the script where to write the built dataset." + ) + if not trust_remote_code: raise click.UsageError( "--trust-remote-code is required to execute a dataset script." @@ -384,9 +390,8 @@ def _run_dataset_script(cfg: object, trust_remote_code: bool) -> None: if not script_path.exists(): raise click.BadParameter(f"Dataset script not found: {script_path}") - cmd = [sys.executable, str(script_path)] - if cfg.dataset.path: - cmd += ["--output", str(Path(cfg.dataset.path).expanduser())] + cmd = [sys.executable, str(script_path), + "--output", str(Path(cfg.dataset.path).expanduser())] logger.info("Building dataset via %s ...", script_path.name) result = subprocess.run( # noqa: S603 From 7ffb2456d400df7e25d88da7be98d5d3310c69f6 Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Thu, 7 May 2026 10:58:39 +0800 Subject: [PATCH 07/12] refactor: single _build_eval_config, collect_cli_overrides utility, use main's _resolve_model_path --- src/winml/modelkit/commands/eval.py | 245 +++++++++++++++++++------- src/winml/modelkit/datasets/config.py | 6 +- src/winml/modelkit/utils/cli.py | 35 ++++ 3 files changed, 217 insertions(+), 69 deletions(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index e2c6a4b70..b700ee505 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -118,6 +118,18 @@ default=False, help="Enable verbose output.", ) +@click.option( + "--dataset-script", + type=str, + default=None, + help="Path to a Python script that builds the evaluation dataset.", +) +@click.option( + "--trust-remote-code", + is_flag=True, + default=False, + help="Allow execution of dataset scripts. Required when --dataset-script is used.", +) @click.option( "--schema", "show_schema", @@ -143,6 +155,8 @@ def eval( label_mapping: Path | None, output: Path | None, verbose: bool, + dataset_script: str | None, + trust_remote_code: bool, show_schema: bool, config_file: Path | None, ) -> None: @@ -168,94 +182,193 @@ def eval( if verbose or (ctx.obj and ctx.obj.get("debug")): logging.getLogger("winml.modelkit").setLevel(logging.DEBUG) - # Apply build config defaults (CLI explicit options take precedence) - if config_file is not None: - build_cfg = cli_utils.load_build_config(config_file) - if build_cfg.loader and not cli_utils.is_cli_provided(ctx, "task"): - task = build_cfg.loader.task - if build_cfg.quant: - if not cli_utils.is_cli_provided(ctx, "samples"): - samples = build_cfg.quant.samples - if not cli_utils.is_cli_provided(ctx, "dataset_name"): - dataset_name = build_cfg.quant.dataset_name + from ..eval import evaluate + + # ── 1. Build config: defaults ← config file ← CLI ── + cfg = _build_eval_config(ctx, config_file, column, label_mapping) if show_schema: from ..eval import WinMLEvaluator from ..eval.evaluate import _EVALUATOR_REGISTRY - if task is None: + if cfg.task is None: raise click.UsageError( "--schema requires --task. Example: winml eval --schema --task object-detection" ) - cls = _EVALUATOR_REGISTRY.get(task, WinMLEvaluator) - _print_schema(task, cls.schema_info()) + cls = _EVALUATOR_REGISTRY.get(cfg.task, WinMLEvaluator) + _print_schema(cfg.task, cls.schema_info()) return - model_path, model_id = _resolve_model_path( - model=model, - model_id=model_id, - ) + # ── 2. Resolve in place ── + _resolve_model(cfg, model, model_id) + _resolve_device(cfg) + _resolve_label_mapping(cfg) + _run_dataset_script(cfg, trust_remote_code) - # Parse column mappings from --column key=value pairs - columns_mapping: dict[str, str] = {} - for c in column: - if "=" not in c: - raise click.BadParameter( - f"Invalid column format: '{c}'. Use key=value.", - param_hint="--column", - ) - k, v = c.split("=", 1) - columns_mapping[k] = v + logger.debug("Effective eval config: %s", cfg.to_dict()) + + # ── 3. Evaluate ── + try: + result = evaluate(cfg) + _write_and_display(result, cfg.output_path) + except Exception as e: + logger.exception("Evaluation failed") + raise click.ClickException(f"Evaluation failed: {e}") from e - # Parse label mapping from JSON file - parsed_label_mapping = None - if label_mapping: - with Path(label_mapping).open() as f: - parsed_label_mapping = json.load(f) +def _build_eval_config( + ctx: click.Context, + config_file: Path | None, + column: tuple[str, ...], + label_mapping: Path | None, +) -> object: + """Build a WinMLEvaluationConfig with precedence: defaults ← config file ← CLI. + + Reads raw JSON for config-file values so only explicitly-present keys + are applied (avoids overriding with dataclass defaults). + Uses ``collect_cli_overrides`` for automatic CLI-to-field mapping. + """ from ..datasets import DatasetConfig - from ..eval import WinMLEvaluationConfig, evaluate + from ..eval import WinMLEvaluationConfig + from ..utils.config_utils import merge_config + + cfg = WinMLEvaluationConfig() + + # ── Config file layer (only explicitly-present keys) ── + if config_file is not None: + raw = json.loads(config_file.read_text()) + + # Loader task as lowest-priority fallback + loader_data = raw.get("loader", {}) + if "task" in loader_data: + cfg.task = loader_data["task"] + + # Quant fields as fallback (only explicitly-set values) + quant_data = raw.get("quant", {}) + quant_overrides: dict = {} + if "samples" in quant_data: + quant_overrides.setdefault("dataset", {})["samples"] = quant_data["samples"] + if "dataset_name" in quant_data: + quant_overrides.setdefault("dataset", {})["name"] = quant_data["dataset_name"] + if quant_overrides: + cfg = merge_config(cfg, quant_overrides) + + # Eval section overrides quant/loader + eval_data = raw.get("eval") + if eval_data: + cfg = merge_config(cfg, eval_data) + + # ── CLI layer (highest priority, auto-mapped via metadata) ── + overrides = cli_utils.collect_cli_overrides(ctx, type(cfg)) + ds_overrides = cli_utils.collect_cli_overrides(ctx, DatasetConfig) + + # --column is multiple=True; non-empty tuple means user provided it + if column: + columns_mapping: dict[str, str] = {} + for c in column: + if "=" not in c: + raise click.BadParameter( + f"Invalid column format: '{c}'. Use key=value.", + param_hint="--column", + ) + k, v = c.split("=", 1) + columns_mapping[k] = v + ds_overrides["columns_mapping"] = columns_mapping + + if label_mapping is not None: + ds_overrides["label_mapping_file"] = str(label_mapping) + + if ds_overrides: + overrides["dataset"] = ds_overrides + + if overrides: + cfg = merge_config(cfg, overrides) + + return cfg + + +def _resolve_model( + cfg: object, + model: tuple[str, ...], + model_id: str | None, +) -> None: + """Resolve ``-m`` / ``--model-id`` into ``cfg.model_path`` / ``cfg.model_id``.""" + model_path, resolved_id = _resolve_model_path(model=model, model_id=model_id) + cfg.model_path = model_path + cfg.model_id = resolved_id + + +def _resolve_device(cfg: object) -> None: + """Resolve ``'auto'`` → concrete device string on *cfg* in place.""" from ..sysinfo import resolve_device - resolved_device, _ = resolve_device(device) - - ds_config = DatasetConfig( - path=dataset_path, - name=dataset_name, - split=split, - samples=samples, - shuffle=shuffle, - columns_mapping=columns_mapping, - label_mapping=parsed_label_mapping, - streaming=streaming, - ) + resolved, _ = resolve_device(cfg.device) + cfg.device = resolved - config = WinMLEvaluationConfig( - model_path=model_path, - model_id=model_id, - task=task, - device=resolved_device, - dataset=ds_config, - output_path=output, - ) - try: - result = evaluate(config) +def _resolve_label_mapping(cfg: object) -> None: + """Load label-mapping JSON file (if any) into ``cfg.dataset.label_mapping``.""" + if cfg.dataset.label_mapping_file: + with Path(cfg.dataset.label_mapping_file).open() as f: + cfg.dataset.label_mapping = json.load(f) - from rich.console import Console - console = Console() - display_eval_report(result, console) +def _run_dataset_script(cfg: object, trust_remote_code: bool) -> None: + """Run the dataset build script referenced by *cfg*, if any. - if output is not None: - output.parent.mkdir(parents=True, exist_ok=True) - with output.open("w") as f: - json.dump(result.to_dict(), f, indent=2, default=_json_default) - console.print(f"[green]Results saved to:[/green] {output}") + The script is invoked with ``--output `` so the built + dataset lands at the path already configured in the config file. + """ + if not cfg.dataset.build_script: + return - except Exception as e: - logger.exception("Evaluation failed") - raise click.ClickException(f"Evaluation failed: {e}") from e + if not cfg.dataset.path: + raise click.UsageError( + "dataset.path is required when dataset.build_script is set. " + "The path tells the script where to write the built dataset." + ) + + if not trust_remote_code: + raise click.UsageError( + "--trust-remote-code is required to execute a dataset script." + ) + + import subprocess + import sys + + script_path = Path(cfg.dataset.build_script) + if not script_path.exists(): + raise click.BadParameter(f"Dataset script not found: {script_path}") + + cmd = [sys.executable, str(script_path), + "--output", str(Path(cfg.dataset.path).expanduser())] + + logger.info("Building dataset via %s ...", script_path.name) + result = subprocess.run( # noqa: S603 + cmd, + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + raise click.ClickException( + f"Dataset script failed (exit {result.returncode}): " + f"{result.stderr.strip()[-200:] or '(no stderr)'}" + ) + + +def _write_and_display(result: object, output_path: Path | None) -> None: + """Display evaluation results and optionally save to JSON.""" + from rich.console import Console + + console = Console() + display_eval_report(result, console) + + if output_path is not None: + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(result.to_dict(), f, indent=2, default=_json_default) + console.print(f"[green]Results saved to:[/green] {output_path}") def _resolve_model_path( diff --git a/src/winml/modelkit/datasets/config.py b/src/winml/modelkit/datasets/config.py index fd1331b7a..8b30684fd 100644 --- a/src/winml/modelkit/datasets/config.py +++ b/src/winml/modelkit/datasets/config.py @@ -32,8 +32,8 @@ class DatasetConfig: Resolved into ``label_mapping`` at eval time. """ - path: str | None = None - name: str | None = None + path: str | None = field(default=None, metadata={"cli_name": "dataset_path"}) + name: str | None = field(default=None, metadata={"cli_name": "dataset_name"}) split: str = "validation" samples: int = 100 shuffle: bool = True @@ -41,7 +41,7 @@ class DatasetConfig: columns_mapping: dict[str, str] = field(default_factory=dict) label_mapping: dict[str, int] | None = None streaming: bool = False - build_script: str | None = None + build_script: str | None = field(default=None, metadata={"cli_name": "dataset_script"}) label_mapping_file: str | None = None def to_dict(self) -> dict[str, Any]: diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 8d58989f7..3c9b36c06 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -203,3 +203,38 @@ def is_cli_provided(ctx: click.Context, param_name: str) -> bool: """ source = ctx.get_parameter_source(param_name) return source == click.core.ParameterSource.COMMANDLINE + + +def collect_cli_overrides(ctx: click.Context, cls: type) -> dict[str, object]: + """Collect CLI-provided values that match fields on a dataclass. + + Iterates ``ctx.params`` and returns ``{field_name: value}`` for every + CLI param that was explicitly provided AND maps to a field on *cls*. + + Name mapping uses ``field(metadata={"cli_name": ...})`` on the + dataclass. Fields without ``cli_name`` metadata match by name. + + Args: + ctx: Click context. + cls: Target dataclass whose fields define the valid key set. + + Returns: + Dict of ``{field_name: value}`` for CLI-provided params. + """ + import dataclasses + + # Build reverse map: cli_name -> field_name + rename: dict[str, str] = {} + valid_fields: set[str] = set() + for f in dataclasses.fields(cls): + valid_fields.add(f.name) + cli_name = f.metadata.get("cli_name") + if cli_name: + rename[cli_name] = f.name + + overrides: dict[str, object] = {} + for cli_name, value in ctx.params.items(): + field_name = rename.get(cli_name, cli_name) + if field_name in valid_fields and is_cli_provided(ctx, cli_name): + overrides[field_name] = value + return overrides From 5c6f74a2eaf7838202a7d104c779ab0f3cb16f49 Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Thu, 7 May 2026 16:30:24 +0800 Subject: [PATCH 08/12] fix: lazy import WinMLEvaluationConfig to avoid heavy dep import at config load time --- src/winml/modelkit/config/build.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index e18937513..140f830a9 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -51,7 +51,6 @@ from typing import TYPE_CHECKING, Any, overload from ..compiler.configs import WinMLCompileConfig -from ..eval.config import WinMLEvaluationConfig from ..export.config import ( InputTensorSpec, OutputTensorSpec, @@ -64,6 +63,8 @@ from ..utils.config_utils import merge_config +# NOTE: WinMLEvaluationConfig is imported lazily to avoid pulling +# eval/__init__.py which imports heavy deps (torch, sklearn, etc.). # NOTE: MODEL_BUILD_CONFIGS is imported lazily inside generate_build_config() # to avoid circular import: config -> models.hf -> config @@ -72,6 +73,8 @@ import torch from torch import nn + from ..eval.config import WinMLEvaluationConfig # noqa: TC004 + __all__ = [ "WinMLBuildConfig", "generate_build_config", @@ -131,14 +134,26 @@ class WinMLBuildConfig: compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) eval: WinMLEvaluationConfig | None = None + def __post_init__(self) -> None: + # Lazy import: inject into module globals so typing.get_type_hints() + # can resolve the eval field annotation (used by merge_config). + from ..eval.config import WinMLEvaluationConfig + + globals().setdefault("WinMLEvaluationConfig", WinMLEvaluationConfig) + @classmethod def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: """Create config from nested dictionary.""" + from ..eval.config import WinMLEvaluationConfig + loader_data = config_dict.get("loader", {}) export_data = config_dict.get("export", {}) quant_data = config_dict.get("quant") compile_data = config_dict.get("compile") eval_data = config_dict.get("eval") + eval_cfg = None + if eval_data is not None: + eval_cfg = WinMLEvaluationConfig.from_dict(eval_data) return cls( loader=WinMLLoaderConfig.from_dict(loader_data), export=(WinMLExportConfig.from_dict(export_data) if export_data is not None else None), @@ -149,11 +164,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: compile=( WinMLCompileConfig.from_dict(compile_data) if compile_data is not None else None ), - eval=( - WinMLEvaluationConfig.from_dict(eval_data) - if eval_data is not None - else None - ), + eval=eval_cfg, ) def to_dict(self) -> dict: From 5517257c8f01d7a6852883db13ecf5237da580ef Mon Sep 17 00:00:00 2001 From: Shiyi Zheng Date: Thu, 7 May 2026 23:32:21 +0800 Subject: [PATCH 09/12] fix(eval): restore --output override mapping and share trust-remote-code option --- src/winml/modelkit/commands/build.py | 8 +++----- src/winml/modelkit/commands/config.py | 8 ++------ src/winml/modelkit/commands/eval.py | 7 ++----- src/winml/modelkit/eval/config.py | 2 +- src/winml/modelkit/utils/cli.py | 24 ++++++++++++++++++++++++ 5 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 3c1f14c23..66607d3bd 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -27,6 +27,7 @@ import click from rich.logging import RichHandler +from ..utils import cli as cli_utils from ..utils.console import ( detect_model_source, get_console, @@ -295,11 +296,8 @@ def _build_modules( default=None, help="Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0.", ) -@click.option( - "--trust-remote-code", - is_flag=True, - default=False, - help="Trust remote code for custom model architectures (e.g., Mu2).", +@cli_utils.trust_remote_code_option( + optional_message="Trust remote code for custom model architectures (e.g., Mu2)." ) @click.option( "-v", diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 9e1553202..260e4878a 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -29,6 +29,7 @@ import click +from ..utils import cli as cli_utils from ..utils.console import ( get_console, print_command_header, @@ -169,12 +170,7 @@ def _is_onnx_file(model_input: str) -> bool: default=True, help="Exclude compilation from generated config (sets compile=None). Default: exclude.", ) -@click.option( - "--trust-remote-code", - is_flag=True, - default=False, - help="Allow running custom code from model repository", -) +@cli_utils.trust_remote_code_option() def config( hf_model: str | None, task: str | None, diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index b700ee505..2dc9ccf40 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -124,11 +124,8 @@ default=None, help="Path to a Python script that builds the evaluation dataset.", ) -@click.option( - "--trust-remote-code", - is_flag=True, - default=False, - help="Allow execution of dataset scripts. Required when --dataset-script is used.", +@cli_utils.trust_remote_code_option( + optional_message="Required when --dataset-script is used." ) @click.option( "--schema", diff --git a/src/winml/modelkit/eval/config.py b/src/winml/modelkit/eval/config.py index a4addcb11..e1ccd3b47 100644 --- a/src/winml/modelkit/eval/config.py +++ b/src/winml/modelkit/eval/config.py @@ -60,7 +60,7 @@ class WinMLEvaluationConfig: task: str | None = None device: str = "cpu" dataset: DatasetConfig = field(default_factory=DatasetConfig) - output_path: Path | None = None + output_path: Path | None = field(default=None, metadata={"cli_name": "output"}) def to_dict(self) -> dict: """Convert to dictionary for serialization.""" diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 3c9b36c06..5f3c95e8e 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -163,6 +163,30 @@ def build_config_option(func): )(func) +def trust_remote_code_option(optional_message: str | None = None): + """Add shared --trust-remote-code option to a Click command. + + Args: + optional_message: Extra command-specific guidance appended to help text. + + Returns: + Decorator function. + """ + help_text = ( + "Allow executing custom code from model repositories or dataset scripts. " + "Use only with trusted sources." + ) + if optional_message: + help_text = f"{help_text} {optional_message}" + + return click.option( + "--trust-remote-code", + is_flag=True, + default=False, + help=help_text, + ) + + def load_build_config(config_path: Path) -> WinMLBuildConfig: """Load a WinMLBuildConfig from a JSON file. From 316c939bd9334cce16bae0c55c75a3682942f292 Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Fri, 8 May 2026 11:42:43 +0800 Subject: [PATCH 10/12] merge: resolve conflicts with main, integrate ep parameter --- src/winml/modelkit/commands/eval.py | 41 ++++++++++++++++++----------- src/winml/modelkit/eval/config.py | 6 +++++ 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 2dc9ccf40..706b0654a 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -67,6 +67,7 @@ show_default=True, help="Device to run on. 'auto' detects the best available device.", ) +@cli_utils.ep_option(required=False) @click.option( "--samples", type=int, @@ -144,6 +145,7 @@ def eval( dataset_name: str | None, task: str | None, device: str, + ep: str | None, samples: int, split: str, shuffle: bool, @@ -233,24 +235,31 @@ def _build_eval_config( # ── Config file layer (only explicitly-present keys) ── if config_file is not None: - raw = json.loads(config_file.read_text()) + build_cfg = cli_utils.load_build_config(config_file) # Loader task as lowest-priority fallback - loader_data = raw.get("loader", {}) - if "task" in loader_data: - cfg.task = loader_data["task"] - - # Quant fields as fallback (only explicitly-set values) - quant_data = raw.get("quant", {}) - quant_overrides: dict = {} - if "samples" in quant_data: - quant_overrides.setdefault("dataset", {})["samples"] = quant_data["samples"] - if "dataset_name" in quant_data: - quant_overrides.setdefault("dataset", {})["name"] = quant_data["dataset_name"] - if quant_overrides: - cfg = merge_config(cfg, quant_overrides) - - # Eval section overrides quant/loader + if build_cfg.loader and build_cfg.loader.task: + cfg.task = build_cfg.loader.task + + # Compile EP as fallback for --ep + if build_cfg.compile and build_cfg.compile.ep_config: + cfg.ep = build_cfg.compile.ep_config.provider + + # Quant fields as fallback + if build_cfg.quant: + quant_overrides: dict = {} + if build_cfg.quant.samples != 100: # non-default + quant_overrides.setdefault("dataset", {})["samples"] = build_cfg.quant.samples + if build_cfg.quant.dataset_name: + quant_overrides.setdefault("dataset", {})["name"] = build_cfg.quant.dataset_name + if quant_overrides: + cfg = merge_config(cfg, quant_overrides) + + # Eval section overrides quant/loader (read raw JSON for this) + try: + raw = json.loads(config_file.read_text()) + except (json.JSONDecodeError, OSError): + raw = {} eval_data = raw.get("eval") if eval_data: cfg = merge_config(cfg, eval_data) diff --git a/src/winml/modelkit/eval/config.py b/src/winml/modelkit/eval/config.py index e1ccd3b47..4c80dd12d 100644 --- a/src/winml/modelkit/eval/config.py +++ b/src/winml/modelkit/eval/config.py @@ -45,6 +45,8 @@ class WinMLEvaluationConfig: None = build from model_id. task: HF pipeline task. Auto-detected from model_id if omitted. device: Target device for inference. + ep: Explicit execution provider (e.g., "qnn", "dml"). Overrides + device-to-provider mapping when provided. dataset: Dataset configuration. output_path: Path to write JSON results. @@ -59,6 +61,7 @@ class WinMLEvaluationConfig: model_path: str | dict[str, str] | None = None task: str | None = None device: str = "cpu" + ep: str | None = None dataset: DatasetConfig = field(default_factory=DatasetConfig) output_path: Path | None = field(default=None, metadata={"cli_name": "output"}) @@ -72,6 +75,8 @@ def to_dict(self) -> dict: if self.task is not None: result["task"] = self.task result["device"] = self.device + if self.ep is not None: + result["ep"] = self.ep result["dataset"] = self.dataset.to_dict() if self.output_path is not None: result["output_path"] = str(self.output_path) @@ -98,6 +103,7 @@ def from_dict(cls, data: dict) -> WinMLEvaluationConfig: model_path=data.get("model_path"), task=data.get("task"), device=data.get("device", "cpu"), + ep=data.get("ep"), dataset=dataset, output_path=(Path(data["output_path"]) if data.get("output_path") else None), ) From ecab53bc492d2af402ac8da180822270ccce4d2c Mon Sep 17 00:00:00 2001 From: "Shiyi Zheng (from Dev Box)" Date: Fri, 8 May 2026 11:45:48 +0800 Subject: [PATCH 11/12] merge: include remaining main changes from previous merge --- .pipelines/Modelkit E2E Test.yml | 8 + .pipelines/templates/e2e-eval-jobs.yml | 259 ++++++++++++------ pyproject.toml | 1 + scripts/e2e_eval/cache/baseline_cache.json | 80 ++++++ scripts/e2e_eval/run_eval.py | 5 +- .../e2e_eval/testsets/models_with_acc.json | 86 +++--- scripts/e2e_eval/utils/reporter.py | 20 +- src/winml/modelkit/analyze/__init__.py | 2 + .../analyze/core/runtime_checker_query.py | 4 +- src/winml/modelkit/commands/analyze.py | 3 +- src/winml/modelkit/commands/perf.py | 24 +- src/winml/modelkit/eval/evaluate.py | 4 + src/winml/modelkit/session/session.py | 83 +++--- src/winml/modelkit/sysinfo/device.py | 5 + src/winml/modelkit/telemetry/consent.py | 6 +- .../modelkit/telemetry/library/exporter.py | 121 +++++++- .../telemetry/library/serialization.py | 27 +- .../unit/analyze/test_static_analyzer_cli.py | 99 +++++++ tests/unit/commands/test_perf_module.py | 89 ++++++ tests/unit/eval/test_eval.py | 195 +++++++++++++ tests/unit/session/test_winml_session.py | 93 +++++++ tests/unit/telemetry/conftest.py | 4 +- tests/unit/telemetry/library/test_exporter.py | 255 ++++++++++++++++- tests/unit/telemetry/library/test_factory.py | 4 +- .../telemetry/library/test_serialization.py | 38 ++- .../unit/telemetry/test_cache_integration.py | 8 +- tests/unit/telemetry/test_telemetry_init.py | 8 +- 27 files changed, 1326 insertions(+), 205 deletions(-) diff --git a/.pipelines/Modelkit E2E Test.yml b/.pipelines/Modelkit E2E Test.yml index e8a5c24e4..ea39dada0 100644 --- a/.pipelines/Modelkit E2E Test.yml +++ b/.pipelines/Modelkit E2E Test.yml @@ -17,6 +17,13 @@ parameters: displayName: 'Skip already-evaluated models (--continue)' type: boolean default: true + - name: ovDevices + displayName: 'NPU-OV: devices to evaluate' + type: object + default: + - cpu + - gpu + - npu stages: - stage: NPU_QNN @@ -39,6 +46,7 @@ stages: agentSuffix: ov evalDate: ${{ parameters.evalDate }} continueRun: ${{ parameters.continueRun }} + devices: ${{ parameters.ovDevices }} - stage: NPU_AMD displayName: 'E2E Eval — NPU-AMD' diff --git a/.pipelines/templates/e2e-eval-jobs.yml b/.pipelines/templates/e2e-eval-jobs.yml index 1c90999df..ebc486728 100644 --- a/.pipelines/templates/e2e-eval-jobs.yml +++ b/.pipelines/templates/e2e-eval-jobs.yml @@ -15,6 +15,10 @@ parameters: - name: modelTimeout type: number default: 3600 + - name: devices + type: object + default: + - npu jobs: - job: Prepare_${{ parameters.agentSuffix }} @@ -95,54 +99,90 @@ jobs: if (-not $evalDate -or $evalDate -eq 'auto') { $evalDate = Get-Date -Format 'yyyy-MM-dd' } $dir = "${{ parameters.evalOutputBase }}/$evalDate/${{ parameters.agentSuffix }}" Write-Host "##vso[task.setvariable variable=EVAL_DIR;isOutput=true]$dir" - Write-Host "Eval output directory: $dir" + Write-Host "Eval output directory (base): $dir" name: set_output_dir displayName: 'Set eval output directory' + - ${{ each device in parameters.devices }}: + - powershell: | + $args = @( + "run", "python", "scripts/e2e_eval/run_eval.py", + "--list-json", "temp/model_list_${{ device }}.json", + "--device", "${{ device }}" + ) + if ('${{ parameters.continueRun }}' -eq 'True') { + $args += @("--continue", "--output-dir", "$(set_output_dir.EVAL_DIR)/${{ device }}") + } + & uv @args + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Generate model list (${{ device }})' + - powershell: | - $args = @( - "run", "python", "scripts/e2e_eval/run_eval.py", - "--list-json", "temp/model_list.json", - "--device", "npu" - ) - if ('${{ parameters.continueRun }}' -eq 'True') { - $args += @("--continue", "--output-dir", "$(set_output_dir.EVAL_DIR)") + $devices = '${{ join(',', parameters.devices) }}' -split ',' + # rowsByKey: matrix-key -> @{ entry = base-fields; pendingDevices = ordered list } + $rowsByKey = [ordered]@{} + $usedKeys = @{} + + foreach ($device in $devices) { + $listPath = "$(Build.SourcesDirectory)/temp/model_list_$device.json" + if (-not (Test-Path $listPath)) { + Write-Warning "Model list missing for device '$device' at $listPath" + continue + } + $models = Get-Content $listPath | ConvertFrom-Json + if ($null -eq $models) { continue } + # Force collection even if a single object was returned + $models = @($models) + Write-Host "Device '$device': $($models.Count) pending models" + + foreach ($m in $models) { + $slug = (($m.hf_id + '_' + $m.task) -replace '[^A-Za-z0-9]', '_') + if ($usedKeys.ContainsKey($slug)) { + $key = $usedKeys[$slug] + } else { + $key = $slug + $suffix = 2 + while ($rowsByKey.Contains($key)) { + $key = "${slug}_${suffix}" + $suffix++ + } + $usedKeys[$slug] = $key + } + + if (-not $rowsByKey.Contains($key)) { + $rowsByKey[$key] = @{ + entry = @{ + hf_id = [string]$m.hf_id + hf_task = [string]$m.task + priority = [string]$m.priority + model_type = [string]$m.model_type + model_group = [string]$m.group + } + pendingDevices = New-Object System.Collections.Generic.List[string] + } + } + $rowsByKey[$key].pendingDevices.Add($device) + } } - & uv @args - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Generate model list' - - powershell: | - $models = Get-Content "$(Build.SourcesDirectory)/temp/model_list.json" | ConvertFrom-Json - $total = $models.Count + $total = $rowsByKey.Count if ($total -eq 0) { - Write-Host "All models already evaluated — nothing to run" + Write-Host "All models already evaluated for all selected devices — nothing to run" Write-Host "##vso[task.setvariable variable=modelMatrix;isOutput=true]{}" Write-Host "##vso[task.setvariable variable=skipEval;isOutput=true]true" return } $matrix = @{} - for ($i = 0; $i -lt $total; $i++) { - $m = $models[$i] - $slug = (($m.hf_id + '_' + $m.task) -replace '[^A-Za-z0-9]', '_') - $key = $slug - $suffix = 2 - while ($matrix.ContainsKey($key)) { - $key = "${slug}_${suffix}" - $suffix++ - } - $matrix[$key] = @{ - hf_id = [string]$m.hf_id - hf_task = [string]$m.task - priority = [string]$m.priority - model_type = [string]$m.model_type - model_group = [string]$m.group - } + foreach ($key in $rowsByKey.Keys) { + $row = $rowsByKey[$key] + $entry = $row.entry.Clone() + $entry['devices_csv'] = ($row.pendingDevices -join ',') + $matrix[$key] = $entry } $json = $matrix | ConvertTo-Json -Compress -Depth 5 - Write-Host "Prepared matrix for $total models" + Write-Host "Prepared matrix for $total models across $($devices.Count) device(s)" Write-Host "##vso[task.setvariable variable=modelMatrix;isOutput=true]$json" name: set_matrix displayName: 'Create matrix variables' @@ -187,72 +227,109 @@ jobs: Write-Host "Build.SourcesDirectory: $(Build.SourcesDirectory)" Write-Host "Model: $(hf_id) / $(hf_task)" Write-Host "Priority: $(priority)" - Write-Host "Output: $(EVAL_DIR)" - - $uvArgs = @( - "run", "--no-sync", "python", "scripts/e2e_eval/run_eval.py", - "--hf-model", "$(hf_id)", - "--output-dir", "$(EVAL_DIR)", - "--device", "npu", - "--continue", - "--verbose", - "--timeout", "${{ parameters.modelTimeout }}", - "--no-report", - "--clean-cache" - ) - if ("$(hf_task)") { - $uvArgs += @("--task", "$(hf_task)") + Write-Host "Output base: $(EVAL_DIR)" + Write-Host "Pending devices: $(devices_csv)" + + $devices = '$(devices_csv)' -split ',' + $devices = $devices | Where-Object { $_ -and $_.Trim() -ne '' } + if (-not $devices -or $devices.Count -eq 0) { + Write-Warning "No pending devices for $(hf_id) / $(hf_task) — skipping" + exit 0 } - & uv @uvArgs - $evalExit = $LASTEXITCODE - if ($evalExit -ne 0) { - Write-Warning "Model eval exited with code $evalExit for $(hf_id) / $(hf_task) (model failure — non-blocking)" + for ($i = 0; $i -lt $devices.Count; $i++) { + $device = $devices[$i].Trim() + $isLast = ($i -eq ($devices.Count - 1)) + Write-Host "============================================================" + Write-Host "[$($i + 1)/$($devices.Count)] device='$device' (last=$isLast) for $(hf_id)" + Write-Host "============================================================" + + $deviceOutput = "$(EVAL_DIR)/$device" + + $uvArgs = @( + "run", "--no-sync", "python", "scripts/e2e_eval/run_eval.py", + "--hf-model", "$(hf_id)", + "--output-dir", $deviceOutput, + "--device", $device, + "--continue", + "--verbose", + "--timeout", "${{ parameters.modelTimeout }}", + "--no-report" + ) + if ($isLast) { + # Only clean cache after the LAST device so HF download and + # any reusable build artifacts can be shared across devices + # for this model. + $uvArgs += "--clean-cache" + } + if ("$(hf_task)") { + $uvArgs += @("--task", "$(hf_task)") + } + + & uv @uvArgs + $evalExit = $LASTEXITCODE + if ($evalExit -ne 0) { + Write-Warning "Model eval exited with code $evalExit for $(hf_id) / $(hf_task) on device '$device' (model failure — non-blocking)" + } } exit 0 workingDirectory: $(Build.SourcesDirectory) - displayName: 'Run eval for current model' - - - job: Report_${{ parameters.agentSuffix }} - displayName: 'Generate Eval Report (${{ parameters.agentSuffix }})' - dependsOn: - - Prepare_${{ parameters.agentSuffix }} - - EvalModel_${{ parameters.agentSuffix }} - condition: always() - pool: - name: modelkit-selfhost-pool - demands: - - Agent.Name -equals ${{ parameters.agentName }} - variables: - EVAL_DIR: $[ dependencies.Prepare_${{ parameters.agentSuffix }}.outputs['set_output_dir.EVAL_DIR'] ] + displayName: 'Run eval for current model (all pending devices)' - steps: - - checkout: self - clean: false - fetchDepth: 1 - path: s + - ${{ each device in parameters.devices }}: + - job: Report_${{ parameters.agentSuffix }}_${{ device }} + displayName: 'Generate Eval Report (${{ parameters.agentSuffix }} / ${{ device }})' + dependsOn: + - Prepare_${{ parameters.agentSuffix }} + - EvalModel_${{ parameters.agentSuffix }} + condition: always() + pool: + name: modelkit-selfhost-pool + demands: + - Agent.Name -equals ${{ parameters.agentName }} + variables: + EVAL_DIR_BASE: $[ dependencies.Prepare_${{ parameters.agentSuffix }}.outputs['set_output_dir.EVAL_DIR'] ] - - checkout: ModelKitArtifacts - clean: false - fetchDepth: 1 - lfs: true - path: artifacts + steps: + - checkout: self + clean: false + fetchDepth: 1 + path: s - - powershell: | - $uvBin = "$env:USERPROFILE\.local\bin" - $venvDir = "$(Build.SourcesDirectory)\.venv\Scripts" - Write-Host "##vso[task.prependpath]$uvBin" - Write-Host "##vso[task.prependpath]$venvDir" - displayName: 'Activate Python environment' + - checkout: ModelKitArtifacts + clean: false + fetchDepth: 1 + lfs: true + path: artifacts - - script: > - uv run --no-sync python scripts/e2e_eval/generate_report.py - --input-dir $(EVAL_DIR) - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Generate evaluation report' + - powershell: | + $uvBin = "$env:USERPROFILE\.local\bin" + $venvDir = "$(Build.SourcesDirectory)\.venv\Scripts" + Write-Host "##vso[task.prependpath]$uvBin" + Write-Host "##vso[task.prependpath]$venvDir" + displayName: 'Activate Python environment' - - task: PublishPipelineArtifact@1 - inputs: - targetPath: $(EVAL_DIR) - artifactName: EvalReport_${{ parameters.agentSuffix }} - displayName: 'Publish eval results as artifact' + - powershell: | + $deviceDir = "$(EVAL_DIR_BASE)/${{ device }}" + if (-not (Test-Path $deviceDir)) { + Write-Host "No results found at $deviceDir — skipping report" + Write-Host "##vso[task.setvariable variable=HAS_RESULTS]false" + return + } + Write-Host "##vso[task.setvariable variable=HAS_RESULTS]true" + Write-Host "##vso[task.setvariable variable=DEVICE_DIR]$deviceDir" + displayName: 'Resolve device output directory (${{ device }})' + + - script: > + uv run --no-sync python scripts/e2e_eval/generate_report.py + --input-dir $(DEVICE_DIR) + workingDirectory: $(Build.SourcesDirectory) + condition: and(succeeded(), eq(variables['HAS_RESULTS'], 'true')) + displayName: 'Generate evaluation report (${{ device }})' + + - task: PublishPipelineArtifact@1 + inputs: + targetPath: $(DEVICE_DIR) + artifactName: EvalReport_${{ parameters.agentSuffix }}_${{ device }} + condition: and(succeeded(), eq(variables['HAS_RESULTS'], 'true')) + displayName: 'Publish eval results as artifact (${{ device }})' diff --git a/pyproject.toml b/pyproject.toml index 459eb74ec..e114eee5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "diffusers>=0.36", "evaluate>=0.4.6", "fastapi>=0.135.3", + "hf_xet>=1.1.10", "httpx>=0.24.0", "jsonschema>=4.23", "mcp>=0.1.0", diff --git a/scripts/e2e_eval/cache/baseline_cache.json b/scripts/e2e_eval/cache/baseline_cache.json index 806073bb3..7bf0d5614 100644 --- a/scripts/e2e_eval/cache/baseline_cache.json +++ b/scripts/e2e_eval/cache/baseline_cache.json @@ -828,5 +828,85 @@ }, "elapsed": 222.8, "command": "python.exe run_pytorch_baseline.py --model fashion-clip --task zero-shot-image-classification --device cpu --num-samples 1000 --dataset fashion_mnist --split test --columns-mapping {\"input_column\": \"image\"}" + }, + "cross-encoder/nli-deberta-v3-small|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.635, + "num_samples": 200 + }, + "elapsed": 117.9, + "command": "python.exe run_pytorch_baseline.py --model nli-deberta-v3-small --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "joeddav/xlm-roberta-large-xnli|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.395, + "num_samples": 200 + }, + "elapsed": 320.9, + "command": "python.exe run_pytorch_baseline.py --model xlm-roberta-large-xnli --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.71, + "num_samples": 200 + }, + "elapsed": 496.6, + "command": "python.exe run_pytorch_baseline.py --model DeBERTa-v3-large-mnli-fever-anli-ling-wanli --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/deberta-v3-large-zeroshot-v2.0|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.865, + "num_samples": 200 + }, + "elapsed": 495.6, + "command": "python.exe run_pytorch_baseline.py --model deberta-v3-large-zeroshot-v2.0 --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.675, + "num_samples": 200 + }, + "elapsed": 178.6, + "command": "python.exe run_pytorch_baseline.py --model mDeBERTa-v3-base-mnli-xnli --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7|zero-shot-classification|fancyzhx/ag_news||test|200": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.62, + "num_samples": 200 + }, + "elapsed": 179.7, + "command": "python.exe run_pytorch_baseline.py --model mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 --task zero-shot-classification --device cpu --num-samples 200 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli|zero-shot-classification|fancyzhx/ag_news||test|100": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.77, + "num_samples": 100 + }, + "elapsed": 273.3, + "command": "python.exe run_pytorch_baseline.py --model DeBERTa-v3-large-mnli-fever-anli-ling-wanli --task zero-shot-classification --device cpu --num-samples 100 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" + }, + "MoritzLaurer/deberta-v3-large-zeroshot-v2.0|zero-shot-classification|fancyzhx/ag_news||test|100": { + "status": "PASS", + "metric": { + "metric": "accuracy", + "value": 0.86, + "num_samples": 100 + }, + "elapsed": 284.0, + "command": "python.exe run_pytorch_baseline.py --model deberta-v3-large-zeroshot-v2.0 --task zero-shot-classification --device cpu --num-samples 100 --dataset ag_news --split test --columns-mapping {\"input_column\": \"text\", \"label_column\": \"label\", \"candidate_labels\": Tech\"}" } } diff --git a/scripts/e2e_eval/run_eval.py b/scripts/e2e_eval/run_eval.py index 18d533d0c..810d8d1b8 100644 --- a/scripts/e2e_eval/run_eval.py +++ b/scripts/e2e_eval/run_eval.py @@ -1265,6 +1265,7 @@ def main() -> None: device=args.device, eval_types_run=[args.eval_type], accuracy_result=None, + ep=args.ep, ) write_result_json(timeout_result, result_path) results.append(timeout_result) @@ -1360,7 +1361,9 @@ def main() -> None: interrupted = True break - result = build_eval_result(entry, perf_proc, args.device, eval_types_run, accuracy_result) + result = build_eval_result( + entry, perf_proc, args.device, eval_types_run, accuracy_result, ep=args.ep + ) results.append(result) # Write eval_result.json immediately (crash-safe, facts only) diff --git a/scripts/e2e_eval/testsets/models_with_acc.json b/scripts/e2e_eval/testsets/models_with_acc.json index 2d3067b22..a8a3d0adf 100644 --- a/scripts/e2e_eval/testsets/models_with_acc.json +++ b/scripts/e2e_eval/testsets/models_with_acc.json @@ -1278,13 +1278,14 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 200, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech" } } }, @@ -1313,13 +1314,15 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 200, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech", + "hypothesis_template": "This text is about {}." } } }, @@ -1341,23 +1344,6 @@ } } }, - { - "hf_id": "lxyuan/distilbert-base-multilingual-cased-sentiments-student", - "task": "zero-shot-classification", - "model_type": "distilbert", - "group": "Top200", - "priority": "P1", - "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", - "metric": "accuracy", - "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" - } - } - }, { "hf_id": "openai/clip-vit-large-patch14-336", "task": "zero-shot-image-classification", @@ -1383,13 +1369,14 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 100, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech" } } }, @@ -1418,13 +1405,14 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 100, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech" } } }, @@ -1453,13 +1441,14 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 200, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech" } } }, @@ -1487,13 +1476,14 @@ "group": "Top200", "priority": "P1", "dataset_config": { - "path": "nyu-mll/multi_nli", - "split": "validation_matched", + "path": "fancyzhx/ag_news", + "split": "test", + "samples": 200, "metric": "accuracy", "columns_mapping": { - "input_column": "premise", - "label_column": "genre", - "candidate_labels": "fiction,government,slate,telephone,travel" + "input_column": "text", + "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech" } } }, @@ -1551,4 +1541,4 @@ } } } -] \ No newline at end of file +] diff --git a/scripts/e2e_eval/utils/reporter.py b/scripts/e2e_eval/utils/reporter.py index a78db8ff0..a97fef69b 100644 --- a/scripts/e2e_eval/utils/reporter.py +++ b/scripts/e2e_eval/utils/reporter.py @@ -37,12 +37,15 @@ def build_eval_result( device: str, eval_types_run: list[str], accuracy_result: dict | None = None, + ep: str | None = None, ) -> dict: """Build a unified eval_result dict (facts only, no derived fields). perf_proc is the raw subprocess result from run_model(), or None when eval_types_run is ["accuracy"] (accuracy-only mode, perf phase skipped). accuracy_result is the accuracy sub-section dict (or None if not run). + ep is the explicit execution provider (e.g., "qnn", "dml"), or None when + not specified (device-to-provider mapping was used). """ perf_section: dict | None = None if perf_proc is not None: @@ -58,7 +61,7 @@ def build_eval_result( "error": perf_proc.get("error_summary", ""), } - return { + result = { "model": entry.hf_id, "task": entry.task, "device": device, @@ -72,6 +75,10 @@ def build_eval_result( "perf": perf_section, "accuracy": accuracy_result, } + # Optional fields: only include when explicitly provided by the user. + if ep is not None: + result["ep"] = ep + return result # --------------------------------------------------------------------------- @@ -340,6 +347,8 @@ def generate_html_report( { "hf_id": hf_id, "task": task, + "device": r.get("device", ""), + "ep": r.get("ep"), "model_type": r.get("model_type", ""), "group": r.get("group", ""), "priority": r.get("priority", ""), @@ -360,6 +369,15 @@ def generate_html_report( "delta_display": ( format_delta(acc) if acc and not acc.get("skipped") else "" ), + "metric": ( + { + "name": (acc.get("winml_metric") or {}).get("metric"), + "baseline": (acc.get("pytorch_baseline_metric") or {}).get("value"), + "winml": (acc.get("winml_metric") or {}).get("value"), + } + if acc and not acc.get("skipped") + else None + ), } ) diff --git a/src/winml/modelkit/analyze/__init__.py b/src/winml/modelkit/analyze/__init__.py index 80cc1835a..b0a9cddd7 100644 --- a/src/winml/modelkit/analyze/__init__.py +++ b/src/winml/modelkit/analyze/__init__.py @@ -24,6 +24,7 @@ from .core.output_aggregator import OutputAggregator from .core.pattern_extractor import PatternExtractor from .core.runtime_checker import RuntimeChecker +from .core.runtime_checker_query import QDQ_SUFFIX from .models import ( Action, ActionItem, @@ -46,6 +47,7 @@ __all__ = [ + "QDQ_SUFFIX", "Action", "ActionItem", "ActionLevel", diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index 4f85b0e7f..b1a01c05a 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -77,6 +77,8 @@ SNAPSHOT_CHANGED_KEY = "__changed__" SNAPSHOT_DELETED_KEY = "__deleted__" +QDQ_SUFFIX = " (QDQ)" + class _PseudoNode: """Lightweight stand-in for onnx.NodeProto used only for logging in _check_negative_rules.""" @@ -2206,7 +2208,7 @@ def run_for_node( def get_pattern_id(is_qdq): return ( - pattern_match.pattern.pattern_id + " (QDQ)" + pattern_match.pattern.pattern_id + QDQ_SUFFIX if is_qdq else pattern_match.pattern.pattern_id ) diff --git a/src/winml/modelkit/commands/analyze.py b/src/winml/modelkit/commands/analyze.py index 976cea399..708dc7a61 100644 --- a/src/winml/modelkit/commands/analyze.py +++ b/src/winml/modelkit/commands/analyze.py @@ -24,6 +24,7 @@ from rich.table import Table from rich.text import Text +from ..analyze import QDQ_SUFFIX from ..utils import cli as cli_utils from ..utils.constants import normalize_ep_name from ..utils.logging import configure_logging @@ -643,7 +644,7 @@ def on_ep_start(ep_name, operator_counts): def on_node_result(pattern_runtime): """Callback invoked per-node during analysis.""" - op = _display_name(pattern_runtime.pattern_id) + op = _display_name(pattern_runtime.pattern_id).removesuffix(QDQ_SUFFIX) level = pattern_runtime.result.classification.value op_counts = instance_counts.setdefault(op, {}) op_counts[level] = op_counts.get(level, 0) + 1 diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index 841080fe1..600403524 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -504,6 +504,9 @@ def _perf_modules( verbose: bool, console: Console, monitor: bool = False, + device: str = "auto", + ep: str | None = None, + precision: str = "auto", ) -> None: """Run per-module build and benchmark for matching submodules. @@ -523,14 +526,21 @@ def _perf_modules( verbose: If True, log exceptions at DEBUG level. console: Rich console for output. monitor: If True, wrap each per-module benchmark with HWMonitor. + device: Target device policy ("auto", "cpu", "gpu", "npu"). + ep: Explicit execution provider (e.g., "qnn", "dml"). Overrides + device-to-provider mapping when set. + precision: Precision mode passed through to the build stage. """ import json as json_mod import tempfile from ..build import build_hf_model from ..config import generate_hf_build_config + from ..sysinfo import resolve_device from .build import _instantiate_parent_model + resolved_device, _ = resolve_device(device=device) + console.print(f"[dim]Generating module configs for {module_class}...[/dim]") try: @@ -538,6 +548,9 @@ def _perf_modules( model_id=hf_model, task=task, module=module_class, + device=resolved_device, + precision=precision, + ep=ep, ) except Exception as e: console.print(f"[red]Error generating module configs: {e}[/red]") @@ -590,12 +603,18 @@ def _perf_modules( config=cfg, output_dir=Path(tmpdir), pytorch_model=submodule, + ep=ep, + device=resolved_device, ) # Benchmark using WinMLSession from ..session import WinMLSession - session = WinMLSession(str(build_result.final_onnx_path)) + session = WinMLSession( + str(build_result.final_onnx_path), + device=resolved_device, + ep=ep, + ) io_cfg = session.io_config inputs = generate_random_inputs(io_cfg, batch_size=batch_size) @@ -1250,6 +1269,9 @@ def perf( verbose=verbose, console=console, monitor=monitor, + device=device.lower(), + ep=ep.lower() if ep else None, + precision=precision.lower(), ) return diff --git a/src/winml/modelkit/eval/evaluate.py b/src/winml/modelkit/eval/evaluate.py index 1fa949cee..b4444bd32 100644 --- a/src/winml/modelkit/eval/evaluate.py +++ b/src/winml/modelkit/eval/evaluate.py @@ -137,6 +137,8 @@ columns_mapping={ "input_column": "text", "label_column": "label", + "candidate_labels": "World,Sports,Business,Sci/Tech", + "hypothesis_template": "This text is about {}.", }, ), "zero-shot-image-classification": DatasetConfig( @@ -182,6 +184,7 @@ def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel: onnx_path=config.model_path, task=config.task, device=config.device, + ep=config.ep, skip_build=True, hf_config=hf_config, ) @@ -192,6 +195,7 @@ def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel: config.model_id, task=config.task, device=config.device, + ep=config.ep, ) diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index 6349cac62..9c12b7c96 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -301,7 +301,7 @@ def compile(self) -> None: # Log which providers were selected by ORT (based on policy) actual_providers = session.get_providers() logger.info( - "Session created with policy %s, providers: %s", + "Session created with device %s, providers: %s", target_device, actual_providers, ) @@ -432,27 +432,40 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: Note: Returns a **fresh** SessionOptions when using explicit EP to avoid "already registered" errors from repeated calls. """ - # Explicit EP targeting: create fresh opts to avoid double-registration - # Don't filter by device type — trust the user's --ep choice - # (e.g., QNN reports as NPU in get_ep_devices but can target GPU) + # Explicit EP targeting: create fresh opts to avoid double-registration. + # When device is also specified (non-"auto"), narrow by both EP name + # and device type so e.g. `--ep qnn --device cpu` finds QNN-on-CPU + # instead of the first QNN ep_device (which may report as NPU). if self._ep and self._ep != "cpu": target_name = self._EP_NAME_MAP.get(self._ep) if target_name: - matched = self._find_ep_device(target_name) + matched = self._find_ep_device(ep_name=target_name, device=device) if matched: + from ..utils.constants import DEVICE_TYPE_TO_DEVICE + opts = ort.SessionOptions() opts.add_provider_for_devices([matched], self._provider_options) - logger.info("Explicit EP: %s (%s)", self._ep, target_name) + resolved = DEVICE_TYPE_TO_DEVICE.get( + matched.device.type, str(matched.device.type) + ) + logger.info( + "Explicit EP: %s (%s) device=%s -> %s", + self._ep, + target_name, + device, + resolved, + ) return opts logger.warning( - "EP '%s' (%s) not found in available devices", + "EP '%s' (%s) not found for device '%s'", self._ep, target_name, + device, ) # No explicit EP — discover available EP for this device type if not self._ep and device.lower() != "cpu": - matched = self._find_ep_for_device(device) + matched = self._find_ep_device(device=device) if matched: opts = ort.SessionOptions() opts.add_provider_for_devices([matched], self._provider_options) @@ -465,36 +478,34 @@ def _build_session_options(self, device: str) -> ort.SessionOptions: device.lower(), ort.OrtExecutionProviderDevicePolicy.PREFER_NPU ) opts.set_provider_selection_policy(policy) + logger.info("Using provider selection policy %s for device %s", policy, device) return opts @staticmethod - def _find_ep_device(ep_name: str) -> Any: - """Find the first OrtEpDevice matching the given EP name. - - Args: - ep_name: Full EP name (e.g., "DmlExecutionProvider"). - - Returns: - The matching OrtEpDevice, or None if not found. - """ - for ep_dev in ort.get_ep_devices(): - if ep_dev.ep_name == ep_name: - return ep_dev - return None - - @staticmethod - def _find_ep_for_device(device: str) -> Any: - """Find the first available OrtEpDevice for the given device type. - - Queries ``ort.get_ep_devices()`` and returns the first EP whose - hardware device type matches (e.g., device="gpu" matches GPU EPs). + def _find_ep_device(device: str, ep_name: str | None = None) -> Any: + """Find the first OrtEpDevice matching the given filters. + + Behavior: + - ``ep_name`` set, ``device == "auto"`` → first ep_device + matching ``ep_name`` (or None). + - ``ep_name`` unset, ``device == "auto"`` → ``None`` (no + effective filter — refuse to pick an arbitrary ep_device). + - ``ep_name`` unset, ``device`` is a concrete type → first + ep_device matching that device type (or None). + - Both set → ep_device must satisfy both (or None). Note: Selection order is determined by the ORT EP registry, which is not part of any documented contract. On systems where multiple EPs match the same device type (e.g., QNN and DML both appear as GPU), - the result is registry-order dependent. When a specific EP is - required, use ``self._ep`` to bypass this discovery path entirely. + a device-only query returns the first one in registry order. Pass + ``ep_name`` to disambiguate. + + Args: + device: Device policy ("cpu", "gpu", "npu", "auto"). ``"auto"`` + and unknown strings act as no-op device filters. + ep_name: Full EP name (e.g., "DmlExecutionProvider"), or None + to skip EP-name filtering. Returns: The matching OrtEpDevice, or None if not found. @@ -502,11 +513,17 @@ def _find_ep_for_device(device: str) -> Any: from ..utils.constants import DEVICE_TO_DEVICE_TYPE device_type = DEVICE_TO_DEVICE_TYPE.get(device.upper()) - if device_type is None: + + # No effective filter — refuse to pick an arbitrary ep_device. + if not ep_name and device_type is None: return None + for ep_dev in ort.get_ep_devices(): - if ep_dev.device.type == device_type: - return ep_dev + if ep_name and ep_dev.ep_name != ep_name: + continue + if device_type is not None and ep_dev.device.type != device_type: + continue + return ep_dev return None def _validate_inputs(self, inputs: dict[str, Any]) -> None: diff --git a/src/winml/modelkit/sysinfo/device.py b/src/winml/modelkit/sysinfo/device.py index 35ad1d403..c7fc1a942 100644 --- a/src/winml/modelkit/sysinfo/device.py +++ b/src/winml/modelkit/sysinfo/device.py @@ -188,6 +188,11 @@ def resolve_device(device: str = "auto") -> tuple[str, list[str]]: for dev in available_devices: compatible_eps = _DEVICE_EP_MAP.get(dev, []) if any(ep in available_eps for ep in compatible_eps): + logger.info( + "Auto-selected device '%s' with compatible EPs: %s for auto device", + dev, + sorted(ep for ep in compatible_eps if ep in available_eps), + ) return dev, available_devices # Fallback: CPU is always valid return "cpu", available_devices diff --git a/src/winml/modelkit/telemetry/consent.py b/src/winml/modelkit/telemetry/consent.py index d484b930f..359c61b9e 100644 --- a/src/winml/modelkit/telemetry/consent.py +++ b/src/winml/modelkit/telemetry/consent.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""Consent decision for ModelKit telemetry. +r"""Consent decision for WinML CLI telemetry. A first-run interactive prompt collects user consent (default: accept) and persists it to ``%USERPROFILE%\.winml\config.json``. This module @@ -48,12 +48,12 @@ def _default_config_path() -> Path | None: _CONSENT_VERSION: int = 1 _PROMPT_TEXT = """\ -ModelKit can collect anonymous usage data to help improve the product. +WinML CLI can collect anonymous usage data to help improve the product. What is collected: - Command name, duration, success/failure - Target device/EP (when the command specifies them) - - OS, architecture, ModelKit version + - OS, architecture, WinML CLI version - Unhandled exception types, code locations, and scrubbed error messages (paths trimmed, length capped, PII patterns scrubbed) diff --git a/src/winml/modelkit/telemetry/library/exporter.py b/src/winml/modelkit/telemetry/library/exporter.py index 5630606e3..49e86c076 100644 --- a/src/winml/modelkit/telemetry/library/exporter.py +++ b/src/winml/modelkit/telemetry/library/exporter.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging +import time from datetime import datetime, timezone from typing import TYPE_CHECKING @@ -15,7 +16,7 @@ from opentelemetry.sdk._logs.export import LogRecordExporter, LogRecordExportResult from .._cache import _PersistentCache -from .serialization import _build_envelope, _serialize_batch +from .serialization import _build_envelope, _envelope_ikey, _serialize_batch if TYPE_CHECKING: @@ -26,6 +27,9 @@ _LOGGER = logging.getLogger(__name__) _HTTP_TIMEOUT = 10.0 +# Truncate response body excerpts in DEBUG logs so a backend that decides +# to dump a large diagnostic payload doesn't flood the log. +_RESPONSE_BODY_LOG_LIMIT = 200 class OneCollectorLogExporter(LogRecordExporter): @@ -46,12 +50,26 @@ def __init__( raise ValueError("ikey must be non-empty") if not endpoint: raise ValueError("endpoint must be non-empty") - self._ikey = ikey + # OneCollector requires the envelope's iKey field to be + # ``o:`` (just the prefix portion of the full ikey), + # while the x-apikey HTTP header carries the full ikey. Compute + # the envelope form once and cache it; a malformed ikey raises + # ValueError, which Telemetry._try_init catches to disable + # telemetry rather than crash the CLI. + self._envelope_ikey = _envelope_ikey(ikey) + # Bare tenant_token (no "o:" prefix) is what the ``kill-tokens`` + # response header lists, so we keep it around for membership checks. + self._tenant_token = self._envelope_ikey.removeprefix("o:") self._endpoint = endpoint self._cache = cache if cache is not None else _PersistentCache() # First export() flushes the cache before sending the new batch; # subsequent exports go straight through. self._cache_flushed = False + # OneCollector's ``kill-tokens`` directive: when set, our tenant + # is on the backend's deny list and we must stop sending until + # this epoch second. In-memory only -- process restart re-hits + # the 401 once and re-records the kill before short-circuiting. + self._killed_until: float | None = None # _shutdown is read on the BatchLogRecordProcessor export thread and # written on the shutdown thread; bool assignment is atomic under the # CPython GIL, so no lock is needed. @@ -82,10 +100,22 @@ def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult: call. Persists the current batch to the cache on POST failure so the next process can retry — :class:`BatchLogRecordProcessor` only re-queues in memory and loses the queue on process exit. + + While the tenant is under ``kill-tokens``, this is a no-op that + returns ``SUCCESS`` to keep the BatchLogRecordProcessor from + re-queueing events the backend has explicitly told us to stop + sending; envelopes for that window are dropped, not cached. """ if self._shutdown or not batch: return LogRecordExportResult.SUCCESS + if self._is_killed(): + _LOGGER.debug( + "telemetry export skipped: tenant under kill-tokens for %.0fs more", + (self._killed_until or 0) - time.time(), + ) + return LogRecordExportResult.SUCCESS + try: envelopes = [self._to_envelope(ld) for ld in batch] except Exception: @@ -94,20 +124,37 @@ def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult: # First-call cache flush: try to send anything left over from a # previous run. Best-effort, single shot — don't loop if the - # backend is still down. + # backend is still down. If the failure is because we just got + # killed, drop the cached batch instead of looping it forever. if not self._cache_flushed: self._cache_flushed = True cached = self._cache.drain() - if cached and not self._post_envelopes(cached): + if cached and not self._post_envelopes(cached) and not self._is_killed(): self._cache.append(cached) + # If the cache-flush POST just killed us, skip the new-batch POST + # too -- another network round-trip would just re-confirm the + # kill. Drop the envelopes (don't cache, same rationale as the + # top-of-export guard) and return SUCCESS so the + # BatchLogRecordProcessor doesn't re-queue them. + if self._is_killed(): + return LogRecordExportResult.SUCCESS + if not self._post_envelopes(envelopes): - self._cache.append(envelopes) + if not self._is_killed(): + self._cache.append(envelopes) return LogRecordExportResult.FAILURE return LogRecordExportResult.SUCCESS def _post_envelopes(self, envelopes: list[dict]) -> bool: - """POST a list of envelopes; return True on 2xx, False otherwise.""" + """POST a list of envelopes; return True on 2xx, False otherwise. + + On non-2xx, parses ``kill-tokens`` / ``kill-duration`` to honor the + backend's tenant-level backoff, and emits a DEBUG log line that + captures the ``Collector-Error`` header and a body excerpt — the + two pieces of info the OneCollector backend uses to communicate + the actual rejection reason. + """ if not envelopes: return True try: @@ -126,9 +173,43 @@ def _post_envelopes(self, envelopes: list[dict]) -> bool: return False if 200 <= response.status_code < 300: return True - _LOGGER.debug("telemetry backend returned %s", response.status_code) + self._record_kill_if_present(response) + _LOGGER.debug( + "telemetry backend returned %s: error=%r body=%s", + response.status_code, + response.headers.get("Collector-Error"), + (response.text or "")[:_RESPONSE_BODY_LOG_LIMIT].replace("\n", " "), + ) return False + def _is_killed(self) -> bool: + """True iff our tenant is currently under a ``kill-tokens`` window.""" + return self._killed_until is not None and time.time() < self._killed_until + + def _record_kill_if_present(self, response: requests.Response) -> None: + """Honor an inbound ``kill-tokens`` directive that names our tenant. + + No-op for any other 4xx/5xx response. + """ + kill_header = response.headers.get("kill-tokens") + duration_header = response.headers.get("kill-duration") + if not kill_header or not duration_header: + return + try: + duration_s = int(duration_header) + except ValueError: + return + if duration_s <= 0: + return + if self._tenant_token not in _parse_kill_tokens(kill_header): + return + self._killed_until = time.time() + duration_s + _LOGGER.debug( + "telemetry tenant under kill-tokens for %ss (until epoch %.0f)", + duration_s, + self._killed_until, + ) + def force_flush(self, timeout_millis: int = 30_000) -> bool: """No-op: all exports are synchronous.""" return True @@ -150,13 +231,37 @@ def _to_envelope(self, ld: ReadableLogRecord) -> dict: ext = _resource_to_ext(ld.resource) return _build_envelope( name=str(record.body), - ikey=self._ikey, + ikey=self._envelope_ikey, timestamp=timestamp, data=data, ext=ext, ) +def _parse_kill_tokens(header_value: str) -> set[str]: + """Parse the OneCollector ``kill-tokens`` header into a set of tenant_token strings. + + The header is a comma-separated list. Each entry is in the form + ``o:`` optionally followed by ``:`` (e.g. + ``o:abc:all`` or ``o:abc:event_name``). We treat any entry naming + a tenant as a full kill for that tenant — per-event kills aren't + something we exploit today. + """ + if not header_value: + return set() + tokens: set[str] = set() + for raw in header_value.split(","): + entry = raw.strip() + if not entry.startswith("o:"): + continue + rest = entry[2:] + # Strip optional ":" suffix. + tenant = rest.split(":", 1)[0] + if tenant: + tokens.add(tenant) + return tokens + + def _ns_to_datetime(ts_ns: int) -> datetime: return datetime.fromtimestamp(ts_ns / 1_000_000_000, tz=timezone.utc) diff --git a/src/winml/modelkit/telemetry/library/serialization.py b/src/winml/modelkit/telemetry/library/serialization.py index 3dcc966e2..3fdcb91d0 100644 --- a/src/winml/modelkit/telemetry/library/serialization.py +++ b/src/winml/modelkit/telemetry/library/serialization.py @@ -15,6 +15,29 @@ from datetime import datetime +def _envelope_ikey(full_ikey: str) -> str: + """Derive the envelope ``iKey`` value from a OneCollector instrumentation key. + + OneCollector expects two distinct iKey forms on the wire: + + - ``x-apikey`` HTTP header: the full instrumentation key + (``--``). + - Envelope ``iKey`` field: ``o:``, where ``tenant_token`` + is the portion of the ikey before the first ``-``. + + Embedding the full ikey in the envelope's ``iKey`` field is what causes + the backend to reject the request with + ``Collector-Error: Invalid Tenant Token``. + """ + dash = full_ikey.find("-") + if dash <= 0: + raise ValueError( + "OneCollector instrumentation key must contain a non-empty " + "tenant_token portion before the first '-'" + ) + return f"o:{full_ikey[:dash]}" + + def _build_envelope( name: str, ikey: str, @@ -28,7 +51,9 @@ def _build_envelope( ver: schema version, always "4.0" name: event name (e.g. "ModelKitAction") time: ISO8601 UTC, millisecond precision, trailing Z - iKey: OneCollector InstrumentationKey (in "o:" form) + iKey: envelope iKey value -- caller is responsible for supplying + it in the ``o:`` form (use + :func:`_envelope_ikey` to derive it from the full ikey) data: event-specific flat payload ext: common context slots (os, app, device) """ diff --git a/tests/unit/analyze/test_static_analyzer_cli.py b/tests/unit/analyze/test_static_analyzer_cli.py index aa1736630..b9f450906 100644 --- a/tests/unit/analyze/test_static_analyzer_cli.py +++ b/tests/unit/analyze/test_static_analyzer_cli.py @@ -767,3 +767,102 @@ def test_ep_without_device_skips_validation( # Should proceed to analysis (not fail on validation) assert result.exit_code == 0 assert mock_instance.analyze.called + + +class TestQDQNodeDisplayMapping: + """Tests for QDQ node result mapping in the op progress table. + + QDQ-wrapped ops (e.g. Conv surrounded by DQ/Q nodes) produce pattern IDs + like 'OP/ai.onnx/Conv (QDQ)'. The live table keys come from + metadata.operator_counts which uses bare op types ('Conv'). The + on_node_result callback must strip the ' (QDQ)' suffix so results are + attributed to the right row instead of being silently dropped. + """ + + def test_qdq_pattern_id_maps_to_base_op_for_table_key(self) -> None: + """_display_name + removesuffix(QDQ_SUFFIX) maps QDQ pattern IDs to base + op types so instance_counts keys match all_op_counts keys.""" + from winml.modelkit.analyze import QDQ_SUFFIX + from winml.modelkit.commands.analyze import _display_name + + assert _display_name("OP/ai.onnx/Conv (QDQ)").removesuffix(QDQ_SUFFIX) == "Conv" + assert _display_name("OP/ai.onnx/Add (QDQ)").removesuffix(QDQ_SUFFIX) == "Add" + assert _display_name("OP/ai.onnx/Pad (QDQ)").removesuffix(QDQ_SUFFIX) == "Pad" + assert ( + _display_name("OP/ai.onnx/DequantizeLinear").removesuffix(QDQ_SUFFIX) + == "DequantizeLinear" + ) + assert _display_name("OP/ai.onnx/Reshape").removesuffix(QDQ_SUFFIX) == "Reshape" + + @patch("winml.modelkit.commands.analyze.Live") + @patch("winml.modelkit.commands.analyze.Console") + @patch("winml.modelkit.analyze.ONNXStaticAnalyzer") + def test_qdq_wrapped_ops_tracked_under_base_type( + self, + mock_analyzer_class: MagicMock, + mock_console_class: MagicMock, + mock_live_class: MagicMock, + runner: CliRunner, + tmp_path: Path, + mock_analyzer_result: Mock, + ) -> None: + """on_node_result must map 'Conv (QDQ)' → 'Conv' so the table row + shows support counts instead of '...'.""" + model_file = tmp_path / "test.onnx" + model_file.write_bytes(b"dummy") + + # Accumulate per-EP instance counts written by on_node_result so we + # can assert that QDQ-wrapped ops land under the base op type key. + captured_ep_counts: dict = {} + + mock_console = MagicMock() + mock_console_class.return_value = mock_console + + ep_support_mock = Mock() + ep_support_mock.ep_type = "QNNExecutionProvider" + ep_support_mock.classification = {} + ep_support_mock.information = [] + mock_analyzer_result.output.results = [ep_support_mock] + + def invoke_callbacks(**kwargs): + on_ep_start = kwargs.get("on_ep_start") + on_node_result = kwargs.get("on_node_result") + if on_ep_start: + on_ep_start("QNNExecutionProvider", {"Conv": 2, "DequantizeLinear": 4}) + if on_node_result: + for _ in range(2): + pr = Mock() + pr.pattern_id = "OP/ai.onnx/Conv (QDQ)" + pr.result.classification.value = "supported" + on_node_result(pr) + for _ in range(4): + pr = Mock() + pr.pattern_id = "OP/ai.onnx/DequantizeLinear" + pr.result.classification.value = "supported" + on_node_result(pr) + # Capture the instance_counts via _render_analysis_summary call args + return mock_analyzer_result + + mock_instance = Mock() + mock_instance.analyze.side_effect = invoke_callbacks + mock_analyzer_class.return_value = mock_instance + + # Intercept _render_analysis_summary to capture ep_instance_counts + with patch("winml.modelkit.commands.analyze._render_analysis_summary") as mock_summary: + result = runner.invoke( + analyze, + ["--model", str(model_file), "--ep", "QNNExecutionProvider", "--device", "NPU"], + ) + if mock_summary.called: + captured_ep_counts = mock_summary.call_args[0][2] # 3rd positional arg + + assert result.exit_code == 0 + # After the fix, 'Conv (QDQ)' is keyed as 'Conv' in instance_counts. + # ep_instance_counts['QNNExecutionProvider']['Conv'] must be populated + # (not 'Conv (QDQ)') so the Conv row shows counts instead of '...'. + assert mock_summary.called + qnn_counts = captured_ep_counts.get("QNNExecutionProvider", {}) + assert "Conv" in qnn_counts, "Conv (QDQ) results must be stored under 'Conv'" + assert "Conv (QDQ)" not in qnn_counts, "QDQ suffix must be stripped" + assert qnn_counts["Conv"] == {"supported": 2} + assert qnn_counts["DequantizeLinear"] == {"supported": 4} diff --git a/tests/unit/commands/test_perf_module.py b/tests/unit/commands/test_perf_module.py index 46ae2c2c3..b0fee4bc3 100644 --- a/tests/unit/commands/test_perf_module.py +++ b/tests/unit/commands/test_perf_module.py @@ -6,12 +6,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + from click.testing import CliRunner from winml.modelkit.cli import main from winml.modelkit.commands.perf import generate_output_path +if TYPE_CHECKING: + from pathlib import Path + + class TestPerfModuleFlag: """Tests for --module flag on winml perf.""" @@ -44,3 +51,85 @@ def test_module_default_output_includes_class_name(self) -> None: module_path = Path(f"{slug}_{module_class}_perf.json") assert module_class in str(module_path) assert str(module_path) != str(path) + + +class TestPerfModuleParameterForwarding: + """Verify --device/--ep/--precision flow from CLI through _perf_modules + into generate_hf_build_config, build_hf_model, and WinMLSession. + + Regression guard: these kwargs were silently dropped before. + """ + + def test_device_and_ep_forwarded_through_module_path(self, tmp_path: Path) -> None: + # Fake module config -- only the attributes _perf_modules touches + fake_cfg = MagicMock() + fake_cfg.loader.model_type = "bert" + fake_cfg.loader.module_path = "encoder.layer.0" + + fake_build_result = MagicMock() + fake_build_result.final_onnx_path = tmp_path / "model.onnx" + + # Make WinMLSession.perf() raise so the benchmark loop is short-circuited + # via the existing try/except in _perf_modules. We still capture the + # constructor kwargs, which is what we care about. + fake_session = MagicMock() + fake_session.perf.side_effect = RuntimeError("test-skip-benchmark") + + with ( + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("npu", "qnn"), + ), + patch( + "winml.modelkit.config.generate_hf_build_config", + return_value=[fake_cfg], + ) as mock_gen, + patch( + "winml.modelkit.commands.build._instantiate_parent_model", + return_value=MagicMock(), + ), + patch( + "winml.modelkit.build.build_hf_model", + return_value=fake_build_result, + ) as mock_build, + patch( + "winml.modelkit.session.WinMLSession", + return_value=fake_session, + ) as mock_session_cls, + ): + runner = CliRunner() + result = runner.invoke( + main, + [ + "perf", + "-m", + "fake/model", + "--module", + "BertLayer", + "--device", + "npu", + "--ep", + "qnn", + "--iterations", + "1", + "--warmup", + "0", + "-o", + str(tmp_path / "out.json"), + ], + ) + + assert result.exit_code == 0, result.output + + gen_kwargs = mock_gen.call_args.kwargs + assert gen_kwargs["device"] == "npu" + assert gen_kwargs["ep"] == "qnn" + assert gen_kwargs["precision"] == "auto" + + build_kwargs = mock_build.call_args.kwargs + assert build_kwargs["ep"] == "qnn" + assert build_kwargs["device"] == "npu" + + session_kwargs = mock_session_cls.call_args.kwargs + assert session_kwargs["device"] == "npu" + assert session_kwargs["ep"] == "qnn" diff --git a/tests/unit/eval/test_eval.py b/tests/unit/eval/test_eval.py index b70e6785c..e3ebe3377 100644 --- a/tests/unit/eval/test_eval.py +++ b/tests/unit/eval/test_eval.py @@ -693,6 +693,200 @@ def test_cli_evaluate_exception_shown_to_user(self): assert result.exit_code != 0 assert "broken model" in result.output + def test_cli_ep_passed_through(self): + """`--ep ` must propagate to WinMLEvaluationConfig.ep.""" + from winml.modelkit.commands.eval import eval as eval_cmd + + runner = CliRunner() + with ( + patch("winml.modelkit.sysinfo.resolve_device", return_value=("npu", ["npu", "cpu"])), + patch("winml.modelkit.eval.evaluate") as mock_evaluate, + ): + mock_evaluate.return_value = EvalResult( + config=WinMLEvaluationConfig(), + metrics={}, + ) + result = runner.invoke( + eval_cmd, + ["-m", "test/model", "--dataset", "imagenet-1k", "--ep", "qnn"], + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + config = mock_evaluate.call_args[0][0] + assert config.ep == "qnn" + + def test_cli_ep_invalid_value_rejected(self): + """Unknown --ep value must be rejected by Click Choice validation.""" + from winml.modelkit.commands.eval import eval as eval_cmd + + runner = CliRunner() + result = runner.invoke( + eval_cmd, + ["-m", "test/model", "--dataset", "imagenet-1k", "--ep", "bogus_ep"], + ) + assert result.exit_code != 0 + assert "bogus_ep" in result.output.lower() or "invalid" in result.output.lower() + + def test_cli_ep_from_build_config(self, tmp_path): + """When --ep is omitted, ep is read from build_cfg.compile.ep_config.provider.""" + from winml.modelkit.commands.eval import eval as eval_cmd + + config_file = tmp_path / "build.yaml" + config_file.touch() + + fake_build_cfg = MagicMock() + fake_build_cfg.loader = None + fake_build_cfg.compile.ep_config.provider = "dml" + fake_build_cfg.quant = None + + runner = CliRunner() + with ( + patch("winml.modelkit.sysinfo.resolve_device", return_value=("gpu", ["gpu", "cpu"])), + patch( + "winml.modelkit.utils.cli.load_build_config", + return_value=fake_build_cfg, + ), + patch("winml.modelkit.eval.evaluate") as mock_evaluate, + ): + mock_evaluate.return_value = EvalResult( + config=WinMLEvaluationConfig(), + metrics={}, + ) + result = runner.invoke( + eval_cmd, + [ + "-m", + "test/model", + "--dataset", + "imagenet-1k", + "--config", + str(config_file), + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + config = mock_evaluate.call_args[0][0] + assert config.ep == "dml" + + def test_cli_ep_overrides_build_config(self, tmp_path): + """Explicit --ep on the CLI must take precedence over build config value.""" + from winml.modelkit.commands.eval import eval as eval_cmd + + config_file = tmp_path / "build.yaml" + config_file.touch() + + fake_build_cfg = MagicMock() + fake_build_cfg.loader = None + fake_build_cfg.compile.ep_config.provider = "dml" + fake_build_cfg.quant = None + + runner = CliRunner() + with ( + patch("winml.modelkit.sysinfo.resolve_device", return_value=("npu", ["npu", "cpu"])), + patch( + "winml.modelkit.utils.cli.load_build_config", + return_value=fake_build_cfg, + ), + patch("winml.modelkit.eval.evaluate") as mock_evaluate, + ): + mock_evaluate.return_value = EvalResult( + config=WinMLEvaluationConfig(), + metrics={}, + ) + result = runner.invoke( + eval_cmd, + [ + "-m", + "test/model", + "--dataset", + "imagenet-1k", + "--config", + str(config_file), + "--ep", + "qnn", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + config = mock_evaluate.call_args[0][0] + assert config.ep == "qnn" + + +class TestBuildEvalResultEpField: + """Tests for build_eval_result handling of the optional `ep` field.""" + + @staticmethod + def _load_reporter(): + """Load scripts/e2e_eval/utils/reporter.py via importlib (not on sys.path).""" + import importlib.util + import sys + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[3] + utils_dir = repo_root / "scripts" / "e2e_eval" / "utils" + + # Pre-load the sibling module reporter.py imports relatively. + if "_e2e_classifier" not in sys.modules: + spec_c = importlib.util.spec_from_file_location( + "_e2e_classifier", utils_dir / "classifier.py" + ) + mod_c = importlib.util.module_from_spec(spec_c) + sys.modules["_e2e_classifier"] = mod_c + spec_c.loader.exec_module(mod_c) + + # Stub the relative import target so reporter.py's `from .classifier ...` works. + pkg_name = "_e2e_reporter_pkg" + if pkg_name not in sys.modules: + pkg = type(sys)(pkg_name) + pkg.__path__ = [str(utils_dir)] + sys.modules[pkg_name] = pkg + sys.modules[f"{pkg_name}.classifier"] = sys.modules["_e2e_classifier"] + + spec = importlib.util.spec_from_file_location( + f"{pkg_name}.reporter", utils_dir / "reporter.py" + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + def _make_entry(self): + entry = MagicMock() + entry.hf_id = "test/model" + entry.task = "image-classification" + entry.model_type = "resnet" + entry.group = "Test" + entry.priority = "P0" + return entry + + def test_ep_omitted_when_none(self): + reporter = self._load_reporter() + + result = reporter.build_eval_result( + entry=self._make_entry(), + perf_proc=None, + device="cpu", + eval_types_run=["accuracy"], + accuracy_result=None, + ep=None, + ) + assert "ep" not in result + + def test_ep_present_when_provided(self): + reporter = self._load_reporter() + + result = reporter.build_eval_result( + entry=self._make_entry(), + perf_proc=None, + device="npu", + eval_types_run=["accuracy"], + accuracy_result=None, + ep="qnn", + ) + assert result["ep"] == "qnn" + class TestDefaultDatasetImmutability: """Tests that module-level _DEFAULT_DATASETS are not corrupted.""" @@ -836,6 +1030,7 @@ def test_load_model_from_pretrained(self): "test/model", task="image-classification", device="cpu", + ep=None, ) assert result is mock_model diff --git a/tests/unit/session/test_winml_session.py b/tests/unit/session/test_winml_session.py index a0cc54705..c82e0ddd6 100644 --- a/tests/unit/session/test_winml_session.py +++ b/tests/unit/session/test_winml_session.py @@ -680,3 +680,96 @@ def test_perf_stats_accessible_after_context( assert stats.count == 3 # 5 - 2 warmup assert stats.mean_ms > 0 assert stats.p99_ms > 0 + + +class TestFindEpDevice: + """Tests for WinMLSession._find_ep_device combined ep_name + device filter. + + Regression guard: when both filters are set, both must match (AND logic). + Previously these were two separate methods and the explicit-EP path + ignored device, so `--ep qnn --device cpu` could return QNN-on-NPU. + """ + + @staticmethod + def _ep_dev(name: str, dev_type) -> object: + """Build a fake OrtEpDevice-like object.""" + from types import SimpleNamespace + + return SimpleNamespace(ep_name=name, device=SimpleNamespace(type=dev_type)) + + def _patch_devices(self, devices: list) -> object: + """Return a contextmanager that patches ort.get_ep_devices.""" + from unittest.mock import patch + + return patch( + "winml.modelkit.session.session.ort.get_ep_devices", + return_value=devices, + ) + + def test_ep_name_only(self) -> None: + """ep_name filter with device='auto' returns first matching name.""" + import onnxruntime as ort + + devs = [ + self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU), + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(device="auto", ep_name="QNNExecutionProvider") + assert match is not None + assert match.ep_name == "QNNExecutionProvider" + + def test_device_only(self) -> None: + """device filter alone returns first matching device type.""" + import onnxruntime as ort + + devs = [ + self._ep_dev("CPUExecutionProvider", ort.OrtHardwareDeviceType.CPU), + self._ep_dev("DmlExecutionProvider", ort.OrtHardwareDeviceType.GPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(device="gpu") + assert match is not None + assert match.ep_name == "DmlExecutionProvider" + + def test_ep_name_and_device_both_required(self) -> None: + """When both filters are set, both must match (AND logic).""" + import onnxruntime as ort + + devs = [ + # QNN-on-NPU comes first; user asks for QNN-on-CPU + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU), + self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.CPU), + ] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="cpu") + assert match is not None + assert match.device.type == ort.OrtHardwareDeviceType.CPU + + def test_no_match_returns_none(self) -> None: + """Non-matching combination returns None even if individual filters would match.""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="cpu") + assert match is None + + def test_auto_device_acts_as_no_filter(self) -> None: + """device='auto' falls back to ep_name-only matching.""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + match = WinMLSession._find_ep_device(ep_name="QNNExecutionProvider", device="auto") + assert match is not None + assert match.ep_name == "QNNExecutionProvider" + + def test_ep_none_and_device_auto_returns_none(self) -> None: + """ep_name=None and device='auto' → None (no effective filter).""" + import onnxruntime as ort + + devs = [self._ep_dev("QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU)] + with self._patch_devices(devs): + assert WinMLSession._find_ep_device(device="auto") is None + assert WinMLSession._find_ep_device(device="auto", ep_name=None) is None diff --git a/tests/unit/telemetry/conftest.py b/tests/unit/telemetry/conftest.py index d73e8f4d3..d4d9e7b1c 100644 --- a/tests/unit/telemetry/conftest.py +++ b/tests/unit/telemetry/conftest.py @@ -46,7 +46,9 @@ def enabled_telemetry(monkeypatch, isolated_config, clean_env): ready instance should use :func:`running_telemetry`, or call ``Telemetry.get_or_init()`` from inside the test body. """ - monkeypatch.setattr("winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "o:test-key") + monkeypatch.setattr( + "winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "test-tenant-1234" + ) consent_mod._write_stored_consent("enabled") monkeypatch.setattr("sys.stdin.isatty", lambda: True) diff --git a/tests/unit/telemetry/library/test_exporter.py b/tests/unit/telemetry/library/test_exporter.py index bd1419c17..2765d7f3b 100644 --- a/tests/unit/telemetry/library/test_exporter.py +++ b/tests/unit/telemetry/library/test_exporter.py @@ -9,6 +9,8 @@ # LogRecordExportResult (the old names are deprecated aliases). Tests use the # current names throughout. +import logging +import time from unittest.mock import MagicMock, patch import pytest @@ -19,6 +21,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.util.instrumentation import InstrumentationScope +from winml.modelkit.telemetry._cache import _PersistentCache from winml.modelkit.telemetry.library import OneCollectorLogExporter @@ -36,18 +39,29 @@ def _make_log_data(body: str, attrs: dict, resource: Resource | None = None) -> @pytest.fixture -def exporter(): - return OneCollectorLogExporter( - ikey="o:abc", +def exporter(tmp_path): + # Use a tmp_path-scoped cache so the test never reads or writes the + # real user-scoped persistent cache (which can be polluted by prior + # modelkit runs on the dev machine and would inject extra POSTs into + # the test). + # Yield + shutdown so the underlying ``requests.Session`` (and its + # connection pool) is closed at teardown rather than leaked across + # tests. + cache = _PersistentCache(path=tmp_path / "modelkit.cache") + exp = OneCollectorLogExporter( + ikey="abc-def", endpoint="https://example.invalid/OneCollector/1.0/", + cache=cache, ) + yield exp + exp.shutdown() @pytest.mark.parametrize( "ikey,endpoint", [ ("", "https://example.invalid/"), - ("o:abc", ""), + ("abc-def", ""), ("", ""), ], ) @@ -59,6 +73,20 @@ def test_constructor_rejects_empty_ikey_or_endpoint(ikey, endpoint): OneCollectorLogExporter(ikey=ikey, endpoint=endpoint) +@pytest.mark.parametrize( + "ikey", + [ + "noseparator", # no dash at all + "-leading-dash", # empty tenant_token portion + ], +) +def test_constructor_rejects_malformed_ikey(ikey): + """The full ikey must contain a non-empty tenant_token portion before + the first '-', otherwise the envelope iKey can't be derived.""" + with pytest.raises(ValueError): + OneCollectorLogExporter(ikey=ikey, endpoint="https://example.invalid/OneCollector/1.0/") + + def test_export_success_returns_success(exporter): ld = _make_log_data( body="ModelKitAction", @@ -72,16 +100,22 @@ def test_export_success_returns_success(exporter): mock_post.assert_called_once() # OneCollector /OneCollector/1.0/ ingest only accepts x-json-stream # (NDJSON) or bond-compact-binary; application/json is rejected with - # HTTP 415. Auth is via the x-apikey header, not the envelope iKey. + # HTTP 415. Auth is via the x-apikey header (full ikey), and the + # envelope iKey field carries the "o:" form -- the two + # values are intentionally different on the wire. headers = exporter._session.headers assert headers["Content-Type"] == "application/x-json-stream; charset=utf-8" - assert headers["x-apikey"] == "o:abc" + assert headers["x-apikey"] == "abc-def" # Body is NDJSON (one envelope per line, no enclosing array). _, kwargs = mock_post.call_args body = kwargs["data"] assert not body.startswith(b"[") assert b'"ModelKitAction"' in body + # Regression guard: envelope iKey is "o:", NOT the full + # ikey. Sending the full ikey here triggers + # ``Collector-Error: Invalid Tenant Token`` from OneCollector. assert b'"iKey":"o:abc"' in body + assert b'"iKey":"abc-def"' not in body def test_export_connection_error_returns_failure(exporter): @@ -220,3 +254,212 @@ def test_resource_to_ext_unknown_attributes_are_ignored(): # Attributes outside the mapping table do not leak into ext. ext = _resource_to_ext(Resource.create({"os.name": "Windows", "custom.attr": "x"})) assert ext == {"os": {"name": "Windows"}} + + +# --- _parse_kill_tokens direct unit tests --- + + +from winml.modelkit.telemetry.library.exporter import _parse_kill_tokens # noqa: E402 + + +@pytest.mark.parametrize( + "header,expected", + [ + ("", set()), + ("o:abc:all", {"abc"}), + ("o:abc", {"abc"}), # no reason suffix + ("o:abc:all,o:def:event_x", {"abc", "def"}), + (" o:abc:all , o:def ", {"abc", "def"}), # whitespace tolerant + ("garbage,o:abc:all", {"abc"}), # entries without "o:" prefix are skipped + ("o::all", set()), # empty tenant_token portion is rejected + ], +) +def test_parse_kill_tokens(header, expected): + assert _parse_kill_tokens(header) == expected + + +# --- kill-tokens / DEBUG-log behavior on failed POSTs --- + + +def _killed_response(tenant: str = "abc", duration: int = 86_400): + """Build a 401 response that mimics OneCollector's tenant-killed reply.""" + resp = MagicMock(status_code=401) + resp.headers = { + "Collector-Error": "Invalid Tenant Token.", + "kill-tokens": f"o:{tenant}:all", + "kill-duration": str(duration), + } + resp.text = '{"acc":0,"efi":{"InvalidTenantToken":"all"}}' + return resp + + +def test_kill_tokens_recorded_on_failure(exporter): + """On a 401 with kill-tokens for our tenant, exporter records the + kill window and ``_is_killed()`` returns True until it expires.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + assert exporter._is_killed() is False + + with patch.object(exporter._session, "post", return_value=_killed_response()): + exporter.export([ld]) + + assert exporter._is_killed() is True + assert exporter._killed_until is not None + assert exporter._killed_until > time.time() + + +def test_kill_tokens_for_other_tenant_is_ignored(exporter): + """If kill-tokens names a different tenant, our exporter is unaffected.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + other = _killed_response(tenant="not-our-tenant") + with patch.object(exporter._session, "post", return_value=other): + exporter.export([ld]) + assert exporter._is_killed() is False + + +def test_export_skipped_during_kill_window(exporter): + """While killed, export() is a no-op that returns SUCCESS without + even touching the HTTP session.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + + # First export: triggers the kill window. + with patch.object(exporter._session, "post", return_value=_killed_response()) as p1: + exporter.export([ld]) + assert p1.call_count == 1 + assert exporter._is_killed() + + # Second export: must not POST at all. + with patch.object(exporter._session, "post") as p2: + result = exporter.export([ld]) + assert result == LogRecordExportResult.SUCCESS + p2.assert_not_called() + + +def test_kill_drops_envelopes_instead_of_caching(exporter, tmp_path): + """A failed POST that triggered a kill must NOT enqueue the batch in + the persistent cache — caching it just guarantees forever-failure on + future startups within the kill window.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + cache_path = tmp_path / "modelkit.cache" + with patch.object(exporter._session, "post", return_value=_killed_response()): + exporter.export([ld]) + assert exporter._is_killed() + assert not cache_path.exists(), "kill-induced failures must not persist to cache" + + +def test_cache_flush_kill_skips_new_batch_post(tmp_path): + """If the cache-flush POST triggers a kill, the new-batch POST in + the same export() call must be skipped — otherwise we waste a + network round-trip just to re-confirm the kill, and the new + envelopes would either be dropped or wrongly re-cached.""" + cache_path = tmp_path / "modelkit.cache" + cache = _PersistentCache(path=cache_path) + cache.append([{"name": "ModelKitHeartbeat", "iKey": "o:abc"}]) + + exp = OneCollectorLogExporter( + ikey="abc-def", + endpoint="https://example.invalid/OneCollector/1.0/", + cache=cache, + ) + try: + ld = _make_log_data("ModelKitHeartbeat", {}) + with patch.object(exp._session, "post", return_value=_killed_response()) as p: + result = exp.export([ld]) + + # Exactly one POST: the cache flush. The new-batch POST is + # short-circuited by the mid-export kill check. + assert p.call_count == 1 + assert exp._is_killed() + # Returned SUCCESS so BatchLogRecordProcessor doesn't re-queue. + assert result == LogRecordExportResult.SUCCESS + finally: + exp.shutdown() + + +def test_kill_window_expiry_re_enables_post(exporter): + """Past the kill window, export() resumes normal POST behavior.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + + # Use a short window so we can fast-forward past it via monkeypatching. + short_kill = _killed_response(duration=10) + with patch.object(exporter._session, "post", return_value=short_kill): + exporter.export([ld]) + assert exporter._is_killed() + + # Fast-forward: pretend we're 11s past the kill window. + fake_now = (exporter._killed_until or 0) + 1 + with ( + patch("winml.modelkit.telemetry.library.exporter.time.time", return_value=fake_now), + patch.object(exporter._session, "post", return_value=MagicMock(status_code=200)) as p, + ): + result = exporter.export([ld]) + + assert result == LogRecordExportResult.SUCCESS + p.assert_called_once() + + +@pytest.mark.parametrize( + "kill_duration_value", + [ + None, # header absent entirely + "", # header present but empty + "0", # non-positive + "abc", # non-numeric + ], +) +def test_kill_tokens_with_unusable_duration_is_ignored(exporter, kill_duration_value): + """``kill-tokens`` is meaningless without a positive integer + ``kill-duration``. Any of: absent, empty, non-positive, or non-numeric + must leave the exporter unkilled.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + resp = MagicMock(status_code=401) + headers = {"kill-tokens": "o:abc:all"} + if kill_duration_value is not None: + headers["kill-duration"] = kill_duration_value + resp.headers = headers + resp.text = "" + with patch.object(exporter._session, "post", return_value=resp): + exporter.export([ld]) + assert exporter._is_killed() is False + + +def test_post_failure_logs_collector_error_and_body_excerpt(exporter, caplog): + """The DEBUG log on non-2xx must capture both the ``Collector-Error`` + header and a body excerpt — the two pieces OneCollector uses to + communicate the actual rejection reason. Without these in the log, + diagnosing tenant/format misconfigurations requires a live probe.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + resp = MagicMock(status_code=401) + resp.headers = {"Collector-Error": "Invalid Tenant Token."} + resp.text = '{"acc":0,"rej":1,"efi":{"InvalidTenantToken":[0]}}' + + caplog.set_level(logging.DEBUG, logger="winml.modelkit.telemetry.library.exporter") + with patch.object(exporter._session, "post", return_value=resp): + exporter.export([ld]) + + backend_logs = [ + r.getMessage() for r in caplog.records if "telemetry backend returned" in r.getMessage() + ] + assert backend_logs, "expected a DEBUG log line for the 401" + msg = backend_logs[0] + assert "401" in msg + assert "Invalid Tenant Token." in msg + assert "InvalidTenantToken" in msg + + +def test_post_failure_log_truncates_long_body(exporter, caplog): + """A backend that returns a huge body shouldn't flood the DEBUG log.""" + ld = _make_log_data("ModelKitHeartbeat", {}) + resp = MagicMock(status_code=500) + resp.headers = {} + resp.text = "x" * 10_000 + + caplog.set_level(logging.DEBUG, logger="winml.modelkit.telemetry.library.exporter") + with patch.object(exporter._session, "post", return_value=resp): + exporter.export([ld]) + + msg = next( + r.getMessage() for r in caplog.records if "telemetry backend returned" in r.getMessage() + ) + # The truncation cap is _RESPONSE_BODY_LOG_LIMIT (200) bytes. + assert "x" * 200 in msg + assert "x" * 1_000 not in msg diff --git a/tests/unit/telemetry/library/test_factory.py b/tests/unit/telemetry/library/test_factory.py index 727759058..2ace2fccc 100644 --- a/tests/unit/telemetry/library/test_factory.py +++ b/tests/unit/telemetry/library/test_factory.py @@ -43,14 +43,14 @@ def test_default_endpoint_points_at_one_collector(): def test_create_logger_provider_returns_configured_provider(make_provider): resource = Resource.create({"app_version": "0.0.1"}) - provider = make_provider(ikey="o:test", resource=resource) + provider = make_provider(ikey="test-tenant-1234", resource=resource) assert isinstance(provider, LoggerProvider) assert provider.resource.attributes.get("app_version") == "0.0.1" def test_create_logger_provider_with_custom_endpoint(make_provider): provider = make_provider( - ikey="o:test", + ikey="test-tenant-1234", endpoint="https://example.invalid/OneCollector/1.0/", ) assert isinstance(provider, LoggerProvider) diff --git a/tests/unit/telemetry/library/test_serialization.py b/tests/unit/telemetry/library/test_serialization.py index 85ed984f6..acc358a30 100644 --- a/tests/unit/telemetry/library/test_serialization.py +++ b/tests/unit/telemetry/library/test_serialization.py @@ -7,7 +7,11 @@ import pytest -from winml.modelkit.telemetry.library.serialization import _build_envelope, _serialize_batch +from winml.modelkit.telemetry.library.serialization import ( + _build_envelope, + _envelope_ikey, + _serialize_batch, +) def test_build_envelope_basic_shape(): @@ -74,3 +78,35 @@ def test_timestamp_millisecond_precision(microsecond, expected_ms): ts = datetime(2026, 4, 17, 10, 30, 0, microsecond, tzinfo=timezone.utc) envelope = _build_envelope("X", "o:k", ts, {}, {}) assert envelope["time"] == f"2026-04-17T10:30:00.{expected_ms}Z" + + +@pytest.mark.parametrize( + "full_ikey,expected", + [ + # Realistic OneCollector iKey shape: <32hex>--. + ( + "abc123abc123abc123abc123abc12345-aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-1234", + "o:abc123abc123abc123abc123abc12345", + ), + # Minimal valid form: anything non-empty before the first dash. + ("abc-def", "o:abc"), + ("token-rest-of-key", "o:token"), + ], +) +def test_envelope_ikey_extracts_tenant_token_and_prefixes(full_ikey, expected): + """The envelope iKey is ``o:``; the suffix + (ingestion token + GUID) only goes in the ``x-apikey`` header.""" + assert _envelope_ikey(full_ikey) == expected + + +@pytest.mark.parametrize( + "bad_ikey", + [ + "noseparator", # no dash at all + "-leading-dash", # empty tenant_token portion + "", # empty (defense in depth; exporter rejects this earlier) + ], +) +def test_envelope_ikey_rejects_malformed(bad_ikey): + with pytest.raises(ValueError, match="tenant_token"): + _envelope_ikey(bad_ikey) diff --git a/tests/unit/telemetry/test_cache_integration.py b/tests/unit/telemetry/test_cache_integration.py index 2d7fdc861..f3a8ed11b 100644 --- a/tests/unit/telemetry/test_cache_integration.py +++ b/tests/unit/telemetry/test_cache_integration.py @@ -53,7 +53,7 @@ def test_network_failure_persists_envelopes_to_cache(cache, cache_path): """Process 1: net is down. The exporter must write the envelope to disk so process 2 can recover it.""" exporter = OneCollectorLogExporter( - ikey="o:abc", + ikey="abc-def", endpoint="https://example.invalid/", cache=cache, ) @@ -81,7 +81,7 @@ def test_next_process_flushes_cached_envelopes_on_first_export(cache, cache_path assert cache_path.exists() exporter = OneCollectorLogExporter( - ikey="o:abc", + ikey="abc-def", endpoint="https://example.invalid/", cache=cache, ) @@ -110,7 +110,7 @@ def test_cache_flush_only_runs_on_first_export(cache): cache.append([{"name": "stale", "iKey": "o:abc"}]) exporter = OneCollectorLogExporter( - ikey="o:abc", + ikey="abc-def", endpoint="https://example.invalid/", cache=cache, ) @@ -133,7 +133,7 @@ def test_cached_envelopes_re_persisted_on_recovery_failure(cache, cache_path): cache.append(seeded) exporter = OneCollectorLogExporter( - ikey="o:abc", + ikey="abc-def", endpoint="https://example.invalid/", cache=cache, ) diff --git a/tests/unit/telemetry/test_telemetry_init.py b/tests/unit/telemetry/test_telemetry_init.py index 91eca6497..3e8b1a6af 100644 --- a/tests/unit/telemetry/test_telemetry_init.py +++ b/tests/unit/telemetry/test_telemetry_init.py @@ -24,7 +24,9 @@ def test_empty_ikey_makes_telemetry_disabled(clean_env, isolated_config, monkeyp def test_consent_disabled_makes_telemetry_disabled(clean_env, isolated_config, monkeypatch): - monkeypatch.setattr("winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "o:test-key") + monkeypatch.setattr( + "winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "test-tenant-1234" + ) consent_mod._write_stored_consent("disabled") monkeypatch.setattr("sys.stdin.isatty", lambda: True) t = Telemetry.get_or_init() @@ -70,7 +72,9 @@ def test_init_swallows_resource_build_errors(clean_env, isolated_config, monkeyp rather than raise. Without this guard a registry permission error or transient OS failure would crash every CLI invocation. """ - monkeypatch.setattr("winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "o:test-key") + monkeypatch.setattr( + "winml.modelkit.telemetry.constants.INSTRUMENTATION_KEY", "test-tenant-1234" + ) consent_mod._write_stored_consent("enabled") monkeypatch.setattr("sys.stdin.isatty", lambda: True) From 83d5a7c74497c0ed3762b2e5c2e3eb27062b5a3b Mon Sep 17 00:00:00 2001 From: Shiyi Zheng Date: Fri, 8 May 2026 13:38:58 +0800 Subject: [PATCH 12/12] test(eval): add precedence unit test for cli over config defaults --- tests/unit/commands/test_eval.py | 96 ++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/tests/unit/commands/test_eval.py b/tests/unit/commands/test_eval.py index 3cb683400..1853f5294 100644 --- a/tests/unit/commands/test_eval.py +++ b/tests/unit/commands/test_eval.py @@ -7,8 +7,12 @@ from __future__ import annotations +import json +from unittest.mock import patch + import click import pytest +from click.testing import CliRunner from winml.modelkit.commands.eval import _resolve_model_path @@ -184,3 +188,95 @@ def test_plain_and_role_path_mixed_raises(self, onnx_file, onnx_vision): model=(str(onnx_file), f"text-encoder={onnx_vision}"), model_id="some/id", ) + + +# --------------------------------------------------------------------------- +# Config precedence (CLI > config file > dataclass defaults) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def runner() -> CliRunner: + return CliRunner() + + +@pytest.fixture +def eval_config_file(tmp_path): + config = { + "loader": { + "task": "feature-extraction", + }, + "eval": { + "task": "image-classification", + "device": "gpu", + "dataset": { + "path": "timm/mini-imagenet", + "split": "test", + "samples": 33, + }, + }, + } + cfg_path = tmp_path / "eval_config.json" + cfg_path.write_text(json.dumps(config), encoding="utf-8") + return cfg_path + + +class TestEvalConfigPrecedence: + def test_cli_overrides_config_and_config_overrides_defaults( + self, + runner: CliRunner, + eval_config_file, + ): + """Validate precedence: CLI > config file > dataclass defaults.""" + from winml.modelkit.commands.eval import eval as eval_cmd + + captured_cfg = {} + + def _fake_evaluate(cfg): + captured_cfg["cfg"] = cfg + + class _FakeResult: + def __init__(self, config): + self.config = config + self.metrics = {"accuracy": 1.0} + + def to_dict(self): + return { + "metrics": self.metrics, + "config": self.config.to_dict(), + } + + return _FakeResult(cfg) + + with ( + patch("winml.modelkit.eval.evaluate", side_effect=_fake_evaluate), + patch("winml.modelkit.commands.eval._resolve_device", return_value=None), + patch("winml.modelkit.commands.eval._write_and_display", return_value=None), + ): + result = runner.invoke( + eval_cmd, + [ + "--config", + str(eval_config_file), + "-m", + "microsoft/resnet-50", + "--device", + "cpu", + "--samples", + "7", + "--split", + "train", + ], + obj={"debug": False}, + ) + + assert result.exit_code == 0, result.output + cfg = captured_cfg["cfg"] + + # CLI > config + assert cfg.device == "cpu" + assert cfg.dataset.samples == 7 + assert cfg.dataset.split == "train" + + # config > dataclass defaults (task default is None) + assert cfg.task == "image-classification"