diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 627f8a1c6..e9e67aeb3 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -435,14 +435,28 @@ def _load_model( model_id: str | None, trust_remote_code: bool, random_init: bool = False, + hf_config: Any | None = None, ) -> Any: - """Load PyTorch model — pretrained or random weights.""" + """Load PyTorch model — pretrained or random weights. + + Args: + config: Build config (loader fields used). + model_id: HuggingFace model ID or local path. + trust_remote_code: Whether to trust remote code. + random_init: If True, build with random weights (no download). + hf_config: Optional pre-loaded ``PretrainedConfig`` to reuse. When + provided, skips the ``AutoConfig.from_pretrained`` round-trip in + both the random-init path and the pretrained ``load_hf_model`` + path (PR #719 dedup pattern). + """ task = config.loader.task if random_init: from transformers import AutoConfig - if model_id is not None: + if hf_config is not None: + pass + elif model_id is not None: hf_config = AutoConfig.from_pretrained(model_id) else: logger.warning( @@ -493,6 +507,7 @@ def _load_model( model_name_or_path=model_id, task=task, trust_remote_code=effective_trust, + hf_config=hf_config, ) return pytorch_model diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 11db0f68f..1a3cb68f2 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -224,6 +224,179 @@ def _build_modules( return results +def _validate_task_supported_for_model( + model_id: str, + task: str, + *, + task_field_name: str = "task", + trust_remote_code: bool = False, + library_name: str = "transformers", + hf_config: Any | None = None, +) -> Any: + """Validate that a task is supported for a model's architecture. + + Private helper for ``winml build`` only. Loads HuggingFace config metadata + and validates against ``TasksManager`` supported-task mapping. + + Why this lives here and not in ``loader/`` as public API: + Only ``winml build`` accepts task and model from independent sources + (config JSON's ``loader.task`` + ``--model``) and runs the full + export+optimize+quantize+compile pipeline that benefits from a fast + upfront fail. Other CLI entrypoints get equivalent coverage through + their existing resolution paths: + + - ``winml config`` derives task from the model when both are present, + so the mismatch can't be silently constructed. + - ``winml export`` / ``winml perf`` surface incompatibilities through + ``resolve_cfg`` -> ``ONNXConfigNotFoundError`` later in the call. + + Promoting this to public API would signal that any command should + wire it in, which is not the current design. If a second caller + appears, move this back to ``loader/`` and re-export it. + + Args: + model_id: HuggingFace model ID or local path. + task: Requested task name. + task_field_name: Field label used in user-facing error messages. + trust_remote_code: Whether to trust remote/custom code while loading config. + library_name: Source library for TasksManager lookup. + hf_config: Optional pre-loaded HF config. When supplied, the + ``AutoConfig.from_pretrained`` round-trip is skipped. Used by + ``_validate_loader_tasks_for_model`` to preflight multiple tasks + against the same model without re-fetching. + + Returns: + The loaded (or passed-through) HuggingFace config. Callers can reuse + this to avoid a duplicate ``AutoConfig.from_pretrained`` later + (see PR #719 -- same deduping pattern as ``resolve_loader_config``). + + Raises: + ValueError: If the task is not supported for the model architecture. + """ + from ..export.io import TASK_SYNONYM_EXTENSIONS, ensure_hf_models_registered + from ..loader.task import get_supported_tasks, normalize_task + + if hf_config is None: + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained( + model_id, + trust_remote_code=trust_remote_code, + ) + model_type = getattr(hf_config, "model_type", None) + if not model_type: + return hf_config + + # Ensure optimum.exporters.onnx.model_configs is imported before querying + # the registry. TasksManager._SUPPORTED_MODEL_TYPE is populated lazily + # when optimum's ONNX model_configs module is first imported (triggered by + # any import of optimum.exporters.onnx). Without this, get_supported_tasks + # returns [] for models like resnet that are registered there, not in the + # winml custom registry. + ensure_hf_models_registered() + + supported_tasks = get_supported_tasks(model_type, library_name=library_name) + # If the upstream registry has no task list for this architecture, + # defer to downstream loader resolution instead of hard-failing here. + if not supported_tasks: + return hf_config + + # [1] Verbatim canonical match — definitive accept. Comparing without + # normalization first means an arch that lists `image-feature-extraction` + # in its supported set accepts that name as-is, while a text-only arch + # that lists only `feature-extraction` does not silently accept it via + # Optimum's synonym collapse on this branch. + if task in supported_tasks: + return hf_config + + # [2] HF-pipeline-only task names that Optimum's TasksManager does not + # know but the rest of the CLI accepts (e.g. ``next-sentence-prediction`` + # handled via HF_TASK_DEFAULTS, ``mask-generation`` preserved for SAM2). + # These are routed downstream by export/io.py::_map_task_synonym, so + # rejecting here would break invocations that ``winml config`` and + # ``winml export`` accept. + if task in TASK_SYNONYM_EXTENSIONS: + return hf_config + + # [3] Optimum synonym fallback — e.g. ``masked-lm`` -> ``fill-mask``. + # Accept, but warn so users converge on the canonical spelling. + # + # Known limitation: Optimum collapses text/image variants of + # feature-extraction (``image-feature-extraction`` -> ``feature-extraction``) + # and routes ``sentence-similarity`` -> ``feature-extraction``. This + # branch therefore silently accepts cross-modality combinations such as + # ``--task image-feature-extraction`` against a text-only arch. Such + # mismatches must be caught downstream where the HF-pipeline-keyed + # registries see the un-collapsed ``loader.task`` value. + normalized = normalize_task(task) + normalized_supported = {normalize_task(t) for t in supported_tasks} + if normalized in normalized_supported: + if normalized != task: + logger.warning( + "%s=%r matches via Optimum synonym mapping; consider using the canonical name %r.", + task_field_name, + task, + normalized, + ) + return hf_config + + supported_list = ", ".join(supported_tasks) + raise ValueError( + f"{task_field_name}='{task}' is not supported for --model {model_id} " + f"(architecture: {model_type}).\n" + f"Supported tasks: {supported_list}." + ) + + +def _validate_loader_tasks_for_model( + *, + model_id: str | None, + configs: list[WinMLBuildConfig], + trust_remote_code: bool, +) -> Any | None: + """Validate config loader task(s) against --model architecture. + + This runs at command entry before setup/stage output so incompatible + config/model combinations fail with an actionable one-line error. + + Loads ``AutoConfig`` at most once and reuses it across every per-task + check, then returns it so the build pipeline can plumb it down to + ``load_hf_model`` and avoid the second/third round-trip that PR #719 + deduped on the inspect path. + + See ``_validate_task_supported_for_model`` for the rationale on why this + preflight is wired into ``winml build`` only. + + Returns: + Pre-loaded ``PretrainedConfig`` (caller should pass this into + ``_run_single_build`` so ``load_hf_model`` skips its own + ``AutoConfig.from_pretrained`` call), or ``None`` when no model_id + was provided / model_id is an ONNX file / no task to validate. + """ + if model_id is None: + return None + + if cli_utils.is_onnx_file_path(model_id): + return None + + tasks = { + cfg.loader.task for cfg in configs if cfg.loader is not None and cfg.loader.task is not None + } + if not tasks: + return None + + hf_config: Any | None = None + for task in sorted(tasks): + hf_config = _validate_task_supported_for_model( + model_id=model_id, + task=task, + task_field_name="config.loader.task", + trust_remote_code=trust_remote_code, + hf_config=hf_config, + ) + return hf_config + + # ============================================================================= # CLI COMMAND # ============================================================================= @@ -476,6 +649,12 @@ def _patch_device(cfg: WinMLBuildConfig) -> None: except ValueError as e: raise click.UsageError(f"Config validation failed: {e}") from e + preloaded_hf_config = _validate_loader_tasks_for_model( + model_id=model_id, + configs=_configs_to_validate, + trust_remote_code=trust_remote_code, + ) + # Build extra kwargs for pipeline control extra_kwargs: dict[str, Any] = {} if no_optimize: @@ -593,6 +772,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None: ep=ep, device=device, extra_kwargs=extra_kwargs, + preloaded_hf_config=preloaded_hf_config, ) except click.UsageError: @@ -639,11 +819,10 @@ def _run_single_build( ep: EPNameOrAlias | None, device: str | None, extra_kwargs: dict[str, Any], + preloaded_hf_config: Any | None = None, ) -> None: """Run single-model build with Rich Live progress per stage.""" - from .config import _is_onnx_file - - _is_onnx = model_id is not None and _is_onnx_file(model_id) + _is_onnx = model_id is not None and cli_utils.is_onnx_file_path(model_id) # Derive source from _is_onnx to guarantee header label matches pipeline source = "ONNX" if _is_onnx else detect_model_source(model_id) @@ -708,6 +887,7 @@ def _run_single_build( ep=ep, device=device, extra_kwargs=extra_kwargs, + preloaded_hf_config=preloaded_hf_config, ) elapsed = time.monotonic() - start_time @@ -1098,6 +1278,7 @@ def _build_hf_pipeline( ep: EPNameOrAlias | None, device: str | None, extra_kwargs: dict[str, Any], + preloaded_hf_config: Any | None = None, ) -> list[tuple[str, float | None]] | None: """HF build pipeline with cascading StageLive per stage. @@ -1151,7 +1332,9 @@ def _name(base: str) -> str: sl.set_status("Exporting to ONNX...") # Load + export (blocking) - pytorch_model = _load_model(config, model_id, trust_remote_code=False) + pytorch_model = _load_model( + config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config + ) t0 = time.monotonic() export_onnx( model=pytorch_model, diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 4172c6573..489344e68 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -57,12 +57,6 @@ def _apply_stage_overrides(cfg: Any, *, no_quant: bool, no_compile: bool) -> Non cfg.compile = None -def _is_onnx_file(model_input: str) -> bool: - """Check if input is a path to an existing .onnx file.""" - path = Path(model_input) - return path.suffix == ".onnx" and path.exists() - - @click.command("config") @cli_utils.model_option(required=False, optional_message="Optional when --model-type is provided.") @click.option( @@ -279,12 +273,12 @@ def config( _shape_config_file = shape_config_path.name # ONNX file detection: generate simpler config without loader/export - if hf_model and _is_onnx_file(hf_model) and module: + if hf_model and cli_utils.is_onnx_file_path(hf_model) and module: raise click.UsageError( "--module is not supported with ONNX file input. " "Module discovery requires a HuggingFace model." ) - if hf_model and _is_onnx_file(hf_model): + if hf_model and cli_utils.is_onnx_file_path(hf_model): config_obj = generate_onnx_build_config( hf_model, task=task, diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 8c0063dd2..5a90b5828 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -149,6 +149,7 @@ def load_hf_model( model_class: str | None = None, user_script: str | None = None, trust_remote_code: bool = False, + hf_config: PretrainedConfig | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -173,6 +174,9 @@ def load_hf_model( The script must define a class matching `model_class` at module level. Requires trust_remote_code=True for security. trust_remote_code: Whether to trust remote code (required for user_script) + hf_config: Optional pre-loaded HF config. When supplied, the + ``AutoConfig.from_pretrained`` round-trip is skipped — same dedup + pattern as ``resolve_loader_config(hf_config=...)`` from PR #719. Returns: Tuple of (model, hf_config, task) @@ -214,10 +218,11 @@ def load_hf_model( raise ValueError("model_class must be specified when using user_script") # [1] Load HF Config - hf_config = AutoConfig.from_pretrained( - model_name_or_path, - trust_remote_code=trust_remote_code, - ) + if hf_config is None: + hf_config = AutoConfig.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + ) # [2] Task & Model Class Resolution if user_script is not None: diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index b6967c54c..5d5d07ca9 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -292,6 +292,16 @@ def load_build_config(config_path: Path) -> tuple[WinMLBuildConfig, dict]: return WinMLBuildConfig.from_dict(data), data +def is_onnx_file_path(model_input: str) -> bool: + """Check if input is a path to an existing ``.onnx`` file. + + Shared helper for CLI commands that accept either a HuggingFace model ID + or a local ``.onnx`` file path for the ``-m/--model`` option. + """ + path = Path(model_input) + return path.suffix == ".onnx" and path.exists() + + def is_cli_provided(ctx: click.Context, param_name: str) -> bool: """Check whether a CLI parameter was explicitly provided by the user. diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index a444b3463..d1a728211 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -78,6 +78,20 @@ def mock_resolve_device(): yield +@pytest.fixture(autouse=True) +def mock_task_model_compatibility_validator(): + """Default to no-op for preflight task/model compatibility checks. + + Most build command unit tests are CLI plumbing tests and should not hit + HuggingFace config resolution paths. + """ + with patch( + "winml.modelkit.commands.build._validate_task_supported_for_model", + return_value=None, + ): + yield + + @pytest.fixture def runner() -> CliRunner: """Create a CLI test runner.""" @@ -357,6 +371,35 @@ def test_module_array_non_object_entry(self, tmp_path: Path): assert result.exit_code != 0 assert "object" in result.output.lower() + def test_rejects_incompatible_config_task_and_model(self, tmp_path: Path): + """config.loader.task + --model mismatch fails before pipeline starts.""" + cfg = _make_minimal_config_file(tmp_path, task="text-generation") + msg = ( + "config.loader.task='text-generation' is not supported for " + "--model microsoft/resnet-50 (architecture: resnet). " + "Supported tasks: image-classification, image-feature-extraction." + ) + + with ( + patch( + "winml.modelkit.commands.build._validate_task_supported_for_model", + side_effect=ValueError(msg), + ) as mock_validate, + patch("winml.modelkit.commands.build._run_single_build") as mock_run, + ): + result = _invoke(["-c", cfg, "-m", "microsoft/resnet-50", "-o", str(tmp_path / "out")]) + + assert result.exit_code != 0 + assert msg in result.output + mock_validate.assert_called_once_with( + model_id="microsoft/resnet-50", + task="text-generation", + task_field_name="config.loader.task", + trust_remote_code=False, + hf_config=None, + ) + mock_run.assert_not_called() + def test_help_lists_all_options(self): """``--help`` must surface every behavior-bearing option.""" result = _invoke(["--help"]) @@ -1267,7 +1310,7 @@ def test_build_onnx_suffix_but_not_exists_uses_hf( ["-c", str(sample_config_file), "-m", "nonexistent.onnx", "-o", str(output_dir)], obj={"debug": False}, ) - # _is_onnx_file checks suffix AND exists(); nonexistent.onnx + # is_onnx_file_path checks suffix AND exists(); nonexistent.onnx # falls through to HF path since the file doesn't exist on disk assert result.exit_code == 0, f"Build failed: {result.output}" assert mock_build_api.called diff --git a/tests/unit/commands/test_build_validate_task.py b/tests/unit/commands/test_build_validate_task.py new file mode 100644 index 000000000..4448f2977 --- /dev/null +++ b/tests/unit/commands/test_build_validate_task.py @@ -0,0 +1,291 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for `_validate_task_supported_for_model` preflight in build CLI. + +This helper used to live at `loader/config.py::validate_task_supported_for_model` +but was demoted to a private helper of the build command because it is the only +caller. Tests live in a dedicated module so they bypass the autouse fixture in +`test_build.py` that mocks the helper out for CLI-plumbing tests. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from winml.modelkit.commands.build import _validate_task_supported_for_model + + +class TestValidateTaskSupportedForModel: + """Tests for `_validate_task_supported_for_model` preflight helper.""" + + def test_raises_for_task_model_mismatch(self) -> None: + """Incompatible task/model combinations raise a clear ValueError.""" + mock_config = MagicMock() + mock_config.model_type = "resnet" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["image-classification", "image-feature-extraction"], + ), + patch("winml.modelkit.loader.task.normalize_task", side_effect=lambda t: t), + pytest.raises( + ValueError, + match=r"config\.loader\.task='text-generation' is not supported", + ), + ): + _validate_task_supported_for_model( + model_id="microsoft/resnet-50", + task="text-generation", + task_field_name="config.loader.task", + ) + + def test_accepts_supported_task(self) -> None: + """A supported task should pass without raising.""" + mock_config = MagicMock() + mock_config.model_type = "resnet" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["image-classification", "image-feature-extraction"], + ), + patch("winml.modelkit.loader.task.normalize_task", side_effect=lambda t: t), + ): + _validate_task_supported_for_model( + model_id="microsoft/resnet-50", + task="image-classification", + ) + + def test_ensure_hf_models_registered_called_before_lookup(self) -> None: + """ensure_hf_models_registered() is called to populate the ONNX registry + before get_supported_tasks, so models like resnet return the correct tasks.""" + mock_config = MagicMock() + mock_config.model_type = "resnet" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered") as mock_ensure, + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction", "image-classification"], + ), + patch("winml.modelkit.loader.task.normalize_task", side_effect=lambda t: t), + pytest.raises(ValueError, match=r"text-generation.*is not supported"), + ): + _validate_task_supported_for_model( + model_id="microsoft/resnet-50", + task="text-generation", + task_field_name="config.loader.task", + ) + mock_ensure.assert_called_once() + + def test_defers_when_registry_still_empty_after_registration(self) -> None: + """When get_supported_tasks returns [] even after registry population, + validation defers to the downstream loader without raising.""" + mock_config = MagicMock() + mock_config.model_type = "custom-model" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch("winml.modelkit.loader.task.get_supported_tasks", return_value=[]), + ): + # Should NOT raise — defer to downstream loader + _validate_task_supported_for_model( + model_id="org/custom-model", + task="text-generation", + ) + + def test_error_message_format(self) -> None: + """Error message has task/model/architecture on line 1, Supported tasks on line 2.""" + mock_config = MagicMock() + mock_config.model_type = "resnet" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["image-classification"], + ), + patch("winml.modelkit.loader.task.normalize_task", side_effect=lambda t: t), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + pytest.raises(ValueError) as exc_info, + ): + _validate_task_supported_for_model( + model_id="microsoft/resnet-50", + task="text-generation", + task_field_name="config.loader.task", + ) + + msg = str(exc_info.value) + lines = msg.splitlines() + assert len(lines) == 2 + assert lines[0].endswith("(architecture: resnet).") + assert lines[1].startswith("Supported tasks:") + + def test_accepts_next_sentence_prediction_for_bert(self) -> None: + """``next-sentence-prediction`` is in ``TASK_SYNONYM_EXTENSIONS`` and must + be accepted, even though Optimum's per-arch supported_tasks does not list + it. Regression for pre-PR behavior, see review claim 2. + """ + mock_config = MagicMock() + mock_config.model_type = "bert" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction", "fill-mask", "text-classification"], + ), + ): + # Should NOT raise — short-circuited via TASK_SYNONYM_EXTENSIONS. + _validate_task_supported_for_model( + model_id="bert-base-uncased", + task="next-sentence-prediction", + ) + + def test_accepts_mask_generation_via_synonym_extensions(self) -> None: + """``mask-generation`` is preserved in ``TASK_SYNONYM_EXTENSIONS`` for SAM2. + + Optimum's ``map_from_synonym`` would normalize it to ``feature-extraction``, + which is wrong for the HF-pipeline-keyed downstream registries. + """ + mock_config = MagicMock() + mock_config.model_type = "sam" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction"], + ), + ): + _validate_task_supported_for_model( + model_id="facebook/sam-vit-base", + task="mask-generation", + ) + + def test_accepts_optimum_synonym_with_warning(self, caplog: pytest.LogCaptureFixture) -> None: + """Optimum-known synonyms (e.g. ``masked-lm`` -> ``fill-mask``) are accepted + but logged as a warning so users converge on the canonical spelling. + """ + mock_config = MagicMock() + mock_config.model_type = "bert" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction", "fill-mask"], + ), + patch( + "winml.modelkit.loader.task.normalize_task", + side_effect=lambda t: {"masked-lm": "fill-mask"}.get(t, t), + ), + caplog.at_level("WARNING", logger="winml.modelkit.commands.build"), + ): + _validate_task_supported_for_model( + model_id="bert-base-uncased", + task="masked-lm", + ) + + assert any( + "synonym" in rec.message and "fill-mask" in rec.message for rec in caplog.records + ), f"Expected canonical-name hint, got: {[r.message for r in caplog.records]}" + + def test_silently_accepts_cross_modality_feature_extraction( + self, caplog: pytest.LogCaptureFixture + ) -> None: + """Documented limitation: Optimum collapses ``image-feature-extraction`` + and ``feature-extraction``. A text-only arch with ``--task + image-feature-extraction`` is therefore accepted (with a warning) at this + gate; cross-modality routing errors must surface downstream where the + HF-pipeline-keyed registries see the un-collapsed ``loader.task``. + + See review claim 1 — this test documents the limitation rather than + asserting a fix. + """ + mock_config = MagicMock() + mock_config.model_type = "bert" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction", "fill-mask"], + ), + patch( + "winml.modelkit.loader.task.normalize_task", + side_effect=lambda t: "feature-extraction" + if t in {"image-feature-extraction", "feature-extraction"} + else t, + ), + caplog.at_level("WARNING", logger="winml.modelkit.commands.build"), + ): + _validate_task_supported_for_model( + model_id="bert-base-uncased", + task="image-feature-extraction", + task_field_name="config.loader.task", + ) + + # Accepted, but the warning must fire so the limitation is at least visible. + assert any("synonym" in rec.message for rec in caplog.records) + + def test_rejects_unrelated_task_after_all_fallbacks(self) -> None: + """A task that is not verbatim-supported, not in ``TASK_SYNONYM_EXTENSIONS``, + and whose Optimum-normalized form is not in the arch's supported set is + still rejected. Ensures the new branches did not turn the gate into a + no-op. + """ + mock_config = MagicMock() + mock_config.model_type = "resnet" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["image-classification", "image-feature-extraction"], + ), + patch("winml.modelkit.loader.task.normalize_task", side_effect=lambda t: t), + pytest.raises(ValueError, match=r"text-generation.*is not supported"), + ): + _validate_task_supported_for_model( + model_id="microsoft/resnet-50", + task="text-generation", + ) + + def test_verbatim_match_does_not_warn(self, caplog: pytest.LogCaptureFixture) -> None: + """When the task is the exact canonical name in the supported set, no + synonym-warning should fire (verbatim branch short-circuits before + normalization). + """ + mock_config = MagicMock() + mock_config.model_type = "vit" + + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch("winml.modelkit.export.io.ensure_hf_models_registered"), + patch( + "winml.modelkit.loader.task.get_supported_tasks", + return_value=["feature-extraction", "image-classification"], + ), + caplog.at_level("WARNING", logger="winml.modelkit.commands.build"), + ): + _validate_task_supported_for_model( + model_id="google/vit-base-patch16-224", + task="image-classification", + ) + + assert not any("synonym" in rec.message for rec in caplog.records) diff --git a/tests/unit/commands/test_config_cli.py b/tests/unit/commands/test_config_cli.py index c993b0e32..1f16fa491 100644 --- a/tests/unit/commands/test_config_cli.py +++ b/tests/unit/commands/test_config_cli.py @@ -315,7 +315,7 @@ def test_onnx_no_quant(self, runner: CliRunner, tmp_path: Path) -> None: """--no-quant should set quant=None even for ONNX configs.""" from winml.modelkit.commands.config import config - # Create a fake .onnx file so _is_onnx_file returns True + # Create a fake .onnx file so is_onnx_file_path returns True onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") diff --git a/tests/unit/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py index a6438034a..659702bc8 100644 --- a/tests/unit/loader/test_resolve_loader_config.py +++ b/tests/unit/loader/test_resolve_loader_config.py @@ -15,7 +15,10 @@ import pytest -from winml.modelkit.loader import WinMLLoaderConfig, resolve_loader_config +from winml.modelkit.loader import ( + WinMLLoaderConfig, + resolve_loader_config, +) # =============================================================================