diff --git a/src/winml/modelkit/export/htp/exporter.py b/src/winml/modelkit/export/htp/exporter.py index 91abd5828..1fde10a4a 100644 --- a/src/winml/modelkit/export/htp/exporter.py +++ b/src/winml/modelkit/export/htp/exporter.py @@ -229,7 +229,13 @@ def export( monitor.update(ExportStep.INPUT_GEN, **input_gen_data) # Step 3: Hierarchy Building - self._trace_model_hierarchy(model, inputs) + # Trace under the Optimum patcher so models that inject constant + # forward arguments at export time (e.g. ViTPose MoE's dataset_index) + # are traced with the same inputs they are exported with. The export + # in Step 4 re-enters the patcher; the contexts are sequential, not + # nested. + with self._get_optimum_patcher(model, task): + self._trace_model_hierarchy(model, inputs) execution_steps = ( self._hierarchy_builder.get_execution_summary().get("execution_steps", 0) @@ -487,7 +493,11 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any: task=to_optimum_task(task), library_name="transformers", ) - return cfg_cls(model_config).patch_model_for_export(model) + # Pass an explicit empty model_kwargs so patchers that inject extra + # forward arguments can populate it. Some patchers (e.g. ViTPose MoE, + # which sets a constant dataset_index) assume a mutable dict and crash + # on the None default from patch_model_for_export. + return cfg_cls(model_config).patch_model_for_export(model, model_kwargs={}) except KeyError: logger.debug( "Model type '%s' (task='%s') not in Optimum registry; " diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..8b8c9cdf6 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -75,6 +75,7 @@ VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration ) from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig +from .vitpose import MODEL_CLASS_MAPPING as _VITPOSE_CLASS_MAPPING from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration @@ -97,6 +98,7 @@ **_SIGLIP_CLASS_MAPPING, **_T5_CLASS_MAPPING, **_VED_CLASS_MAPPING, + **_VITPOSE_CLASS_MAPPING, } # Registry: model_type -> WinMLBuildConfig diff --git a/src/winml/modelkit/models/hf/vitpose.py b/src/winml/modelkit/models/hf/vitpose.py new file mode 100644 index 000000000..7b25ded6c --- /dev/null +++ b/src/winml/modelkit/models/hf/vitpose.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""ViTPose HuggingFace Model Configuration. + +ViTPose is a top-down human pose (keypoint-detection) model: a plain ViT +backbone with a lightweight decoder that regresses keypoint heatmaps inside a +given person box. + +This module provides: +- MODEL_CLASS_MAPPING: routes keypoint-detection to VitPoseForPoseEstimation. + +Why ViTPose needs class mapping: +Optimum already registers the ONNX export config (VitPoseOnnxConfig) for the +"vitpose" model type, so export works once the model is loaded. However, +Optimum's TasksManager has no task-to-class entry for "keypoint-detection", +and transformers' AutoModelForKeypointDetection only recognizes SuperPoint — +not ViTPose. Without this mapping the resolver cannot load the model class for +the keypoint-detection task. The "plus" checkpoints (MoE backbone) load through +the same class; their expert index is fixed at export time by Optimum's +VitPoseModelPatcher, so no extra input is needed. +""" + +from __future__ import annotations + +from transformers import VitPoseForPoseEstimation + + +# (model_type, task) -> HuggingFace model class +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("vitpose", "keypoint-detection"): VitPoseForPoseEstimation, +} diff --git a/tests/unit/export/test_htp_exporter_patcher_model_kwargs.py b/tests/unit/export/test_htp_exporter_patcher_model_kwargs.py new file mode 100644 index 000000000..dd3ab8156 --- /dev/null +++ b/tests/unit/export/test_htp_exporter_patcher_model_kwargs.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Regression tests for `HTPExporter._get_optimum_patcher` model_kwargs handling. + +Some Optimum model patchers populate a mutable ``model_kwargs`` dict to inject +constant forward arguments at export time. ViTPose's MoE patcher, for example, +sets ``model_kwargs["dataset_index"]`` when ``num_experts > 1``. Optimum's +``patch_model_for_export`` defaults ``model_kwargs`` to ``None``, so such +patchers crash with ``TypeError: 'NoneType' object does not support item +assignment`` unless the caller passes an explicit dict. + +This test pins the contract that ``_get_optimum_patcher`` passes an explicit +``model_kwargs={}`` so those patchers can populate it. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import torch.nn as nn + +from winml.modelkit.export.htp import HTPExporter + + +class _FakeConfig: + """Minimal HF-style config exposing the model_type the patcher checks.""" + + model_type = "vitpose" + + +class _FakeModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.config = _FakeConfig() + + +class TestGetOptimumPatcherModelKwargs: + """_get_optimum_patcher must pass an explicit mutable model_kwargs dict.""" + + def test_patch_model_for_export_receives_explicit_dict(self) -> None: + """The patcher call must pass ``model_kwargs={}`` (not the None default). + + We patch the TasksManager lookup to return a fake config constructor + whose ``patch_model_for_export`` records the ``model_kwargs`` it + receives. A non-None dict lets MoE patchers populate forward arguments + without crashing. + """ + captured: dict[str, object] = {} + + fake_onnx_config = MagicMock() + + def record_patch(model, model_kwargs=None): + captured["model_kwargs"] = model_kwargs + return MagicMock() + + fake_onnx_config.patch_model_for_export.side_effect = record_patch + + def fake_ctor(*args: object, **kwargs: object): + return fake_onnx_config + + with patch( + "optimum.exporters.tasks.TasksManager.get_exporter_config_constructor", + return_value=fake_ctor, + ): + HTPExporter._get_optimum_patcher(_FakeModel(), task="keypoint-detection") + + assert captured.get("model_kwargs") == {}, ( + "Expected _get_optimum_patcher to pass an explicit model_kwargs={} " + f"to patch_model_for_export, got {captured.get('model_kwargs')!r}. " + "MoE patchers (e.g. ViTPose dataset_index) need a mutable dict." + ) diff --git a/tests/unit/models/test_vitpose_mapping.py b/tests/unit/models/test_vitpose_mapping.py new file mode 100644 index 000000000..3e70ecd00 --- /dev/null +++ b/tests/unit/models/test_vitpose_mapping.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for ViTPose keypoint-detection model-class resolution. + +Optimum registers the ViTPose ONNX export config but has no +task-to-class entry for ``keypoint-detection``, and transformers' +``AutoModelForKeypointDetection`` only recognises SuperPoint. The +``("vitpose", "keypoint-detection")`` entry in ``MODEL_CLASS_MAPPING`` +bridges that gap so the resolver can load ``VitPoseForPoseEstimation``. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from winml.modelkit.loader import resolve_task +from winml.modelkit.models.hf import MODEL_CLASS_MAPPING +from winml.modelkit.models.hf.vitpose import MODEL_CLASS_MAPPING as VITPOSE_MAPPING + + +class TestVitPoseMapping: + """ViTPose keypoint-detection routes to VitPoseForPoseEstimation.""" + + def test_mapping_entry_registered(self): + """The aggregated mapping exposes the vitpose keypoint-detection entry.""" + assert ("vitpose", "keypoint-detection") in MODEL_CLASS_MAPPING + assert ( + MODEL_CLASS_MAPPING[("vitpose", "keypoint-detection")].__name__ + == "VitPoseForPoseEstimation" + ) + + def test_module_mapping_merged_into_aggregate(self): + """The module-level mapping is included in the aggregated mapping.""" + assert VITPOSE_MAPPING.items() <= MODEL_CLASS_MAPPING.items() + + def test_explicit_task_resolves_vitpose_class(self): + """An explicit keypoint-detection task resolves VitPoseForPoseEstimation.""" + config = MagicMock() + config.model_type = "vitpose" + config.architectures = ["VitPoseForPoseEstimation"] + config._name_or_path = "usyd-community/vitpose-base-simple" + + resolution = resolve_task(config, task="keypoint-detection") + + assert resolution.task == "keypoint-detection" + assert resolution.model_class.__name__ == "VitPoseForPoseEstimation" +