diff --git a/src/winml/modelkit/commands/inspect.py b/src/winml/modelkit/commands/inspect.py index 1eec26467..f0e94f0bf 100644 --- a/src/winml/modelkit/commands/inspect.py +++ b/src/winml/modelkit/commands/inspect.py @@ -40,9 +40,7 @@ _LOCAL_FILE_EXTS = frozenset({".onnx", ".pt", ".pth", ".safetensors", ".bin"}) -def _validate_task( - ctx: click.Context, param: click.Parameter, value: str | None -) -> str | None: +def _validate_task(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None: """Click-time validation for --task against the hand-coded KNOWN_TASKS set. Imports only ..loader.task to keep validation cheap — going through optimum @@ -426,11 +424,13 @@ def _inspect_model_v2( import optimum.exporters.onnx.model_configs # noqa: F401 from optimum.exporters.tasks import TasksManager + from ..loader import resolve_optimum_library + onnx_config_cls = TasksManager.get_exporter_config_constructor( exporter="onnx", model_type=model_type, task=task, - library_name="transformers", + library_name=resolve_optimum_library(model_type), ) if onnx_config_cls: config_name = ( diff --git a/src/winml/modelkit/export/io.py b/src/winml/modelkit/export/io.py index 99d9ac58f..c4f382409 100644 --- a/src/winml/modelkit/export/io.py +++ b/src/winml/modelkit/export/io.py @@ -210,11 +210,19 @@ def _get_onnx_config( normalized_task = _map_task_synonym(task) + # Route model_types whose Optimum OnnxConfig is registered under another + # library (e.g. timm via "timm_wrapper" -> "timm") so the lookup succeeds + # from every call site without an explicit --library flag. + from ..loader import resolve_optimum_library + + library_name = resolve_optimum_library(model_type, library_name) + logger.debug( - "Getting OnnxConfig: model_type=%s, task=%s -> %s", + "Getting OnnxConfig: model_type=%s, task=%s -> %s (library=%s)", model_type, task, normalized_task, + library_name, ) try: diff --git a/src/winml/modelkit/inspect/resolver.py b/src/winml/modelkit/inspect/resolver.py index 51a6a9ccd..85f64cd19 100644 --- a/src/winml/modelkit/inspect/resolver.py +++ b/src/winml/modelkit/inspect/resolver.py @@ -15,8 +15,11 @@ from ..loader.task import ( HF_TASK_DEFAULTS, KNOWN_TASKS, + WRAPPED_LIBRARY_MODEL_TYPES, + _detect_task_and_class_from_config, _detect_task_from_config, _get_custom_model_class, + resolve_optimum_library, ) from ..models import ( HF_MODEL_CLASS_MAPPING, @@ -46,6 +49,11 @@ logger = logging.getLogger(__name__) +# Task-detection provenance label returned by detect_task() for wrapped-library +# model types (e.g. timm via "timm_wrapper"). Surfaced in `inspect` output as +# "Task (via )" and in the JSON `task_source` field. +WRAPPED_LIBRARY_SOURCE = "wrapped-library" + # Mapping from pipeline stage verbs to the filenames build_hf_model() produces. # "export" is omitted because its stage name equals its filename — the # .get(stage, stage) fallback handles it. Used only in the legacy @@ -121,6 +129,16 @@ def detect_task(config: PretrainedConfig) -> tuple[str, str]: if mt == model_type_normalized: return task, "HF_MODEL_CLASS_MAPPING" + # Wrapped-library model types (e.g. timm via "timm_wrapper") carry no + # `architectures`; reuse the loader's resolution to derive the real task + # instead of falling through to the HF_TASK_DEFAULTS mislabel below. + if model_type in WRAPPED_LIBRARY_MODEL_TYPES and not getattr(config, "architectures", None): + try: + task, _ = _detect_task_and_class_from_config(config) + return task, WRAPPED_LIBRARY_SOURCE + except Exception: + logger.debug("wrapped-library task detection failed for %s", model_type, exc_info=True) + # Use TasksManager detection try: task = _detect_task_from_config(config) @@ -343,7 +361,7 @@ def resolve_exporter( exporter="onnx", model_type=model_type, task=task, - library_name="transformers", + library_name=resolve_optimum_library(model_type), ) if onnx_config_cls: # Handle functools.partial returned by TasksManager diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index 1a6b9b2b7..89a8a0f55 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -32,6 +32,7 @@ get_supported_tasks, get_task_abbrev, normalize_task, + resolve_optimum_library, resolve_task_and_model_class, ) @@ -46,6 +47,7 @@ "normalize_task", "resolve_hf_model_class", "resolve_loader_config", + "resolve_optimum_library", "resolve_task_and_model_class", ] diff --git a/src/winml/modelkit/loader/task.py b/src/winml/modelkit/loader/task.py index 9aea05fce..08eb141aa 100644 --- a/src/winml/modelkit/loader/task.py +++ b/src/winml/modelkit/loader/task.py @@ -8,6 +8,7 @@ Public API: resolve_task_and_model_class - Main orchestrator (3 resolution cases) + resolve_optimum_library - Route a model_type to the Optimum export library normalize_task - Map task aliases to canonical names get_task_abbrev - Abbreviated task name for cache keys get_supported_tasks - List ONNX-exportable tasks for a model type @@ -154,6 +155,40 @@ ("prajjwal1/bert-tiny", None): "feature-extraction", } +# Some transformers model_types are generic wrappers that expose an entire other +# library through a single type (e.g. timm via "timm_wrapper"). Such configs +# carry no `architectures` field, and their Optimum ONNX export config is +# registered under the wrapped library, not "transformers". This is a +# library-routing concern handled at the common resolution layer (the loader +# below and export.io._get_onnx_config), not a per-model OnnxConfig. +# +# Only the library is recorded here -- it is the irreducible Optimum-taxonomy +# fact. The export task is derived from Optimum's task list for that library +# (get_supported_tasks), not hardcoded. +# model_type -> optimum_library +WRAPPED_LIBRARY_MODEL_TYPES: dict[str, str] = { + "timm_wrapper": "timm", +} + + +def resolve_optimum_library(model_type: str | None, library_name: str = "transformers") -> str: + """Route a transformers model_type to the Optimum library that owns its export. + + Most models export under the library they were requested with. A few + transformers model_types are thin wrappers whose Optimum OnnxConfig lives in + another library (see :data:`WRAPPED_LIBRARY_MODEL_TYPES`); route those so the + OnnxConfig lookup succeeds without an explicit ``--library`` flag. + + Only the ``"transformers"`` library is rerouted, so an explicit + non-``"transformers"`` library is returned unchanged. (An explicit + ``--library transformers`` is indistinguishable from the default and is + still rerouted for wrapped types -- harmless, since those types have no + OnnxConfig registered under transformers anyway.) + """ + if library_name == "transformers" and model_type in WRAPPED_LIBRARY_MODEL_TYPES: + return WRAPPED_LIBRARY_MODEL_TYPES[model_type] + return library_name + # ============================================================================= # Internal Helpers @@ -314,8 +349,49 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t return resolve_task_and_model_class(config, task=override_task) # [1] Resolve architecture class from config. - # If config.architectures is missing/empty, this raises ValueError and the - # caller should provide task explicitly. + # Some model_types (e.g. timm via "timm_wrapper") are generic library + # wrappers that carry no `architectures` field. Resolve those through their + # wrapped library: the task comes from Optimum's task list for that library + # (not hardcoded), and the class from get_model_class_for_task (a generic + # Auto* class that transformers dispatches to the wrapper at load). + if not getattr(config, "architectures", None): + model_type = getattr(config, "model_type", None) + library = WRAPPED_LIBRARY_MODEL_TYPES.get(model_type) if model_type else None + if library is not None: + # Populate Optimum's exporter registry (incl. the wrapped library's + # task list) before querying it; scoped to this rare branch so normal + # model loading never pays for the import. + import optimum.exporters.onnx.model_configs # noqa: F401 + + supported = get_supported_tasks(model_type, library_name=library) + if supported: + # A wrapped library exposes a single ONNX export task today + # (timm -> "image-classification"), so supported[0] is the right + # default. If one ever exposes multiple, supported[0] is an + # arbitrary pick -- warn (listing the tasks) but still proceed; + # pass --task to choose a different one. + task = supported[0] + if len(supported) > 1: + logger.warning( + "config has no 'architectures' and the %s library exposes " + "multiple export tasks for %s %s; defaulting to %r " + "(pass --task to choose another).", + library, + model_type, + supported, + task, + ) + model_class = TasksManager.get_model_class_for_task(task, framework="pt") + logger.info( + "config has no 'architectures'; resolved %s via %s library (task=%s, class=%s)", + model_type, + library, + task, + model_class.__name__, + ) + return task, model_class + # If config.architectures is still missing/empty, this raises ValueError and + # the caller should provide task explicitly. arch_model_class = _resolve_model_class_from_config(config) arch_name = arch_model_class.__name__ diff --git a/tests/unit/export/test_timm_library_routing.py b/tests/unit/export/test_timm_library_routing.py new file mode 100644 index 000000000..993ec85d0 --- /dev/null +++ b/tests/unit/export/test_timm_library_routing.py @@ -0,0 +1,56 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for timm library routing during OnnxConfig resolution. + +timm checkpoints load through transformers' TimmWrapper (model_type= +"timm_wrapper"), but Optimum registers their OnnxConfig (TimmDefaultOnnxConfig) +only under library_name="timm". ``resolve_optimum_library`` reroutes the lookup +so ``resolve_io_specs`` / ``_get_onnx_config`` resolve it under the default +"transformers" library, with no --library flag. See loader/task.py and +export/io.py. +""" + +from __future__ import annotations + +import pytest + +# Trigger OnnxConfig registration with TasksManager +import winml.modelkit.models # noqa: F401 +from winml.modelkit.export import resolve_io_specs +from winml.modelkit.export.io import _get_onnx_config # internal: routing under test + + +@pytest.fixture(scope="module") +def timm_wrapper_config(): + """Minimal offline TimmWrapperConfig (no hub download).""" + from transformers import TimmWrapperConfig + + return TimmWrapperConfig(num_labels=10) + + +class TestTimmLibraryRouting: + """timm_wrapper resolves to Optimum's TimmDefaultOnnxConfig via library routing.""" + + def test_get_onnx_config_routes_to_timm_default(self, timm_wrapper_config) -> None: + """A default (transformers) lookup reroutes to Optimum's TimmDefaultOnnxConfig.""" + from optimum.exporters.onnx.model_configs import TimmDefaultOnnxConfig + + onnx_config = _get_onnx_config("timm_wrapper", "image-classification", timm_wrapper_config) + assert isinstance(onnx_config, TimmDefaultOnnxConfig), ( + "timm_wrapper did not route to Optimum's timm OnnxConfig; " + "resolve_optimum_library routing may be inactive." + ) + + def test_io_specs_pixel_values_to_logits(self, timm_wrapper_config) -> None: + """resolve_io_specs yields the timm image-classifier I/O without a --library flag.""" + specs = resolve_io_specs("timm_wrapper", "image-classification", timm_wrapper_config) + assert specs["input_names"] == ["pixel_values"] + assert "logits" in specs["output_names"] + + def test_pixel_values_is_4d_nchw(self, timm_wrapper_config) -> None: + specs = resolve_io_specs("timm_wrapper", "image-classification", timm_wrapper_config) + shape = specs["input_shapes"][0] + assert len(shape) == 4, f"pixel_values should be 4D NCHW, got {shape}" + assert shape[1] == 3, f"expected 3 channels, got {shape[1]}" diff --git a/tests/unit/inspect/test_resolver_timm.py b/tests/unit/inspect/test_resolver_timm.py new file mode 100644 index 000000000..bb9f006b2 --- /dev/null +++ b/tests/unit/inspect/test_resolver_timm.py @@ -0,0 +1,48 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""timm (wrapped-library) resolution in the public `inspect` path. + +`inspect_model` resolves task/exporter via `resolver.detect_task` + +`resolver.resolve_exporter` — a separate path from the CLI's `_inspect_model_v2`. +timm checkpoints load as `TimmWrapperConfig` (model_type="timm_wrapper", +architectures=None). Without wrapped-library handling, `detect_task` mislabels +the task (HF_TASK_DEFAULTS fallback) and `resolve_exporter` hardcodes +library_name="transformers" so the OnnxConfig lookup fails (UNSUPPORTED). + +These cover the fix that routes both through the timm library, matching the CLI. +""" + +from __future__ import annotations + +import pytest + +from winml.modelkit.inspect import SupportLevel, detect_task, resolve_exporter + + +@pytest.fixture(scope="module") +def timm_wrapper_config(): + """Minimal offline TimmWrapperConfig (no hub download).""" + from transformers import TimmWrapperConfig + + return TimmWrapperConfig(num_labels=10) + + +class TestDetectTaskTimm: + def test_timm_detects_image_classification(self, timm_wrapper_config) -> None: + """timm_wrapper (no architectures) resolves to image-classification, not a fallback.""" + task, source = detect_task(timm_wrapper_config) + assert task == "image-classification", f"got task={task!r} source={source!r}" + + +class TestResolveExporterTimm: + def test_timm_resolves_optimum_onnx_config(self, timm_wrapper_config) -> None: + """resolve_exporter routes timm_wrapper to Optimum's timm OnnxConfig + real I/O.""" + info = resolve_exporter( + "timm_wrapper", "image-classification", hf_config=timm_wrapper_config + ) + assert info.onnx_config_class == "TimmDefaultOnnxConfig", info.onnx_config_class + assert info.support_level is not SupportLevel.UNSUPPORTED + names = [t.name for t in info.input_tensors] + assert "pixel_values" in names, names diff --git a/tests/unit/loader/test_detect_task_and_class.py b/tests/unit/loader/test_detect_task_and_class.py index b2c1381ba..ec181a9ca 100644 --- a/tests/unit/loader/test_detect_task_and_class.py +++ b/tests/unit/loader/test_detect_task_and_class.py @@ -16,7 +16,11 @@ import pytest -from winml.modelkit.loader.task import _detect_task_and_class_from_config +from winml.modelkit.loader.task import ( + WRAPPED_LIBRARY_MODEL_TYPES, + _detect_task_and_class_from_config, + resolve_optimum_library, +) class TestDetectTaskAndClassFromConfig: @@ -148,3 +152,55 @@ def test_no_override_for_unrelated_model(self): assert task == "image-classification" # TasksManager returns AutoModelForImageClassification, not the arch class assert resolved_class is not ResNetForImageClassification or task == "image-classification" + + +class TestResolveOptimumLibrary: + """Unit tests for the resolve_optimum_library wrapped-library router.""" + + def test_timm_wrapper_routes_to_timm(self): + """timm_wrapper under the default library routes to Optimum's 'timm'.""" + assert resolve_optimum_library("timm_wrapper", "transformers") == "timm" + + def test_unmapped_model_type_unchanged(self): + """A normal transformers model_type is not rerouted.""" + assert resolve_optimum_library("bert", "transformers") == "transformers" + + def test_none_model_type_unchanged(self): + assert resolve_optimum_library(None, "transformers") == "transformers" + + def test_explicit_library_is_respected(self): + """An explicit (non-default) library always wins over the wrapper routing.""" + assert resolve_optimum_library("timm_wrapper", "timm") == "timm" + assert resolve_optimum_library("timm_wrapper", "diffusers") == "diffusers" + + +class TestWrappedLibraryArchitecturesFallback: + """Auto-detection for wrapper model_types that carry no `architectures`. + + timm checkpoints load through transformers' TimmWrapper as TimmWrapperConfig + (architectures=None); the loader resolves them via WRAPPED_LIBRARY_MODEL_TYPES + instead of raising. + """ + + def test_timm_wrapper_resolves_without_architectures(self): + config = MagicMock() + config.architectures = None + config.model_type = "timm_wrapper" + config._name_or_path = "" + + task, resolved_class = _detect_task_and_class_from_config(config) + + # Task is derived from Optimum's task list for the timm library, not hardcoded. + assert WRAPPED_LIBRARY_MODEL_TYPES["timm_wrapper"] == "timm" + assert task == "image-classification" + # A generic Auto* class is used; it dispatches to TimmWrapper at load time. + assert resolved_class.__name__ == "AutoModelForImageClassification" + + def test_missing_architectures_without_wrapper_still_raises(self): + config = MagicMock() + config.architectures = None + config.model_type = "totally-unknown-model-xyz" + config._name_or_path = "" + + with pytest.raises(ValueError, match="no 'architectures' field"): + _detect_task_and_class_from_config(config)