Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _validate_task_supported_for_model(
# [2] HF-pipeline-only task names that Optimum's TasksManager does not
# know but the rest of the CLI accepts (e.g. ``next-sentence-prediction``
# handled via HF_TASK_DEFAULTS, ``mask-generation`` preserved for SAM2).
# These are routed downstream by export/io.py::_map_task_synonym, so
# These are routed downstream by export/io.py::map_task_synonym, so
# rejecting here would break invocations that ``winml config`` and
# ``winml export`` accept.
if task in TASK_SYNONYM_EXTENSIONS:
Expand Down
4 changes: 3 additions & 1 deletion src/winml/modelkit/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,14 @@ def _inspect_model_v2(
import optimum.exporters.onnx.model_configs # noqa: F401
from optimum.exporters.tasks import TasksManager

# TasksManager expects normalized task names
from ..export.io import map_task_synonym
from ..loader import resolve_optimum_library

onnx_config_cls = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model_type=model_type,
task=task,
task=map_task_synonym(task),
library_name=resolve_optimum_library(model_type),
)
if onnx_config_cls:
Expand Down
21 changes: 19 additions & 2 deletions src/winml/modelkit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"object-detection": ObjectDetectionDataset,
"text-classification": TextDataset,
"text-feature-extraction": TextDataset,
"feature-extraction": TextDataset,
"feature-extraction": {"input_ids": TextDataset, "pixel_values": ImageDataset},
"sentence-similarity": TextDataset,
"next-sentence-prediction": TextDataset,
"fill-mask": TextDataset,
Expand All @@ -51,6 +51,23 @@
}

Comment thread
zhenchaoni marked this conversation as resolved.

def _resolve_dataset_class(task: str, io_config: dict | None) -> tuple[type, str]:
"""Resolve the dataset class for ``task``."""
dataset_class = TASK_DATASET_MAPPING[task]
if not isinstance(dataset_class, dict):
return dataset_class, task

hits = [name for name in (io_config or {}) if name in dataset_class]
if len(hits) == 1:
return dataset_class[hits[0]], task

logger.warning(
"Task '%s' is not supported for the model, falling back to RandomDataset",
task,
)
return RandomDataset, "random"


def universal_calib_dataset(
model_name: str,
task: str,
Expand Down Expand Up @@ -98,7 +115,7 @@ def universal_calib_dataset(

# Create dataset with error handling
try:
dataset_class = TASK_DATASET_MAPPING[task]
dataset_class, task = _resolve_dataset_class(task, kwargs.get("io_config"))

# Craft kwargs - only add optional parameters if provided
dataset_kwargs = {
Expand Down
46 changes: 42 additions & 4 deletions src/winml/modelkit/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,62 @@ def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel:
)


# Evaluator uses HF pipeline and evaluate library,
# which have their own task naming conventions.
# Inner dict maps an ONNX input name (from the model's IO config) to the
# corresponding HF task name, so we can resolve ambiguous tasks by modality.
HF_TASK_NAME_MAPPING: dict[str, dict[str, str]] = {
"feature-extraction": {
"input_ids": "feature-extraction",
"pixel_values": "image-feature-extraction",
},
}


def to_hf_pipeline_task(task: str, model_id: str | None) -> str:
"""Convert task name to an HF-pipeline-recognizable name."""
mapping = HF_TASK_NAME_MAPPING.get(task)
if mapping is None or model_id is None:
return task

try:
from transformers import AutoConfig

from ..export.io import _get_onnx_config

Comment thread
zhenchaoni marked this conversation as resolved.
hf_config = AutoConfig.from_pretrained(model_id)
io_config = _get_onnx_config(hf_config.model_type, task, hf_config).inputs
except Exception as e:
logger.debug("Static OnnxConfig probe failed for task %r: %s", task, e)
return task

hits = [mapping[n] for n in mapping if n in io_config]
if len(hits) != 1:
return task
return hits[0]


def _resolve_task(config: WinMLEvaluationConfig) -> str:
"""Resolve task from config or model's HF config, and validate it is supported."""
console = Console()
console.print("[bold]Resolving task...[/bold]")

if config.task is not None:
task = config.task
else:
if config.model_id is None:
raise ValueError("Cannot infer task without model_id. Provide --task.")

console = Console()
console.print("[bold]Detecting model task...[/bold]")

from transformers import AutoConfig

from ..loader.task import _detect_task_from_config

hf_config = AutoConfig.from_pretrained(config.model_id)
task = _detect_task_from_config(hf_config)
console.print(f"[dim]Detected task:[/dim] {task}")

# Convert to an HF-pipeline-recognizable task before evaluator lookup.
task = to_hf_pipeline_task(task, config.model_id)
console.print(f"[dim]Use[/dim] {task} [dim]to evaluate[/dim]")

if task not in _EVALUATOR_REGISTRY:
supported = ", ".join(sorted(_EVALUATOR_REGISTRY))
Expand Down
5 changes: 4 additions & 1 deletion src/winml/modelkit/export/htp/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,14 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any:
logger.debug("Model has no config.model_type; skipping Optimum patcher.")
return contextlib.nullcontext()

# TasksManager expects normalized task names
from ..io import map_task_synonym

try:
cfg_cls = TasksManager.get_exporter_config_constructor(
"onnx",
model_type=model_type,
task=task,
task=map_task_synonym(task),
library_name="transformers",
)
return cfg_cls(model_config).patch_model_for_export(model)
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/export/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def ensure_hf_models_registered() -> None:
}


def _map_task_synonym(task: str) -> str:
def map_task_synonym(task: str) -> str:
"""Map task name to canonical form, extending Optimum's synonym mapping.

Our extensions take priority over Optimum's built-in synonym map.
Expand Down Expand Up @@ -208,7 +208,7 @@ def _get_onnx_config(
"""
ensure_hf_models_registered()

normalized_task = _map_task_synonym(task)
normalized_task = map_task_synonym(task)

# Route model_types whose Optimum OnnxConfig is registered under another
# library (e.g. timm via "timm_wrapper" -> "timm") so the lookup succeeds
Expand Down
5 changes: 4 additions & 1 deletion src/winml/modelkit/inspect/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,15 @@ def resolve_exporter(
import optimum.exporters.onnx.model_configs # noqa: F401
from optimum.exporters.tasks import TasksManager

# TasksManager expects normalized task names
from ..export.io import map_task_synonym

# TasksManager uses underscores (sam2_video), not hyphens (sam2-video)
# Use original model_type for TasksManager lookup
onnx_config_cls = TasksManager.get_exporter_config_constructor(
exporter="onnx",
model_type=model_type,
task=task,
task=map_task_synonym(task),
library_name=resolve_optimum_library(model_type),
)
if onnx_config_cls:
Expand Down
14 changes: 12 additions & 2 deletions tests/e2e/test_eval_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,25 @@ def test_sentence_similarity(self, runner: CliRunner, tmp_path: Path) -> None:
if is_host("qnn"):
_assert_in_range(data["metrics"], "cosine_spearman", 40.0, 100.0)

@pytest.mark.parametrize(
"task",
["image-feature-extraction", "feature-extraction"],
)
def test_image_feature_extraction(
self, runner: CliRunner, tmp_path: Path,
self, runner: CliRunner, tmp_path: Path, task: str,
) -> None:
# kNN accuracies reported as percentages 0..100.
# --streaming avoids caching mini-imagenet.
# Parameterized over both task names accepted on the CLI:
# - "image-feature-extraction" is the HF pipeline task name
# and dispatches to the image evaluator directly.
# - "feature-extraction" is bimodal; for a vision model it is
# mapped internally to the HF pipeline name so the image
# dataset and evaluator are selected.
out = tmp_path / "result.json"
_invoke(runner, [
"-m", "facebook/dinov2-small",
"--task", "image-feature-extraction",
"--task", task,
"--streaming",
"--samples", SAMPLES,
"-o", str(out),
Expand Down
26 changes: 26 additions & 0 deletions tests/e2e/test_export_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,32 @@ def test_minimal_resnet50(self, tmp_path: Path):
_assert_all_nodes_have(model, "winml.hierarchy.depth")


class TestExportDinoV2:

MODEL = "facebook/dinov2-base"

def test_image_feature_extraction(self, tmp_path: Path):
"""``-t image-feature-extraction`` must produce a valid ONNX export."""
onnx_path = tmp_path / "model.onnx"
result = _invoke(["-m", self.MODEL, "-o", str(onnx_path),
"-t", "image-feature-extraction"])
assert result.exit_code == 0, (
f"export failed (exit {result.exit_code}):\n{result.output}"
)
assert onnx_path.exists(), f"ONNX model not found at {onnx_path}"

model = onnx.load(str(onnx_path))
# Optimum-driven OnnxConfig for dinov2/feature-extraction produces
# last_hidden_state. If the patcher had fallen back to nullcontext,
# the trace-inferred output names (last_hidden_state, pooler_output)
# would have been used instead.
assert _output_names(model) == ["last_hidden_state"], (
f"expected outputs ['last_hidden_state'], got {_output_names(model)} "
"— Optimum patcher likely fell back to nullcontext because the "
"task wasn't normalised before TasksManager lookup."
)


# ===========================================================================
# Required-option failures
# ===========================================================================
Expand Down
28 changes: 28 additions & 0 deletions tests/e2e/test_inspect_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,31 @@ def test_auto_detect_object_detection(self):
assert data["model_id"] == self.MODEL
assert data["model_type"] == "detr"
assert data["task"] == "object-detection"


@pytest.mark.network
class TestInspectDinoV2:

MODEL = "facebook/dinov2-base"

def test_image_feature_extraction_override(self):
"""HF synonym 'image-feature-extraction' must resolve via TasksManager."""
data = _run_network(self.MODEL, task="image-feature-extraction")
_assert_common_structure(data, self.MODEL, "image-feature-extraction")
assert data["model_type"] == "dinov2"
exporter = data["exporter"]
assert exporter["onnx_config_class"] == "Dinov2OnnxConfig", (
f"expected Dinov2OnnxConfig, got {exporter['onnx_config_class']!r} "
"— task likely wasn't normalised before TasksManager lookup."
)
assert exporter["onnx_config_source"] == "TasksManager"
assert exporter["support_level"] != "unsupported"

def test_feature_extraction_override(self):
"""'feature-extraction' (the Optimum task) must also resolve (control)."""
data = _run_network(self.MODEL, task="feature-extraction")
_assert_common_structure(data, self.MODEL, "feature-extraction")
assert data["model_type"] == "dinov2"
exporter = data["exporter"]
assert exporter["onnx_config_class"] == "Dinov2OnnxConfig"
assert exporter["onnx_config_source"] == "TasksManager"
31 changes: 31 additions & 0 deletions tests/e2e/test_quantize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def onnx_imgseg() -> Path:
)


@pytest.fixture(scope="session")
def onnx_dinov2() -> Path:
return _export_hf_to_onnx(
"facebook/dinov2-small", "image-feature-extraction", "dinov2_small",
)


# ---------------------------------------------------------------------------
# Standard assertions
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -564,6 +571,30 @@ def test_unsupported_task_falls_back_to_random_dataset(
f"fallback warning not emitted in CLI output:\n{r.output}"
)

@pytest.mark.network
def test_feature_extraction_with_pixel_values_uses_image_dataset(
self, runner: CliRunner, onnx_dinov2: Path, tmp_path: Path
):
# For a vision model the bimodal task 'feature-extraction' must
# dispatch via the model's ONNX inputs (pixel_values) to
# ImageDataset for calibration. The task label in the log stays
# 'feature-extraction' (the resolver only swaps the dataset class).
out = tmp_path / "d7.onnx"

r = _invoke(
runner,
[
"-m", str(onnx_dinov2), "-o", str(out),
"--task", "feature-extraction",
"--model-name", "facebook/dinov2-small",
"--samples", "4", "-v",
],
)
_assert_quantized_output(input_onnx=onnx_dinov2, output_onnx=out, stdout=r.output)
assert (
"Creating feature-extraction dataset with ImageDataset" in r.output
), r.output


# ===========================================================================
# Output behavior
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/datasets/test_random_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,18 @@ class TestTaskDatasetMapping:
"""Verify all supported tasks map to correct dataset classes."""

def test_all_tasks_have_mappings(self) -> None:
"""Every task in TASK_DATASET_MAPPING maps to a callable dataset class."""
"""Every task maps to either a dataset class or an input-name dispatch dict."""
from winml.modelkit.datasets import TASK_DATASET_MAPPING

for task, cls in TASK_DATASET_MAPPING.items():
assert callable(cls), f"Task {task!r} maps to non-callable {cls}"
for task, entry in TASK_DATASET_MAPPING.items():
if isinstance(entry, dict):
assert entry, f"Task {task!r} maps to empty dict"
for input_name, cls in entry.items():
assert callable(cls), (
f"Task {task!r}[{input_name!r}] maps to non-callable {cls}"
)
else:
assert callable(entry), f"Task {task!r} maps to non-callable {entry}"

@pytest.mark.parametrize(
("task", "module_path", "class_name"),
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ def test_infer_from_model_id(self):
):
assert _resolve_task(config) == "image-classification"

def test_feature_extraction_mapped_to_hf_image_feature_extraction_for_vision_model(self):
"""Vision FE model with --task feature-extraction is mapped to the HF
pipeline task image-feature-extraction so the evaluator registry
lookup succeeds."""
from winml.modelkit.eval.evaluate import _resolve_task

fake_hf_config = MagicMock()
fake_hf_config.model_type = "dinov2"
fake_onnx_config = MagicMock()
fake_onnx_config.inputs = {"pixel_values": object()}

config = WinMLEvaluationConfig(
model_id="facebook/dinov2-base", task="feature-extraction"
)
with (
patch(
"transformers.AutoConfig.from_pretrained",
return_value=fake_hf_config,
),
patch(
"winml.modelkit.export.io._get_onnx_config",
return_value=fake_onnx_config,
),
):
assert _resolve_task(config) == "image-feature-extraction"


class TestGetEvaluatorClass:
"""Tests for get_evaluator_class registry lookup."""
Expand Down
Loading
Loading