Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/winml/modelkit/export/htp/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ def export(
monitor.update(ExportStep.INPUT_GEN, **input_gen_data)

# Step 3: Hierarchy Building
self._trace_model_hierarchy(model, inputs)
# Trace under the Optimum patcher so models that inject constant
# forward arguments at export time (e.g. ViTPose MoE's dataset_index)
# are traced with the same inputs they are exported with. The export
# in Step 4 re-enters the patcher; the contexts are sequential, not
# nested.
with self._get_optimum_patcher(model, task):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good and low-risk overall. The model_kwargs={} fix is effectively a no-op for non-MoE patchers (Optimum's ModelPatcher.__init__ already coerces None → {}), so that one's safe.

My one concern is wrapping Step 3 (_trace_model_hierarchy) in the patcher. Models that already resolve a real Optimum patcher today — e.g. CLIP (clip_text_model / clip_vision_model), SAM, T5, SigLIP, whisper, VED — will now have their hierarchy trace run through patched forward for the first time. The ONNX graph is unaffected (export already ran patched), but the traced module path can shift, which could change the hierarchy / tag coverage on those models.

You verified the 6 ViTPose models, but those weren't being traced-under-patch before. Could you also run a before/after on at least one already-patched non-ViTPose model (CLIP is a good pick) and confirm the tag coverage / hierarchy stats are unchanged?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran CLIP before and after the change — it goes through a real Optimum patcher (CLIPTextModelWithProjection), so it's a fair test. The hierarchy and tagging come out the same both ways: 52 traced modules, 153 hierarchy modules, 606 nodes, 606 tagged, 100% coverage. Only the build time differs, so tracing Step 3 under the patcher doesn't change tag coverage for the already-patched models.

self._trace_model_hierarchy(model, inputs)

execution_steps = (
self._hierarchy_builder.get_execution_summary().get("execution_steps", 0)
Expand Down Expand Up @@ -487,7 +493,11 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any:
task=to_optimum_task(task),
library_name="transformers",
)
return cfg_cls(model_config).patch_model_for_export(model)
# Pass an explicit empty model_kwargs so patchers that inject extra
# forward arguments can populate it. Some patchers (e.g. ViTPose MoE,
# which sets a constant dataset_index) assume a mutable dict and crash
# on the None default from patch_model_for_export.
return cfg_cls(model_config).patch_model_for_export(model, model_kwargs={})
except KeyError:
logger.debug(
"Model type '%s' (task='%s') not in Optimum registry; "
Expand Down
2 changes: 2 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration
)
from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig
from .vitpose import MODEL_CLASS_MAPPING as _VITPOSE_CLASS_MAPPING
from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration


Expand All @@ -97,6 +98,7 @@
**_SIGLIP_CLASS_MAPPING,
**_T5_CLASS_MAPPING,
**_VED_CLASS_MAPPING,
**_VITPOSE_CLASS_MAPPING,
}

# Registry: model_type -> WinMLBuildConfig
Expand Down
48 changes: 48 additions & 0 deletions src/winml/modelkit/models/hf/vitpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""ViTPose HuggingFace Model Configuration.

ViTPose is a top-down human pose (keypoint-detection) model: a plain ViT
backbone with a lightweight decoder that regresses keypoint heatmaps inside a
given person box.

This module provides:
- MODEL_CLASS_MAPPING: routes keypoint-detection to VitPoseForPoseEstimation,
and declares it the default task via a (vitpose, None) sentinel.

Why ViTPose needs class mapping:
Optimum already registers the ONNX export config (VitPoseOnnxConfig) for the
"vitpose" model type, so export works once the model is loaded. However,
Optimum's TasksManager has no task-to-class entry for "keypoint-detection",
and transformers' AutoModelForKeypointDetection only recognizes SuperPoint —
not ViTPose. Without this mapping the resolver cannot load the model class for
the keypoint-detection task. The "plus" checkpoints (MoE backbone) load through
the same class; their expert index is fixed at export time by Optimum's
VitPoseModelPatcher, so no extra input is needed.

Why the (vitpose, None) sentinel:
TasksManager cannot infer a task from the ViTPose architecture, so without a
declared default the resolver falls back to an unrelated task and config/build
fail unless the user passes --task keypoint-detection. The sentinel encodes
keypoint-detection as the canonical default (the resolver reverse-looks-up the
task sharing the sentinel's class), making --task optional. Mirrors SAM, which
declares mask-generation the same way.
"""

from __future__ import annotations

from transformers import VitPoseForPoseEstimation


# (model_type, task) -> HuggingFace model class
#
# The (vitpose, None) sentinel declares keypoint-detection as the default task
# applied during auto-detection (when the user does not pass --task). Its value
# is the default *class*; the resolver reverse-looks-up the task name from the
# matching (vitpose, keypoint-detection) -> same class entry.
MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = {
("vitpose", "keypoint-detection"): VitPoseForPoseEstimation,

@vortex-captain vortex-captain Jun 24, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("vitpose", "keypoint-detection"): VitPoseForPoseEstimation,
("vitpose", "keypoint-detection"): VitPoseForPoseEstimation,
("vitpose", None): VitPoseForPoseEstimation,

Could you help try test to see this line makes command with task omitted work?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the ("vitpose", None) sentinel and it works. It makes the resolver reverse-look-up keypoint-detection from the matching class entry, so config and build pick the task up without --task now. I ran winml config -m usyd-community/vitpose-base-simple with no task flag and it resolves keypoint-detection on its own, and added unit tests for it. Same approach SAM uses for mask-generation.

("vitpose", None): VitPoseForPoseEstimation,
}
73 changes: 73 additions & 0 deletions tests/unit/export/test_htp_exporter_patcher_model_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Regression tests for `HTPExporter._get_optimum_patcher` model_kwargs handling.
Some Optimum model patchers populate a mutable ``model_kwargs`` dict to inject
constant forward arguments at export time. ViTPose's MoE patcher, for example,
sets ``model_kwargs["dataset_index"]`` when ``num_experts > 1``. Optimum's
``patch_model_for_export`` defaults ``model_kwargs`` to ``None``, so such
patchers crash with ``TypeError: 'NoneType' object does not support item
assignment`` unless the caller passes an explicit dict.
This test pins the contract that ``_get_optimum_patcher`` passes an explicit
``model_kwargs={}`` so those patchers can populate it.
"""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import torch.nn as nn

from winml.modelkit.export.htp import HTPExporter


class _FakeConfig:
"""Minimal HF-style config exposing the model_type the patcher checks."""

model_type = "vitpose"


class _FakeModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.config = _FakeConfig()


class TestGetOptimumPatcherModelKwargs:
"""_get_optimum_patcher must pass an explicit mutable model_kwargs dict."""

def test_patch_model_for_export_receives_explicit_dict(self) -> None:
"""The patcher call must pass ``model_kwargs={}`` (not the None default).
We patch the TasksManager lookup to return a fake config constructor
whose ``patch_model_for_export`` records the ``model_kwargs`` it
receives. A non-None dict lets MoE patchers populate forward arguments
without crashing.
"""
captured: dict[str, object] = {}

fake_onnx_config = MagicMock()

def record_patch(model, model_kwargs=None):
captured["model_kwargs"] = model_kwargs
return MagicMock()

fake_onnx_config.patch_model_for_export.side_effect = record_patch

def fake_ctor(*args: object, **kwargs: object):
return fake_onnx_config

with patch(
"optimum.exporters.tasks.TasksManager.get_exporter_config_constructor",
return_value=fake_ctor,
):
HTPExporter._get_optimum_patcher(_FakeModel(), task="keypoint-detection")

assert captured.get("model_kwargs") == {}, (
"Expected _get_optimum_patcher to pass an explicit model_kwargs={} "
f"to patch_model_for_export, got {captured.get('model_kwargs')!r}. "
"MoE patchers (e.g. ViTPose dataset_index) need a mutable dict."
)
69 changes: 69 additions & 0 deletions tests/unit/models/test_vitpose_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for ViTPose keypoint-detection model-class resolution.

Optimum registers the ViTPose ONNX export config but has no
task-to-class entry for ``keypoint-detection``, and transformers'
``AutoModelForKeypointDetection`` only recognises SuperPoint. The
``("vitpose", "keypoint-detection")`` entry in ``MODEL_CLASS_MAPPING``
bridges that gap so the resolver can load ``VitPoseForPoseEstimation``.
"""

from __future__ import annotations

from unittest.mock import MagicMock

from winml.modelkit.loader import resolve_task
from winml.modelkit.models.hf import MODEL_CLASS_MAPPING
from winml.modelkit.models.hf.vitpose import MODEL_CLASS_MAPPING as VITPOSE_MAPPING


class TestVitPoseMapping:
"""ViTPose keypoint-detection routes to VitPoseForPoseEstimation."""

def test_mapping_entry_registered(self):
"""The aggregated mapping exposes the vitpose keypoint-detection entry."""
assert ("vitpose", "keypoint-detection") in MODEL_CLASS_MAPPING
assert (
MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")].__name__
== "VitPoseForPoseEstimation"
)

def test_module_mapping_merged_into_aggregate(self):
"""The module-level mapping is included in the aggregated mapping."""
assert VITPOSE_MAPPING.items() <= MODEL_CLASS_MAPPING.items()

def test_explicit_task_resolves_vitpose_class(self):
"""An explicit keypoint-detection task resolves VitPoseForPoseEstimation."""
config = MagicMock()
config.model_type = "vitpose"
config.architectures = ["VitPoseForPoseEstimation"]
config._name_or_path = "usyd-community/vitpose-base-simple"

resolution = resolve_task(config, task="keypoint-detection")

assert resolution.task == "keypoint-detection"
assert resolution.model_class.__name__ == "VitPoseForPoseEstimation"

def test_sentinel_resolves_default_task_without_explicit_task(self):
"""With no --task, the (vitpose, None) sentinel defaults to keypoint-detection."""
config = MagicMock()
config.model_type = "vitpose"
config.architectures = ["VitPoseForPoseEstimation"]
config._name_or_path = "usyd-community/vitpose-base-simple"

resolution = resolve_task(config)

assert resolution.task == "keypoint-detection"
assert resolution.model_class.__name__ == "VitPoseForPoseEstimation"

def test_sentinel_registered_in_mapping(self):
"""The (vitpose, None) sentinel shares the keypoint-detection class."""
assert ("vitpose", None) in MODEL_CLASS_MAPPING
assert (
MODEL_CLASS_MAPPING[("vitpose", None)]
is MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")]
)

Loading