-
Notifications
You must be signed in to change notification settings - Fork 4
feat(keypoint-detection): enable ViTPose config/build/perf #905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,48 @@ | ||||||||
| # ------------------------------------------------------------------------- | ||||||||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||||||||
| # Licensed under the MIT License. | ||||||||
| # -------------------------------------------------------------------------- | ||||||||
| """ViTPose HuggingFace Model Configuration. | ||||||||
|
|
||||||||
| ViTPose is a top-down human pose (keypoint-detection) model: a plain ViT | ||||||||
| backbone with a lightweight decoder that regresses keypoint heatmaps inside a | ||||||||
| given person box. | ||||||||
|
|
||||||||
| This module provides: | ||||||||
| - MODEL_CLASS_MAPPING: routes keypoint-detection to VitPoseForPoseEstimation, | ||||||||
| and declares it the default task via a (vitpose, None) sentinel. | ||||||||
|
|
||||||||
| Why ViTPose needs class mapping: | ||||||||
| Optimum already registers the ONNX export config (VitPoseOnnxConfig) for the | ||||||||
| "vitpose" model type, so export works once the model is loaded. However, | ||||||||
| Optimum's TasksManager has no task-to-class entry for "keypoint-detection", | ||||||||
| and transformers' AutoModelForKeypointDetection only recognizes SuperPoint — | ||||||||
| not ViTPose. Without this mapping the resolver cannot load the model class for | ||||||||
| the keypoint-detection task. The "plus" checkpoints (MoE backbone) load through | ||||||||
| the same class; their expert index is fixed at export time by Optimum's | ||||||||
| VitPoseModelPatcher, so no extra input is needed. | ||||||||
|
|
||||||||
| Why the (vitpose, None) sentinel: | ||||||||
| TasksManager cannot infer a task from the ViTPose architecture, so without a | ||||||||
| declared default the resolver falls back to an unrelated task and config/build | ||||||||
| fail unless the user passes --task keypoint-detection. The sentinel encodes | ||||||||
| keypoint-detection as the canonical default (the resolver reverse-looks-up the | ||||||||
| task sharing the sentinel's class), making --task optional. Mirrors SAM, which | ||||||||
| declares mask-generation the same way. | ||||||||
| """ | ||||||||
|
|
||||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| from transformers import VitPoseForPoseEstimation | ||||||||
|
|
||||||||
|
|
||||||||
| # (model_type, task) -> HuggingFace model class | ||||||||
| # | ||||||||
| # The (vitpose, None) sentinel declares keypoint-detection as the default task | ||||||||
| # applied during auto-detection (when the user does not pass --task). Its value | ||||||||
| # is the default *class*; the resolver reverse-looks-up the task name from the | ||||||||
| # matching (vitpose, keypoint-detection) -> same class entry. | ||||||||
| MODEL_CLASS_MAPPING: dict[tuple[str, str | None], type] = { | ||||||||
| ("vitpose", "keypoint-detection"): VitPoseForPoseEstimation, | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Could you help try test to see this line makes command with task omitted work?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the ("vitpose", None) sentinel and it works. It makes the resolver reverse-look-up keypoint-detection from the matching class entry, so config and build pick the task up without --task now. I ran |
||||||||
| ("vitpose", None): VitPoseForPoseEstimation, | ||||||||
| } | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| """Regression tests for `HTPExporter._get_optimum_patcher` model_kwargs handling. | ||
| Some Optimum model patchers populate a mutable ``model_kwargs`` dict to inject | ||
| constant forward arguments at export time. ViTPose's MoE patcher, for example, | ||
| sets ``model_kwargs["dataset_index"]`` when ``num_experts > 1``. Optimum's | ||
| ``patch_model_for_export`` defaults ``model_kwargs`` to ``None``, so such | ||
| patchers crash with ``TypeError: 'NoneType' object does not support item | ||
| assignment`` unless the caller passes an explicit dict. | ||
| This test pins the contract that ``_get_optimum_patcher`` passes an explicit | ||
| ``model_kwargs={}`` so those patchers can populate it. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import torch.nn as nn | ||
|
|
||
| from winml.modelkit.export.htp import HTPExporter | ||
|
|
||
|
|
||
| class _FakeConfig: | ||
| """Minimal HF-style config exposing the model_type the patcher checks.""" | ||
|
|
||
| model_type = "vitpose" | ||
|
|
||
|
|
||
| class _FakeModel(nn.Module): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.config = _FakeConfig() | ||
|
|
||
|
|
||
| class TestGetOptimumPatcherModelKwargs: | ||
| """_get_optimum_patcher must pass an explicit mutable model_kwargs dict.""" | ||
|
|
||
| def test_patch_model_for_export_receives_explicit_dict(self) -> None: | ||
| """The patcher call must pass ``model_kwargs={}`` (not the None default). | ||
| We patch the TasksManager lookup to return a fake config constructor | ||
| whose ``patch_model_for_export`` records the ``model_kwargs`` it | ||
| receives. A non-None dict lets MoE patchers populate forward arguments | ||
| without crashing. | ||
| """ | ||
| captured: dict[str, object] = {} | ||
|
|
||
| fake_onnx_config = MagicMock() | ||
|
|
||
| def record_patch(model, model_kwargs=None): | ||
| captured["model_kwargs"] = model_kwargs | ||
| return MagicMock() | ||
|
|
||
| fake_onnx_config.patch_model_for_export.side_effect = record_patch | ||
|
|
||
| def fake_ctor(*args: object, **kwargs: object): | ||
| return fake_onnx_config | ||
|
|
||
| with patch( | ||
| "optimum.exporters.tasks.TasksManager.get_exporter_config_constructor", | ||
| return_value=fake_ctor, | ||
| ): | ||
| HTPExporter._get_optimum_patcher(_FakeModel(), task="keypoint-detection") | ||
|
|
||
| assert captured.get("model_kwargs") == {}, ( | ||
| "Expected _get_optimum_patcher to pass an explicit model_kwargs={} " | ||
| f"to patch_model_for_export, got {captured.get('model_kwargs')!r}. " | ||
| "MoE patchers (e.g. ViTPose dataset_index) need a mutable dict." | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| """Tests for ViTPose keypoint-detection model-class resolution. | ||
|
|
||
| Optimum registers the ViTPose ONNX export config but has no | ||
| task-to-class entry for ``keypoint-detection``, and transformers' | ||
| ``AutoModelForKeypointDetection`` only recognises SuperPoint. The | ||
| ``("vitpose", "keypoint-detection")`` entry in ``MODEL_CLASS_MAPPING`` | ||
| bridges that gap so the resolver can load ``VitPoseForPoseEstimation``. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| from winml.modelkit.loader import resolve_task | ||
| from winml.modelkit.models.hf import MODEL_CLASS_MAPPING | ||
| from winml.modelkit.models.hf.vitpose import MODEL_CLASS_MAPPING as VITPOSE_MAPPING | ||
|
|
||
|
|
||
| class TestVitPoseMapping: | ||
| """ViTPose keypoint-detection routes to VitPoseForPoseEstimation.""" | ||
|
|
||
| def test_mapping_entry_registered(self): | ||
| """The aggregated mapping exposes the vitpose keypoint-detection entry.""" | ||
| assert ("vitpose", "keypoint-detection") in MODEL_CLASS_MAPPING | ||
| assert ( | ||
| MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")].__name__ | ||
| == "VitPoseForPoseEstimation" | ||
| ) | ||
|
|
||
| def test_module_mapping_merged_into_aggregate(self): | ||
| """The module-level mapping is included in the aggregated mapping.""" | ||
| assert VITPOSE_MAPPING.items() <= MODEL_CLASS_MAPPING.items() | ||
|
|
||
| def test_explicit_task_resolves_vitpose_class(self): | ||
| """An explicit keypoint-detection task resolves VitPoseForPoseEstimation.""" | ||
| config = MagicMock() | ||
| config.model_type = "vitpose" | ||
| config.architectures = ["VitPoseForPoseEstimation"] | ||
| config._name_or_path = "usyd-community/vitpose-base-simple" | ||
|
|
||
| resolution = resolve_task(config, task="keypoint-detection") | ||
|
|
||
| assert resolution.task == "keypoint-detection" | ||
| assert resolution.model_class.__name__ == "VitPoseForPoseEstimation" | ||
|
|
||
| def test_sentinel_resolves_default_task_without_explicit_task(self): | ||
| """With no --task, the (vitpose, None) sentinel defaults to keypoint-detection.""" | ||
| config = MagicMock() | ||
| config.model_type = "vitpose" | ||
| config.architectures = ["VitPoseForPoseEstimation"] | ||
| config._name_or_path = "usyd-community/vitpose-base-simple" | ||
|
|
||
| resolution = resolve_task(config) | ||
|
|
||
| assert resolution.task == "keypoint-detection" | ||
| assert resolution.model_class.__name__ == "VitPoseForPoseEstimation" | ||
|
|
||
| def test_sentinel_registered_in_mapping(self): | ||
| """The (vitpose, None) sentinel shares the keypoint-detection class.""" | ||
| assert ("vitpose", None) in MODEL_CLASS_MAPPING | ||
| assert ( | ||
| MODEL_CLASS_MAPPING[("vitpose", None)] | ||
| is MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")] | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good and low-risk overall. The
model_kwargs={}fix is effectively a no-op for non-MoE patchers (Optimum'sModelPatcher.__init__already coercesNone → {}), so that one's safe.My one concern is wrapping Step 3 (
_trace_model_hierarchy) in the patcher. Models that already resolve a real Optimum patcher today — e.g. CLIP (clip_text_model/clip_vision_model), SAM, T5, SigLIP, whisper, VED — will now have their hierarchy trace run through patchedforwardfor the first time. The ONNX graph is unaffected (export already ran patched), but the traced module path can shift, which could change the hierarchy / tag coverage on those models.You verified the 6 ViTPose models, but those weren't being traced-under-patch before. Could you also run a before/after on at least one already-patched non-ViTPose model (CLIP is a good pick) and confirm the tag coverage / hierarchy stats are unchanged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran CLIP before and after the change — it goes through a real Optimum patcher (CLIPTextModelWithProjection), so it's a fair test. The hierarchy and tagging come out the same both ways: 52 traced modules, 153 hierarchy modules, 606 nodes, 606 tagged, 100% coverage. Only the build time differs, so tracing Step 3 under the patcher doesn't change tag coverage for the already-patched models.