Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9072792
validate model task in config
chinazhangchao May 25, 2026
0b90129
Merge branch 'main' of https://github.com/microsoft/WinML-ModelKit in…
chinazhangchao May 25, 2026
3350150
Merge branch 'main' into chao/validatetask
chinazhangchao May 25, 2026
b98bf6d
fix test
chinazhangchao May 25, 2026
d5ae8e4
revert test
chinazhangchao May 25, 2026
39fb54d
Merge branch 'main' of https://github.com/microsoft/WinML-ModelKit in…
chinazhangchao May 25, 2026
ffeefe3
Merge branch 'main' of https://github.com/microsoft/WinML-ModelKit in…
chinazhangchao May 25, 2026
48df30a
Merge branch 'main' into chao/validatetask
chinazhangchao May 25, 2026
10562b3
Merge branch 'main' into chao/validatetask
chinazhangchao May 26, 2026
5b6b7d7
Merge branch 'main' of https://github.com/microsoft/WinML-ModelKit in…
chinazhangchao May 26, 2026
42e58c3
fix comments
chinazhangchao May 26, 2026
97433f3
Merge branch 'main' into chao/validatetask
chinazhangchao May 26, 2026
d0f104a
Merge branch 'main' into chao/validatetask
chinazhangchao May 26, 2026
a6a467a
Merge branch 'main' into chao/validatetask
chinazhangchao May 27, 2026
f4ec418
Merge branch 'main' into chao/validatetask
chinazhangchao May 27, 2026
5168a67
Merge branch 'main' into chao/validatetask
chinazhangchao May 27, 2026
7c6ed9b
fix comments
chinazhangchao May 28, 2026
acaf432
Merge branch 'main' of https://github.com/microsoft/WinML-ModelKit in…
chinazhangchao May 28, 2026
e5d0fbf
Merge branch 'main' into chao/validatetask
chinazhangchao May 28, 2026
53e0e7c
Merge branch 'main' into chao/validatetask
chinazhangchao May 28, 2026
6384905
Merge branch 'main' into chao/validatetask
chinazhangchao May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,28 @@ def _load_model(
model_id: str | None,
trust_remote_code: bool,
random_init: bool = False,
hf_config: Any | None = None,
) -> Any:
"""Load PyTorch model — pretrained or random weights."""
"""Load PyTorch model — pretrained or random weights.

Args:
config: Build config (loader fields used).
model_id: HuggingFace model ID or local path.
trust_remote_code: Whether to trust remote code.
random_init: If True, build with random weights (no download).
hf_config: Optional pre-loaded ``PretrainedConfig`` to reuse. When
provided, skips the ``AutoConfig.from_pretrained`` round-trip in
both the random-init path and the pretrained ``load_hf_model``
path (PR #719 dedup pattern).
"""
task = config.loader.task

if random_init:
from transformers import AutoConfig

if model_id is not None:
if hf_config is not None:
pass
elif model_id is not None:
hf_config = AutoConfig.from_pretrained(model_id)
else:
logger.warning(
Expand Down Expand Up @@ -493,6 +507,7 @@ def _load_model(
model_name_or_path=model_id,
task=task,
trust_remote_code=effective_trust,
hf_config=hf_config,
)
return pytorch_model

Expand Down
191 changes: 187 additions & 4 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,179 @@ def _build_modules(
return results


def _validate_task_supported_for_model(
model_id: str,
task: str,
*,
task_field_name: str = "task",
trust_remote_code: bool = False,
library_name: str = "transformers",
hf_config: Any | None = None,
) -> Any:
"""Validate that a task is supported for a model's architecture.

Private helper for ``winml build`` only. Loads HuggingFace config metadata
and validates against ``TasksManager`` supported-task mapping.

Why this lives here and not in ``loader/`` as public API:
Only ``winml build`` accepts task and model from independent sources
(config JSON's ``loader.task`` + ``--model``) and runs the full
export+optimize+quantize+compile pipeline that benefits from a fast
upfront fail. Other CLI entrypoints get equivalent coverage through
their existing resolution paths:

- ``winml config`` derives task from the model when both are present,
so the mismatch can't be silently constructed.
- ``winml export`` / ``winml perf`` surface incompatibilities through
``resolve_cfg`` -> ``ONNXConfigNotFoundError`` later in the call.

Promoting this to public API would signal that any command should
wire it in, which is not the current design. If a second caller
appears, move this back to ``loader/`` and re-export it.

Args:
model_id: HuggingFace model ID or local path.
task: Requested task name.
task_field_name: Field label used in user-facing error messages.
trust_remote_code: Whether to trust remote/custom code while loading config.
library_name: Source library for TasksManager lookup.
hf_config: Optional pre-loaded HF config. When supplied, the
``AutoConfig.from_pretrained`` round-trip is skipped. Used by
``_validate_loader_tasks_for_model`` to preflight multiple tasks
against the same model without re-fetching.

Returns:
The loaded (or passed-through) HuggingFace config. Callers can reuse
this to avoid a duplicate ``AutoConfig.from_pretrained`` later
(see PR #719 -- same deduping pattern as ``resolve_loader_config``).

Raises:
ValueError: If the task is not supported for the model architecture.
"""
from ..export.io import TASK_SYNONYM_EXTENSIONS, ensure_hf_models_registered
from ..loader.task import get_supported_tasks, normalize_task

if hf_config is None:
from transformers import AutoConfig

hf_config = AutoConfig.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
model_type = getattr(hf_config, "model_type", None)
if not model_type:
return hf_config

# Ensure optimum.exporters.onnx.model_configs is imported before querying
# the registry. TasksManager._SUPPORTED_MODEL_TYPE is populated lazily
# when optimum's ONNX model_configs module is first imported (triggered by
# any import of optimum.exporters.onnx). Without this, get_supported_tasks
# returns [] for models like resnet that are registered there, not in the
# winml custom registry.
ensure_hf_models_registered()

supported_tasks = get_supported_tasks(model_type, library_name=library_name)
# If the upstream registry has no task list for this architecture,
# defer to downstream loader resolution instead of hard-failing here.
if not supported_tasks:
return hf_config

# [1] Verbatim canonical match — definitive accept. Comparing without
# normalization first means an arch that lists `image-feature-extraction`
# in its supported set accepts that name as-is, while a text-only arch
# that lists only `feature-extraction` does not silently accept it via
# Optimum's synonym collapse on this branch.
if task in supported_tasks:
return hf_config

# [2] HF-pipeline-only task names that Optimum's TasksManager does not
# know but the rest of the CLI accepts (e.g. ``next-sentence-prediction``
# handled via HF_TASK_DEFAULTS, ``mask-generation`` preserved for SAM2).
# These are routed downstream by export/io.py::_map_task_synonym, so
# rejecting here would break invocations that ``winml config`` and
# ``winml export`` accept.
if task in TASK_SYNONYM_EXTENSIONS:
return hf_config

# [3] Optimum synonym fallback — e.g. ``masked-lm`` -> ``fill-mask``.
# Accept, but warn so users converge on the canonical spelling.
#
# Known limitation: Optimum collapses text/image variants of
# feature-extraction (``image-feature-extraction`` -> ``feature-extraction``)
# and routes ``sentence-similarity`` -> ``feature-extraction``. This
# branch therefore silently accepts cross-modality combinations such as
# ``--task image-feature-extraction`` against a text-only arch. Such
# mismatches must be caught downstream where the HF-pipeline-keyed
# registries see the un-collapsed ``loader.task`` value.
normalized = normalize_task(task)
normalized_supported = {normalize_task(t) for t in supported_tasks}
if normalized in normalized_supported:
if normalized != task:
logger.warning(
"%s=%r matches via Optimum synonym mapping; consider using the canonical name %r.",
task_field_name,
task,
normalized,
)
return hf_config

supported_list = ", ".join(supported_tasks)
raise ValueError(
f"{task_field_name}='{task}' is not supported for --model {model_id} "
f"(architecture: {model_type}).\n"
f"Supported tasks: {supported_list}."
)


def _validate_loader_tasks_for_model(
Comment thread
chinazhangchao marked this conversation as resolved.
*,
model_id: str | None,
configs: list[WinMLBuildConfig],
trust_remote_code: bool,
) -> Any | None:
"""Validate config loader task(s) against --model architecture.

This runs at command entry before setup/stage output so incompatible
config/model combinations fail with an actionable one-line error.

Loads ``AutoConfig`` at most once and reuses it across every per-task
check, then returns it so the build pipeline can plumb it down to
``load_hf_model`` and avoid the second/third round-trip that PR #719
deduped on the inspect path.

See ``_validate_task_supported_for_model`` for the rationale on why this
preflight is wired into ``winml build`` only.

Returns:
Pre-loaded ``PretrainedConfig`` (caller should pass this into
``_run_single_build`` so ``load_hf_model`` skips its own
``AutoConfig.from_pretrained`` call), or ``None`` when no model_id
was provided / model_id is an ONNX file / no task to validate.
"""
if model_id is None:
return None

if cli_utils.is_onnx_file_path(model_id):
return None

tasks = {
cfg.loader.task for cfg in configs if cfg.loader is not None and cfg.loader.task is not None
}
if not tasks:
return None

hf_config: Any | None = None
for task in sorted(tasks):
hf_config = _validate_task_supported_for_model(
model_id=model_id,
task=task,
task_field_name="config.loader.task",
trust_remote_code=trust_remote_code,
hf_config=hf_config,
)
return hf_config


# =============================================================================
# CLI COMMAND
# =============================================================================
Expand Down Expand Up @@ -476,6 +649,12 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
except ValueError as e:
raise click.UsageError(f"Config validation failed: {e}") from e

preloaded_hf_config = _validate_loader_tasks_for_model(
model_id=model_id,
configs=_configs_to_validate,
trust_remote_code=trust_remote_code,
)

# Build extra kwargs for pipeline control
extra_kwargs: dict[str, Any] = {}
if no_optimize:
Expand Down Expand Up @@ -593,6 +772,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
ep=ep,
device=device,
extra_kwargs=extra_kwargs,
preloaded_hf_config=preloaded_hf_config,
)

except click.UsageError:
Expand Down Expand Up @@ -639,11 +819,10 @@ def _run_single_build(
ep: EPNameOrAlias | None,
device: str | None,
extra_kwargs: dict[str, Any],
preloaded_hf_config: Any | None = None,
) -> None:
"""Run single-model build with Rich Live progress per stage."""
from .config import _is_onnx_file

_is_onnx = model_id is not None and _is_onnx_file(model_id)
_is_onnx = model_id is not None and cli_utils.is_onnx_file_path(model_id)
# Derive source from _is_onnx to guarantee header label matches pipeline
source = "ONNX" if _is_onnx else detect_model_source(model_id)

Expand Down Expand Up @@ -708,6 +887,7 @@ def _run_single_build(
ep=ep,
device=device,
extra_kwargs=extra_kwargs,
preloaded_hf_config=preloaded_hf_config,
)

elapsed = time.monotonic() - start_time
Expand Down Expand Up @@ -1098,6 +1278,7 @@ def _build_hf_pipeline(
ep: EPNameOrAlias | None,
device: str | None,
extra_kwargs: dict[str, Any],
preloaded_hf_config: Any | None = None,
) -> list[tuple[str, float | None]] | None:
"""HF build pipeline with cascading StageLive per stage.

Expand Down Expand Up @@ -1151,7 +1332,9 @@ def _name(base: str) -> str:
sl.set_status("Exporting to ONNX...")

# Load + export (blocking)
pytorch_model = _load_model(config, model_id, trust_remote_code=False)
pytorch_model = _load_model(
config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config
)
t0 = time.monotonic()
export_onnx(
model=pytorch_model,
Expand Down
10 changes: 2 additions & 8 deletions src/winml/modelkit/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ def _apply_stage_overrides(cfg: Any, *, no_quant: bool, no_compile: bool) -> Non
cfg.compile = None


def _is_onnx_file(model_input: str) -> bool:
"""Check if input is a path to an existing .onnx file."""
path = Path(model_input)
return path.suffix == ".onnx" and path.exists()


@click.command("config")
@cli_utils.model_option(required=False, optional_message="Optional when --model-type is provided.")
@click.option(
Expand Down Expand Up @@ -279,12 +273,12 @@ def config(
_shape_config_file = shape_config_path.name

# ONNX file detection: generate simpler config without loader/export
if hf_model and _is_onnx_file(hf_model) and module:
if hf_model and cli_utils.is_onnx_file_path(hf_model) and module:
raise click.UsageError(
"--module is not supported with ONNX file input. "
"Module discovery requires a HuggingFace model."
)
if hf_model and _is_onnx_file(hf_model):
if hf_model and cli_utils.is_onnx_file_path(hf_model):
config_obj = generate_onnx_build_config(
hf_model,
task=task,
Expand Down
13 changes: 9 additions & 4 deletions src/winml/modelkit/loader/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def load_hf_model(
model_class: str | None = None,
user_script: str | None = None,
trust_remote_code: bool = False,
hf_config: PretrainedConfig | None = None,
) -> tuple[nn.Module, PretrainedConfig, str]:
"""Load, detect task, and prepare HuggingFace model.

Expand All @@ -173,6 +174,9 @@ def load_hf_model(
The script must define a class matching `model_class` at module level.
Requires trust_remote_code=True for security.
trust_remote_code: Whether to trust remote code (required for user_script)
hf_config: Optional pre-loaded HF config. When supplied, the
``AutoConfig.from_pretrained`` round-trip is skipped — same dedup
pattern as ``resolve_loader_config(hf_config=...)`` from PR #719.

Returns:
Tuple of (model, hf_config, task)
Expand Down Expand Up @@ -214,10 +218,11 @@ def load_hf_model(
raise ValueError("model_class must be specified when using user_script")

# [1] Load HF Config
hf_config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
)
if hf_config is None:
hf_config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
)

# [2] Task & Model Class Resolution
if user_script is not None:
Expand Down
10 changes: 10 additions & 0 deletions src/winml/modelkit/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,16 @@ def load_build_config(config_path: Path) -> tuple[WinMLBuildConfig, dict]:
return WinMLBuildConfig.from_dict(data), data


def is_onnx_file_path(model_input: str) -> bool:
"""Check if input is a path to an existing ``.onnx`` file.

Shared helper for CLI commands that accept either a HuggingFace model ID
or a local ``.onnx`` file path for the ``-m/--model`` option.
"""
path = Path(model_input)
return path.suffix == ".onnx" and path.exists()


def is_cli_provided(ctx: click.Context, param_name: str) -> bool:
"""Check whether a CLI parameter was explicitly provided by the user.

Expand Down
Loading
Loading