Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/winml/modelkit/export/htp/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ def export(
monitor.update(ExportStep.INPUT_GEN, **input_gen_data)

# Step 3: Hierarchy Building
self._trace_model_hierarchy(model, inputs)
# Trace under the Optimum patcher so models that inject constant
# forward arguments at export time (e.g. ViTPose MoE's dataset_index)
# are traced with the same inputs they are exported with. The export
# in Step 4 re-enters the patcher; the contexts are sequential, not
# nested.
with self._get_optimum_patcher(model, task):
self._trace_model_hierarchy(model, inputs)

execution_steps = (
self._hierarchy_builder.get_execution_summary().get("execution_steps", 0)
Expand Down Expand Up @@ -487,7 +493,11 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any:
task=to_optimum_task(task),
library_name="transformers",
)
return cfg_cls(model_config).patch_model_for_export(model)
# Pass an explicit empty model_kwargs so patchers that inject extra
# forward arguments can populate it. Some patchers (e.g. ViTPose MoE,
# which sets a constant dataset_index) assume a mutable dict and crash
# on the None default from patch_model_for_export.
return cfg_cls(model_config).patch_model_for_export(model, model_kwargs={})
except KeyError:
logger.debug(
"Model type '%s' (task='%s') not in Optimum registry; "
Expand Down
2 changes: 2 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration
)
from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig
from .vitpose import MODEL_CLASS_MAPPING as _VITPOSE_CLASS_MAPPING
from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration


Expand All @@ -97,6 +98,7 @@
**_SIGLIP_CLASS_MAPPING,
**_T5_CLASS_MAPPING,
**_VED_CLASS_MAPPING,
**_VITPOSE_CLASS_MAPPING,
}

# Registry: model_type -> WinMLBuildConfig
Expand Down
33 changes: 33 additions & 0 deletions src/winml/modelkit/models/hf/vitpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""ViTPose HuggingFace Model Configuration.

ViTPose is a top-down human pose (keypoint-detection) model: a plain ViT
backbone with a lightweight decoder that regresses keypoint heatmaps inside a
given person box.

This module provides:
- MODEL_CLASS_MAPPING: routes keypoint-detection to VitPoseForPoseEstimation.

Why ViTPose needs class mapping:
Optimum already registers the ONNX export config (VitPoseOnnxConfig) for the
"vitpose" model type, so export works once the model is loaded. However,
Optimum's TasksManager has no task-to-class entry for "keypoint-detection",
and transformers' AutoModelForKeypointDetection only recognizes SuperPoint —
not ViTPose. Without this mapping the resolver cannot load the model class for
the keypoint-detection task. The "plus" checkpoints (MoE backbone) load through
the same class; their expert index is fixed at export time by Optimum's
VitPoseModelPatcher, so no extra input is needed.
"""

from __future__ import annotations

from transformers import VitPoseForPoseEstimation


# (model_type, task) -> HuggingFace model class
MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
("vitpose", "keypoint-detection"): VitPoseForPoseEstimation,
}
73 changes: 73 additions & 0 deletions tests/unit/export/test_htp_exporter_patcher_model_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Regression tests for `HTPExporter._get_optimum_patcher` model_kwargs handling.

Some Optimum model patchers populate a mutable ``model_kwargs`` dict to inject
constant forward arguments at export time. ViTPose's MoE patcher, for example,
sets ``model_kwargs["dataset_index"]`` when ``num_experts > 1``. Optimum's
``patch_model_for_export`` defaults ``model_kwargs`` to ``None``, so such
patchers crash with ``TypeError: 'NoneType' object does not support item
assignment`` unless the caller passes an explicit dict.

This test pins the contract that ``_get_optimum_patcher`` passes an explicit
``model_kwargs={}`` so those patchers can populate it.
"""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import torch.nn as nn

from winml.modelkit.export.htp import HTPExporter


class _FakeConfig:
"""Minimal HF-style config exposing the model_type the patcher checks."""

model_type = "vitpose"


class _FakeModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.config = _FakeConfig()


class TestGetOptimumPatcherModelKwargs:
"""_get_optimum_patcher must pass an explicit mutable model_kwargs dict."""

def test_patch_model_for_export_receives_explicit_dict(self) -> None:
"""The patcher call must pass ``model_kwargs={}`` (not the None default).

We patch the TasksManager lookup to return a fake config constructor
whose ``patch_model_for_export`` records the ``model_kwargs`` it
receives. A non-None dict lets MoE patchers populate forward arguments
without crashing.
"""
captured: dict[str, object] = {}

fake_onnx_config = MagicMock()

def record_patch(model, model_kwargs=None):
captured["model_kwargs"] = model_kwargs
return MagicMock()

fake_onnx_config.patch_model_for_export.side_effect = record_patch

def fake_ctor(*args: object, **kwargs: object):
return fake_onnx_config

with patch(
"optimum.exporters.tasks.TasksManager.get_exporter_config_constructor",
return_value=fake_ctor,
):
HTPExporter._get_optimum_patcher(_FakeModel(), task="keypoint-detection")

assert captured.get("model_kwargs") == {}, (
"Expected _get_optimum_patcher to pass an explicit model_kwargs={} "
f"to patch_model_for_export, got {captured.get('model_kwargs')!r}. "
"MoE patchers (e.g. ViTPose dataset_index) need a mutable dict."
)
49 changes: 49 additions & 0 deletions tests/unit/models/test_vitpose_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for ViTPose keypoint-detection model-class resolution.

Optimum registers the ViTPose ONNX export config but has no
task-to-class entry for ``keypoint-detection``, and transformers'
``AutoModelForKeypointDetection`` only recognises SuperPoint. The
``("vitpose", "keypoint-detection")`` entry in ``MODEL_CLASS_MAPPING``
bridges that gap so the resolver can load ``VitPoseForPoseEstimation``.
"""

from __future__ import annotations

from unittest.mock import MagicMock

from winml.modelkit.loader import resolve_task
from winml.modelkit.models.hf import MODEL_CLASS_MAPPING
from winml.modelkit.models.hf.vitpose import MODEL_CLASS_MAPPING as VITPOSE_MAPPING


class TestVitPoseMapping:
"""ViTPose keypoint-detection routes to VitPoseForPoseEstimation."""

def test_mapping_entry_registered(self):
"""The aggregated mapping exposes the vitpose keypoint-detection entry."""
assert ("vitpose", "keypoint-detection") in MODEL_CLASS_MAPPING
assert (
MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")].__name__
== "VitPoseForPoseEstimation"
)

def test_module_mapping_merged_into_aggregate(self):
"""The module-level mapping is included in the aggregated mapping."""
assert VITPOSE_MAPPING.items() <= MODEL_CLASS_MAPPING.items()

def test_explicit_task_resolves_vitpose_class(self):
"""An explicit keypoint-detection task resolves VitPoseForPoseEstimation."""
config = MagicMock()
config.model_type = "vitpose"
config.architectures = ["VitPoseForPoseEstimation"]
config._name_or_path = "usyd-community/vitpose-base-simple"

resolution = resolve_task(config, task="keypoint-detection")

assert resolution.task == "keypoint-detection"
assert resolution.model_class.__name__ == "VitPoseForPoseEstimation"

Loading