Skip to content
Merged
8 changes: 4 additions & 4 deletions src/winml/modelkit/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
_LOCAL_FILE_EXTS = frozenset({".onnx", ".pt", ".pth", ".safetensors", ".bin"})


def _validate_task(
ctx: click.Context, param: click.Parameter, value: str | None
) -> str | None:
def _validate_task(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
"""Click-time validation for --task against the hand-coded KNOWN_TASKS set.

Imports only ..loader.task to keep validation cheap — going through optimum
Expand Down Expand Up @@ -426,11 +424,13 @@ def _inspect_model_v2(
import optimum.exporters.onnx.model_configs # noqa: F401
from optimum.exporters.tasks import TasksManager

from ..loader import resolve_optimum_library
Comment thread
vortex-captain marked this conversation as resolved.

onnx_config_cls = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model_type=model_type,
task=task,
library_name="transformers",
library_name=resolve_optimum_library(model_type),
Comment thread
vortex-captain marked this conversation as resolved.
)
if onnx_config_cls:
config_name = (
Expand Down
10 changes: 9 additions & 1 deletion src/winml/modelkit/export/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,19 @@ def _get_onnx_config(

normalized_task = _map_task_synonym(task)

# Route model_types whose Optimum OnnxConfig is registered under another
# library (e.g. timm via "timm_wrapper" -> "timm") so the lookup succeeds
# from every call site without an explicit --library flag.
from ..loader import resolve_optimum_library
Comment thread
vortex-captain marked this conversation as resolved.

library_name = resolve_optimum_library(model_type, library_name)
Comment thread
vortex-captain marked this conversation as resolved.

logger.debug(
"Getting OnnxConfig: model_type=%s, task=%s -> %s",
"Getting OnnxConfig: model_type=%s, task=%s -> %s (library=%s)",
model_type,
task,
normalized_task,
library_name,
)

try:
Expand Down
20 changes: 19 additions & 1 deletion src/winml/modelkit/inspect/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
from ..loader.task import (
HF_TASK_DEFAULTS,
KNOWN_TASKS,
WRAPPED_LIBRARY_MODEL_TYPES,
_detect_task_and_class_from_config,
Comment thread
vortex-captain marked this conversation as resolved.
_detect_task_from_config,
_get_custom_model_class,
resolve_optimum_library,
)
from ..models import (
HF_MODEL_CLASS_MAPPING,
Expand Down Expand Up @@ -46,6 +49,11 @@

logger = logging.getLogger(__name__)

# Task-detection provenance label returned by detect_task() for wrapped-library
# model types (e.g. timm via "timm_wrapper"). Surfaced in `inspect` output as
# "Task <task> (via <source>)" and in the JSON `task_source` field.
WRAPPED_LIBRARY_SOURCE = "wrapped-library"

# Mapping from pipeline stage verbs to the filenames build_hf_model() produces.
# "export" is omitted because its stage name equals its filename — the
# .get(stage, stage) fallback handles it. Used only in the legacy
Expand Down Expand Up @@ -121,6 +129,16 @@ def detect_task(config: PretrainedConfig) -> tuple[str, str]:
if mt == model_type_normalized:
return task, "HF_MODEL_CLASS_MAPPING"

# Wrapped-library model types (e.g. timm via "timm_wrapper") carry no
# `architectures`; reuse the loader's resolution to derive the real task
# instead of falling through to the HF_TASK_DEFAULTS mislabel below.
if model_type in WRAPPED_LIBRARY_MODEL_TYPES and not getattr(config, "architectures", None):
try:
task, _ = _detect_task_and_class_from_config(config)
return task, WRAPPED_LIBRARY_SOURCE
except Exception:
logger.debug("wrapped-library task detection failed for %s", model_type, exc_info=True)

# Use TasksManager detection
try:
task = _detect_task_from_config(config)
Expand Down Expand Up @@ -343,7 +361,7 @@ def resolve_exporter(
exporter="onnx",
model_type=model_type,
task=task,
library_name="transformers",
library_name=resolve_optimum_library(model_type),
)
if onnx_config_cls:
# Handle functools.partial returned by TasksManager
Expand Down
2 changes: 2 additions & 0 deletions src/winml/modelkit/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_supported_tasks,
get_task_abbrev,
normalize_task,
resolve_optimum_library,
resolve_task_and_model_class,
)

Expand All @@ -46,6 +47,7 @@
"normalize_task",
"resolve_hf_model_class",
"resolve_loader_config",
"resolve_optimum_library",
"resolve_task_and_model_class",
]

Expand Down
80 changes: 78 additions & 2 deletions src/winml/modelkit/loader/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

Public API:
resolve_task_and_model_class - Main orchestrator (3 resolution cases)
resolve_optimum_library - Route a model_type to the Optimum export library
normalize_task - Map task aliases to canonical names
get_task_abbrev - Abbreviated task name for cache keys
get_supported_tasks - List ONNX-exportable tasks for a model type
Expand Down Expand Up @@ -154,6 +155,40 @@
("prajjwal1/bert-tiny", None): "feature-extraction",
}

# Some transformers model_types are generic wrappers that expose an entire other
# library through a single type (e.g. timm via "timm_wrapper"). Such configs
# carry no `architectures` field, and their Optimum ONNX export config is
# registered under the wrapped library, not "transformers". This is a
# library-routing concern handled at the common resolution layer (the loader
# below and export.io._get_onnx_config), not a per-model OnnxConfig.
#
# Only the library is recorded here -- it is the irreducible Optimum-taxonomy
# fact. The export task is derived from Optimum's task list for that library
# (get_supported_tasks), not hardcoded.
# model_type -> optimum_library
Comment thread
vortex-captain marked this conversation as resolved.
WRAPPED_LIBRARY_MODEL_TYPES: dict[str, str] = {
"timm_wrapper": "timm",
}


def resolve_optimum_library(model_type: str | None, library_name: str = "transformers") -> str:
Comment thread
vortex-captain marked this conversation as resolved.
"""Route a transformers model_type to the Optimum library that owns its export.

Comment thread
vortex-captain marked this conversation as resolved.
Comment thread
vortex-captain marked this conversation as resolved.
Most models export under the library they were requested with. A few
transformers model_types are thin wrappers whose Optimum OnnxConfig lives in
another library (see :data:`WRAPPED_LIBRARY_MODEL_TYPES`); route those so the
OnnxConfig lookup succeeds without an explicit ``--library`` flag.

Only the ``"transformers"`` library is rerouted, so an explicit
non-``"transformers"`` library is returned unchanged. (An explicit
``--library transformers`` is indistinguishable from the default and is
still rerouted for wrapped types -- harmless, since those types have no
OnnxConfig registered under transformers anyway.)
"""
if library_name == "transformers" and model_type in WRAPPED_LIBRARY_MODEL_TYPES:
return WRAPPED_LIBRARY_MODEL_TYPES[model_type]
return library_name


# =============================================================================
# Internal Helpers
Expand Down Expand Up @@ -314,8 +349,49 @@ def _detect_task_and_class_from_config(config: PretrainedConfig) -> tuple[str, t
return resolve_task_and_model_class(config, task=override_task)

# [1] Resolve architecture class from config.
# If config.architectures is missing/empty, this raises ValueError and the
# caller should provide task explicitly.
# Some model_types (e.g. timm via "timm_wrapper") are generic library
Comment thread
vortex-captain marked this conversation as resolved.
# wrappers that carry no `architectures` field. Resolve those through their
# wrapped library: the task comes from Optimum's task list for that library
# (not hardcoded), and the class from get_model_class_for_task (a generic
# Auto* class that transformers dispatches to the wrapper at load).
if not getattr(config, "architectures", None):
model_type = getattr(config, "model_type", None)
library = WRAPPED_LIBRARY_MODEL_TYPES.get(model_type) if model_type else None
if library is not None:
# Populate Optimum's exporter registry (incl. the wrapped library's
# task list) before querying it; scoped to this rare branch so normal
# model loading never pays for the import.
import optimum.exporters.onnx.model_configs # noqa: F401

supported = get_supported_tasks(model_type, library_name=library)
if supported:
# A wrapped library exposes a single ONNX export task today
# (timm -> "image-classification"), so supported[0] is the right
# default. If one ever exposes multiple, supported[0] is an
# arbitrary pick -- warn (listing the tasks) but still proceed;
# pass --task to choose a different one.
task = supported[0]
Comment thread
vortex-captain marked this conversation as resolved.
Comment thread
vortex-captain marked this conversation as resolved.
if len(supported) > 1:
logger.warning(
"config has no 'architectures' and the %s library exposes "
"multiple export tasks for %s %s; defaulting to %r "
"(pass --task to choose another).",
library,
model_type,
supported,
task,
)
model_class = TasksManager.get_model_class_for_task(task, framework="pt")
Comment thread
vortex-captain marked this conversation as resolved.
logger.info(
"config has no 'architectures'; resolved %s via %s library (task=%s, class=%s)",
model_type,
library,
task,
model_class.__name__,
)
Comment thread
vortex-captain marked this conversation as resolved.
Comment thread
vortex-captain marked this conversation as resolved.
return task, model_class
# If config.architectures is still missing/empty, this raises ValueError and
# the caller should provide task explicitly.
Comment thread
vortex-captain marked this conversation as resolved.
arch_model_class = _resolve_model_class_from_config(config)
arch_name = arch_model_class.__name__

Expand Down
56 changes: 56 additions & 0 deletions tests/unit/export/test_timm_library_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for timm library routing during OnnxConfig resolution.

timm checkpoints load through transformers' TimmWrapper (model_type=
"timm_wrapper"), but Optimum registers their OnnxConfig (TimmDefaultOnnxConfig)
only under library_name="timm". ``resolve_optimum_library`` reroutes the lookup
so ``resolve_io_specs`` / ``_get_onnx_config`` resolve it under the default
"transformers" library, with no --library flag. See loader/task.py and
export/io.py.
"""

from __future__ import annotations

import pytest

# Trigger OnnxConfig registration with TasksManager
import winml.modelkit.models # noqa: F401
Comment thread
vortex-captain marked this conversation as resolved.
Dismissed
from winml.modelkit.export import resolve_io_specs
from winml.modelkit.export.io import _get_onnx_config # internal: routing under test


@pytest.fixture(scope="module")
def timm_wrapper_config():
"""Minimal offline TimmWrapperConfig (no hub download)."""
from transformers import TimmWrapperConfig

return TimmWrapperConfig(num_labels=10)


class TestTimmLibraryRouting:
"""timm_wrapper resolves to Optimum's TimmDefaultOnnxConfig via library routing."""

def test_get_onnx_config_routes_to_timm_default(self, timm_wrapper_config) -> None:
"""A default (transformers) lookup reroutes to Optimum's TimmDefaultOnnxConfig."""
from optimum.exporters.onnx.model_configs import TimmDefaultOnnxConfig

onnx_config = _get_onnx_config("timm_wrapper", "image-classification", timm_wrapper_config)
assert isinstance(onnx_config, TimmDefaultOnnxConfig), (
"timm_wrapper did not route to Optimum's timm OnnxConfig; "
"resolve_optimum_library routing may be inactive."
)

def test_io_specs_pixel_values_to_logits(self, timm_wrapper_config) -> None:
"""resolve_io_specs yields the timm image-classifier I/O without a --library flag."""
specs = resolve_io_specs("timm_wrapper", "image-classification", timm_wrapper_config)
assert specs["input_names"] == ["pixel_values"]
assert "logits" in specs["output_names"]

def test_pixel_values_is_4d_nchw(self, timm_wrapper_config) -> None:
specs = resolve_io_specs("timm_wrapper", "image-classification", timm_wrapper_config)
shape = specs["input_shapes"][0]
assert len(shape) == 4, f"pixel_values should be 4D NCHW, got {shape}"
assert shape[1] == 3, f"expected 3 channels, got {shape[1]}"
Comment thread
vortex-captain marked this conversation as resolved.
48 changes: 48 additions & 0 deletions tests/unit/inspect/test_resolver_timm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""timm (wrapped-library) resolution in the public `inspect` path.

`inspect_model` resolves task/exporter via `resolver.detect_task` +
`resolver.resolve_exporter` — a separate path from the CLI's `_inspect_model_v2`.
timm checkpoints load as `TimmWrapperConfig` (model_type="timm_wrapper",
architectures=None). Without wrapped-library handling, `detect_task` mislabels
the task (HF_TASK_DEFAULTS fallback) and `resolve_exporter` hardcodes
library_name="transformers" so the OnnxConfig lookup fails (UNSUPPORTED).

These cover the fix that routes both through the timm library, matching the CLI.
"""

from __future__ import annotations

import pytest

from winml.modelkit.inspect import SupportLevel, detect_task, resolve_exporter


@pytest.fixture(scope="module")
def timm_wrapper_config():
"""Minimal offline TimmWrapperConfig (no hub download)."""
from transformers import TimmWrapperConfig

return TimmWrapperConfig(num_labels=10)


class TestDetectTaskTimm:
def test_timm_detects_image_classification(self, timm_wrapper_config) -> None:
"""timm_wrapper (no architectures) resolves to image-classification, not a fallback."""
task, source = detect_task(timm_wrapper_config)
assert task == "image-classification", f"got task={task!r} source={source!r}"


class TestResolveExporterTimm:
def test_timm_resolves_optimum_onnx_config(self, timm_wrapper_config) -> None:
"""resolve_exporter routes timm_wrapper to Optimum's timm OnnxConfig + real I/O."""
info = resolve_exporter(
"timm_wrapper", "image-classification", hf_config=timm_wrapper_config
)
assert info.onnx_config_class == "TimmDefaultOnnxConfig", info.onnx_config_class
assert info.support_level is not SupportLevel.UNSUPPORTED
names = [t.name for t in info.input_tensors]
assert "pixel_values" in names, names
58 changes: 57 additions & 1 deletion tests/unit/loader/test_detect_task_and_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

import pytest

from winml.modelkit.loader.task import _detect_task_and_class_from_config
from winml.modelkit.loader.task import (
WRAPPED_LIBRARY_MODEL_TYPES,
_detect_task_and_class_from_config,
resolve_optimum_library,
)


class TestDetectTaskAndClassFromConfig:
Expand Down Expand Up @@ -148,3 +152,55 @@ def test_no_override_for_unrelated_model(self):
assert task == "image-classification"
# TasksManager returns AutoModelForImageClassification, not the arch class
assert resolved_class is not ResNetForImageClassification or task == "image-classification"


class TestResolveOptimumLibrary:
"""Unit tests for the resolve_optimum_library wrapped-library router."""

def test_timm_wrapper_routes_to_timm(self):
"""timm_wrapper under the default library routes to Optimum's 'timm'."""
assert resolve_optimum_library("timm_wrapper", "transformers") == "timm"

def test_unmapped_model_type_unchanged(self):
"""A normal transformers model_type is not rerouted."""
assert resolve_optimum_library("bert", "transformers") == "transformers"

def test_none_model_type_unchanged(self):
assert resolve_optimum_library(None, "transformers") == "transformers"

def test_explicit_library_is_respected(self):
"""An explicit (non-default) library always wins over the wrapper routing."""
assert resolve_optimum_library("timm_wrapper", "timm") == "timm"
assert resolve_optimum_library("timm_wrapper", "diffusers") == "diffusers"


class TestWrappedLibraryArchitecturesFallback:
"""Auto-detection for wrapper model_types that carry no `architectures`.

timm checkpoints load through transformers' TimmWrapper as TimmWrapperConfig
(architectures=None); the loader resolves them via WRAPPED_LIBRARY_MODEL_TYPES
instead of raising.
"""

def test_timm_wrapper_resolves_without_architectures(self):
config = MagicMock()
config.architectures = None
config.model_type = "timm_wrapper"
config._name_or_path = ""

task, resolved_class = _detect_task_and_class_from_config(config)

# Task is derived from Optimum's task list for the timm library, not hardcoded.
assert WRAPPED_LIBRARY_MODEL_TYPES["timm_wrapper"] == "timm"
assert task == "image-classification"
Comment thread
vortex-captain marked this conversation as resolved.
# A generic Auto* class is used; it dispatches to TimmWrapper at load time.
assert resolved_class.__name__ == "AutoModelForImageClassification"

def test_missing_architectures_without_wrapper_still_raises(self):
config = MagicMock()
config.architectures = None
config.model_type = "totally-unknown-model-xyz"
config._name_or_path = ""

with pytest.raises(ValueError, match="no 'architectures' field"):
_detect_task_and_class_from_config(config)
Loading