Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import click
from rich.logging import RichHandler

from ..utils import cli as cli_utils
from ..utils.console import (
detect_model_source,
get_console,
Expand Down Expand Up @@ -295,11 +296,8 @@ def _build_modules(
default=None,
help="Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0.",
)
@click.option(
"--trust-remote-code",
is_flag=True,
default=False,
help="Trust remote code for custom model architectures (e.g., Mu2).",
@cli_utils.trust_remote_code_option(
optional_message="Trust remote code for custom model architectures (e.g., Mu2)."
)
@click.option(
"-v",
Expand Down
8 changes: 2 additions & 6 deletions src/winml/modelkit/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import click

from ..utils import cli as cli_utils
from ..utils.console import (
get_console,
print_command_header,
Expand Down Expand Up @@ -169,12 +170,7 @@ def _is_onnx_file(model_input: str) -> bool:
default=True,
help="Exclude compilation from generated config (sets compile=None). Default: exclude.",
)
@click.option(
"--trust-remote-code",
is_flag=True,
default=False,
help="Allow running custom code from model repository",
)
@cli_utils.trust_remote_code_option()
def config(
hf_model: str | None,
task: str | None,
Expand Down
252 changes: 183 additions & 69 deletions src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@
default=False,
help="Enable verbose output.",
)
@click.option(
"--dataset-script",
type=str,
default=None,
help="Path to a Python script that builds the evaluation dataset.",
)
@cli_utils.trust_remote_code_option(
optional_message="Required when --dataset-script is used."
)
@click.option(
"--schema",
"show_schema",
Expand All @@ -145,6 +154,8 @@ def eval(
label_mapping: Path | None,
output: Path | None,
verbose: bool,
dataset_script: str | None,
trust_remote_code: bool,
show_schema: bool,
config_file: Path | None,
) -> None:
Expand All @@ -170,97 +181,200 @@ def eval(
if verbose or (ctx.obj and ctx.obj.get("debug")):
logging.getLogger("winml.modelkit").setLevel(logging.DEBUG)

# Apply build config defaults (CLI explicit options take precedence)
if config_file is not None:
build_cfg = cli_utils.load_build_config(config_file)
if build_cfg.loader and not cli_utils.is_cli_provided(ctx, "task"):
task = build_cfg.loader.task
if build_cfg.compile and not cli_utils.is_cli_provided(ctx, "ep"):
ep = build_cfg.compile.ep_config.provider
if build_cfg.quant:
if not cli_utils.is_cli_provided(ctx, "samples"):
samples = build_cfg.quant.samples
if not cli_utils.is_cli_provided(ctx, "dataset_name"):
dataset_name = build_cfg.quant.dataset_name
from ..eval import evaluate

# ── 1. Build config: defaults ← config file ← CLI ──
cfg = _build_eval_config(ctx, config_file, column, label_mapping)

if show_schema:
from ..eval import WinMLEvaluator
from ..eval.evaluate import _EVALUATOR_REGISTRY

if task is None:
if cfg.task is None:
raise click.UsageError(
"--schema requires --task. Example: winml eval --schema --task object-detection"
)
cls = _EVALUATOR_REGISTRY.get(task, WinMLEvaluator)
_print_schema(task, cls.schema_info())
cls = _EVALUATOR_REGISTRY.get(cfg.task, WinMLEvaluator)
_print_schema(cfg.task, cls.schema_info())
return

model_path, model_id = _resolve_model_path(
model=model,
model_id=model_id,
)
# ── 2. Resolve in place ──
_resolve_model(cfg, model, model_id)
_resolve_device(cfg)
_resolve_label_mapping(cfg)
_run_dataset_script(cfg, trust_remote_code)

# Parse column mappings from --column key=value pairs
columns_mapping: dict[str, str] = {}
for c in column:
if "=" not in c:
raise click.BadParameter(
f"Invalid column format: '{c}'. Use key=value.",
param_hint="--column",
)
k, v = c.split("=", 1)
columns_mapping[k] = v
logger.debug("Effective eval config: %s", cfg.to_dict())

# ── 3. Evaluate ──
try:
result = evaluate(cfg)
_write_and_display(result, cfg.output_path)
except Exception as e:
logger.exception("Evaluation failed")
raise click.ClickException(f"Evaluation failed: {e}") from e

# Parse label mapping from JSON file
parsed_label_mapping = None
if label_mapping:
with Path(label_mapping).open() as f:
parsed_label_mapping = json.load(f)

def _build_eval_config(
ctx: click.Context,
config_file: Path | None,
column: tuple[str, ...],
label_mapping: Path | None,
) -> object:
"""Build a WinMLEvaluationConfig with precedence: defaults ← config file ← CLI.

Reads raw JSON for config-file values so only explicitly-present keys
are applied (avoids overriding with dataclass defaults).
Uses ``collect_cli_overrides`` for automatic CLI-to-field mapping.
"""
from ..datasets import DatasetConfig
from ..eval import WinMLEvaluationConfig, evaluate
from ..eval import WinMLEvaluationConfig
from ..utils.config_utils import merge_config

cfg = WinMLEvaluationConfig()

# ── Config file layer (only explicitly-present keys) ──
if config_file is not None:
build_cfg = cli_utils.load_build_config(config_file)

# Loader task as lowest-priority fallback
if build_cfg.loader and build_cfg.loader.task:
cfg.task = build_cfg.loader.task

# Compile EP as fallback for --ep
if build_cfg.compile and build_cfg.compile.ep_config:
cfg.ep = build_cfg.compile.ep_config.provider

# Quant fields as fallback
if build_cfg.quant:
quant_overrides: dict = {}
if build_cfg.quant.samples != 100: # non-default
quant_overrides.setdefault("dataset", {})["samples"] = build_cfg.quant.samples
if build_cfg.quant.dataset_name:
quant_overrides.setdefault("dataset", {})["name"] = build_cfg.quant.dataset_name
if quant_overrides:
cfg = merge_config(cfg, quant_overrides)

# Eval section overrides quant/loader (read raw JSON for this)
try:
raw = json.loads(config_file.read_text())
except (json.JSONDecodeError, OSError):
raw = {}
eval_data = raw.get("eval")
if eval_data:
cfg = merge_config(cfg, eval_data)

# ── CLI layer (highest priority, auto-mapped via metadata) ──
overrides = cli_utils.collect_cli_overrides(ctx, type(cfg))
ds_overrides = cli_utils.collect_cli_overrides(ctx, DatasetConfig)

# --column is multiple=True; non-empty tuple means user provided it
if column:
columns_mapping: dict[str, str] = {}
for c in column:
if "=" not in c:
raise click.BadParameter(
f"Invalid column format: '{c}'. Use key=value.",
param_hint="--column",
)
k, v = c.split("=", 1)
columns_mapping[k] = v
ds_overrides["columns_mapping"] = columns_mapping

if label_mapping is not None:
ds_overrides["label_mapping_file"] = str(label_mapping)

if ds_overrides:
overrides["dataset"] = ds_overrides

if overrides:
cfg = merge_config(cfg, overrides)

return cfg


def _resolve_model(
cfg: object,
model: tuple[str, ...],
model_id: str | None,
) -> None:
"""Resolve ``-m`` / ``--model-id`` into ``cfg.model_path`` / ``cfg.model_id``."""
model_path, resolved_id = _resolve_model_path(model=model, model_id=model_id)
cfg.model_path = model_path
cfg.model_id = resolved_id


def _resolve_device(cfg: object) -> None:
"""Resolve ``'auto'`` → concrete device string on *cfg* in place."""
from ..sysinfo import resolve_device

resolved_device, _ = resolve_device(device)

ds_config = DatasetConfig(
path=dataset_path,
name=dataset_name,
split=split,
samples=samples,
shuffle=shuffle,
columns_mapping=columns_mapping,
label_mapping=parsed_label_mapping,
streaming=streaming,
)
resolved, _ = resolve_device(cfg.device)
cfg.device = resolved

config = WinMLEvaluationConfig(
model_path=model_path,
model_id=model_id,
task=task,
device=resolved_device,
ep=ep,
dataset=ds_config,
output_path=output,
)

try:
result = evaluate(config)
def _resolve_label_mapping(cfg: object) -> None:
"""Load label-mapping JSON file (if any) into ``cfg.dataset.label_mapping``."""
if cfg.dataset.label_mapping_file:
with Path(cfg.dataset.label_mapping_file).open() as f:
cfg.dataset.label_mapping = json.load(f)

from rich.console import Console

console = Console()
display_eval_report(result, console)
def _run_dataset_script(cfg: object, trust_remote_code: bool) -> None:
"""Run the dataset build script referenced by *cfg*, if any.

if output is not None:
output.parent.mkdir(parents=True, exist_ok=True)
with output.open("w") as f:
json.dump(result.to_dict(), f, indent=2, default=_json_default)
console.print(f"[green]Results saved to:[/green] {output}")
The script is invoked with ``--output <dataset.path>`` so the built
dataset lands at the path already configured in the config file.
"""
if not cfg.dataset.build_script:
return

except Exception as e:
logger.exception("Evaluation failed")
raise click.ClickException(f"Evaluation failed: {e}") from e
if not cfg.dataset.path:
raise click.UsageError(
"dataset.path is required when dataset.build_script is set. "
"The path tells the script where to write the built dataset."
)

if not trust_remote_code:
raise click.UsageError(
"--trust-remote-code is required to execute a dataset script."
)

import subprocess
import sys

script_path = Path(cfg.dataset.build_script)
if not script_path.exists():
raise click.BadParameter(f"Dataset script not found: {script_path}")

cmd = [sys.executable, str(script_path),
"--output", str(Path(cfg.dataset.path).expanduser())]

logger.info("Building dataset via %s ...", script_path.name)
result = subprocess.run( # noqa: S603
cmd,
capture_output=True,
text=True,
timeout=300,
)
if result.returncode != 0:
raise click.ClickException(
f"Dataset script failed (exit {result.returncode}): "
f"{result.stderr.strip()[-200:] or '(no stderr)'}"
)


def _write_and_display(result: object, output_path: Path | None) -> None:
"""Display evaluation results and optionally save to JSON."""
from rich.console import Console

console = Console()
display_eval_report(result, console)

if output_path is not None:
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w") as f:
json.dump(result.to_dict(), f, indent=2, default=_json_default)
console.print(f"[green]Results saved to:[/green] {output_path}")


def _resolve_model_path(
Expand Down
Loading
Loading