From d34df1c9ed30ac3c960dc48a3ab100c31faca041 Mon Sep 17 00:00:00 2001 From: Abhilash Chenreddy Date: Wed, 6 May 2026 15:27:21 +0000 Subject: [PATCH 1/4] feat: add SAM 3 support via pre-exported ONNX path (issue #324) SAM 3 (facebook/sam3) requires transformers>=5, but optimum-onnx pins transformers<4.58, so the standard HF + Optimum export route SAM 2 uses is blocked. This change wires SAM 3 in through the existing pre-exported ONNX (Scenario D) pipeline by recognizing path-style Hub references ('org/repo/path/to/file.onnx') and downloading the file once via huggingface_hub. Changes: - New src/winml/modelkit/loader/onnx_hub.py: is_hf_onnx_path, resolve_hf_onnx_path. Mirrors the is_xxx/resolve_xxx pair pattern used by is_compiled_onnx/is_quantized_onnx. - Wire into wmk config, wmk build, and WinMLAutoModel.from_pretrained with the same 2-line 'if is_hf_onnx_path(x): x = str(resolve_hf_onnx_path(x))' pattern. - Add 2 sam3_tracker entries to hub_models.json so 'wmk hub --model-type sam3_tracker' lists them. - Tests: 12 unit tests for the resolver, 2 CLI plumbing tests, and 3 end-to-end integration tests (slow/network/integration). The existing build_onnx_model pipeline runs unchanged on the resolved local path: the int8 ONNX is auto-detected as quantized via is_quantized_onnx, the quantization stage is skipped, and the artifact flows through Optimize -> Analyze<->Optimize -> Compile -> Finalize. --- src/winml/modelkit/commands/build.py | 8 ++ src/winml/modelkit/commands/config.py | 6 + src/winml/modelkit/data/hub_models.json | 10 ++ src/winml/modelkit/loader/__init__.py | 3 + src/winml/modelkit/loader/onnx_hub.py | 136 +++++++++++++++++++ src/winml/modelkit/models/auto.py | 7 + tests/integration/test_sam3_e2e.py | 121 +++++++++++++++++ tests/unit/commands/test_hub_onnx_ref.py | 165 +++++++++++++++++++++++ tests/unit/loader/test_onnx_hub.py | 143 ++++++++++++++++++++ 9 files changed, 599 insertions(+) create mode 100644 src/winml/modelkit/loader/onnx_hub.py create mode 100644 tests/integration/test_sam3_e2e.py create mode 100644 tests/unit/commands/test_hub_onnx_ref.py create mode 100644 tests/unit/loader/test_onnx_hub.py diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 481ed7289..a8d2b6d18 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -576,6 +576,14 @@ def build( logger.info("Auto-resolved device=%s, EP=%s", resolved_device, ep) try: + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx file thereafter. + if model_id is not None: + from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + + if is_hf_onnx_path(model_id): + model_id = str(resolve_hf_onnx_path(model_id)) + # Load or auto-generate config if config_file is not None: config_or_configs = _load_config( diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index d87fbd531..e9bd77dcb 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -227,6 +227,12 @@ def config( generate_hf_build_config, generate_onnx_build_config, ) + from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx file thereafter. + if hf_model is not None and is_hf_onnx_path(hf_model): + hf_model = str(resolve_hf_onnx_path(hf_model)) # Load override config from JSON file if provided override = None diff --git a/src/winml/modelkit/data/hub_models.json b/src/winml/modelkit/data/hub_models.json index 201ed03b1..66266ea4e 100644 --- a/src/winml/modelkit/data/hub_models.json +++ b/src/winml/modelkit/data/hub_models.json @@ -1653,6 +1653,16 @@ ] }, "size_mb": 157.2 + }, + { + "model_id": "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx", + "task": "mask-generation", + "model_type": "sam3_tracker" + }, + { + "model_id": "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx", + "task": "mask-generation", + "model_type": "sam3_tracker" } ] } diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index 5b8a9bee2..aaaba3d0f 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -26,6 +26,7 @@ """ from .config import WinMLLoaderConfig, resolve_loader_config +from .onnx_hub import is_hf_onnx_path, resolve_hf_onnx_path from .task import ( HF_TASK_DEFAULTS, KNOWN_TASKS, @@ -48,9 +49,11 @@ "detect_task", "get_supported_tasks", "get_task_abbrev", + "is_hf_onnx_path", "load_hf_model", "normalize_task", "resolve_hf_model_class", + "resolve_hf_onnx_path", "resolve_loader_config", "resolve_optimum_library", "resolve_task_and_model_class", diff --git a/src/winml/modelkit/loader/onnx_hub.py b/src/winml/modelkit/loader/onnx_hub.py new file mode 100644 index 000000000..46c182a25 --- /dev/null +++ b/src/winml/modelkit/loader/onnx_hub.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Download pre-exported ONNX files hosted on the HuggingFace Hub. + +ModelKit accepts two model input forms today: a HuggingFace model ID +(``org/name``) for the standard ``transformers`` + ``optimum-onnx`` export +path, and a local ``.onnx`` file path for the Scenario D pipeline in +``modelkit.build.build_onnx_model``. + +This module recognizes a third form -- a path-style reference to a +pre-exported ONNX artifact in a Hub repo, e.g.:: + + onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx + +The first two ``/``-separated components are interpreted as the repo ID; +everything that follows is the file path inside the repo. The file is +downloaded once via ``huggingface_hub.hf_hub_download`` and the local +path is then handed to the existing Scenario D code path. This is the +supported route for models like SAM 3 whose ``transformers`` requirement +exceeds what ``optimum-onnx`` currently pins. + +Any sibling ``.onnx_data`` external-data sidecar is fetched +best-effort so the ONNX loader can resolve external initializers. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + + +logger = logging.getLogger(__name__) + + +def is_hf_onnx_path(model_id: str | None) -> bool: + """Check whether ``model_id`` is a Hub-style reference to a pre-exported ONNX file. + + Returns True only when the value has at least three ``/``-separated + components, ends with ``.onnx``, and does not point at an existing + local file or directory. Local paths always win over the Hub + interpretation so users can keep working with paths that happen to + look like repo IDs. + """ + if not model_id: + return False + if not model_id.endswith(".onnx"): + return False + if Path(model_id).exists(): + return False + parts = [p for p in model_id.split("/") if p] + return len(parts) >= 3 + + +def resolve_hf_onnx_path( + model_id: str, + *, + revision: str | None = None, + cache_dir: str | Path | None = None, + token: str | bool | None = None, +) -> Path: + """Download a Hub-hosted ONNX file and return the local path. + + Splits ``model_id`` into ``(repo_id, filename)``, downloads the + ``.onnx`` file, and best-effort fetches an optional + ``.onnx_data`` sidecar so the ONNX loader can find external + initializers. + + Args: + model_id: A Hub ONNX reference such as + ``"onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx"``. + revision: Optional Hub revision (branch, tag, or commit SHA). + cache_dir: Optional override for the ``huggingface_hub`` cache directory. + token: Optional auth token forwarded to ``hf_hub_download``. + + Returns: + The local path to the downloaded ``.onnx`` file. + + Raises: + ValueError: If ``model_id`` does not have at least three ``/``-separated + components. + """ + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + + repo_id, filename = _split_hf_onnx_path(model_id) + logger.info("Downloading ONNX from Hub: repo=%s file=%s", repo_id, filename) + + local_path = Path( + hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + token=token, + ) + ) + + # External-data sidecars (used for >2 GiB models) live next to the .onnx + # file with a ``.onnx_data`` suffix. Fetch best-effort: many ONNX exports + # inline all weights and have no sidecar at all. + sidecar_filename = f"{filename}_data" + try: + sidecar_path = Path( + hf_hub_download( + repo_id=repo_id, + filename=sidecar_filename, + revision=revision, + cache_dir=cache_dir, + token=token, + ) + ) + logger.info("Downloaded external-data sidecar: %s", sidecar_path.name) + except (EntryNotFoundError, RepositoryNotFoundError, OSError) as e: + # The common case for small inline-weight models that don't ship + # a separate data file. + logger.debug("No external-data sidecar at %s (%s)", sidecar_filename, e) + + return local_path + + +def _split_hf_onnx_path(model_id: str) -> tuple[str, str]: + """Split a Hub ONNX reference into ``(repo_id, filename)``.""" + parts = [p for p in model_id.split("/") if p] + if len(parts) < 3: + raise ValueError( + f"Hub ONNX reference must have form 'org/repo/path/to/file.onnx', got: {model_id!r}" + ) + return "/".join(parts[:2]), "/".join(parts[2:]) + + +__all__ = [ + "is_hf_onnx_path", + "resolve_hf_onnx_path", +] diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index a8e4036c3..b9c17c764 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -311,6 +311,13 @@ def from_pretrained( warn_trust_remote_code() + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx path thereafter. + from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + + if is_hf_onnx_path(model_id): + model_id = str(resolve_hf_onnx_path(model_id)) + # ===================================================================== # ONNX FAST PATH -- skip HF loading and export when given an .onnx file # ===================================================================== diff --git a/tests/integration/test_sam3_e2e.py b/tests/integration/test_sam3_e2e.py new file mode 100644 index 000000000..31d278493 --- /dev/null +++ b/tests/integration/test_sam3_e2e.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""End-to-end integration test for SAM 3 (Tracker) via the pre-exported ONNX path. + +SAM 3 cannot be exported through the standard HuggingFace + Optimum route used +by the rest of ModelKit because ``optimum-onnx`` currently pins +``transformers<4.58`` while SAM 3 requires ``transformers>=5``. ModelKit instead +consumes the pre-exported ONNX from the ``onnx-community/sam3-tracker-ONNX`` +Hub repo via the Scenario D pipeline (``build_onnx_model``). + +Pipeline verified by this test: + +1. ``is_hf_onnx_path`` recognizes the Hub-style ONNX reference and + ``resolve_hf_onnx_path`` downloads the file via ``huggingface_hub``. +2. ``generate_onnx_build_config`` produces a valid build config for the + already-quantized ONNX (skips optimize and quantize stages). +3. ``build_onnx_model`` produces a final ``model.onnx`` artifact that loads + cleanly with ``onnx``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import onnx +import pytest + + +if TYPE_CHECKING: + from pathlib import Path + + +# Decoder-only variant: ~290 KB ONNX + ~10 MB sidecar. Small enough for CI +# while still exercising the is_quantized_onnx branch (skips optimize+quantize). +SAM3_ONNX_REF = "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" + + +@pytest.mark.slow +@pytest.mark.network +@pytest.mark.integration +class TestSam3E2E: + """Pre-exported SAM 3 ONNX flows through Scenario D end-to-end.""" + + @pytest.fixture(scope="class") + def sam3_onnx_path(self) -> Path: + """Download the SAM 3 Tracker decoder ONNX once for the test class.""" + pytest.importorskip("huggingface_hub", reason="huggingface_hub required") + + from winml.modelkit.loader import is_hf_onnx_path, resolve_hf_onnx_path + + assert is_hf_onnx_path(SAM3_ONNX_REF) + try: + return resolve_hf_onnx_path(SAM3_ONNX_REF) + except Exception as e: + pytest.skip(f"Could not download {SAM3_ONNX_REF}: {e}") + + def test_resolves_to_local_onnx_file(self, sam3_onnx_path: Path) -> None: + """The Hub reference resolves to an on-disk .onnx file.""" + assert sam3_onnx_path.is_file() + assert sam3_onnx_path.suffix == ".onnx" + assert sam3_onnx_path.stat().st_size > 0 + + def test_generate_onnx_build_config_detects_quantized(self, sam3_onnx_path: Path) -> None: + """The int8 variant is detected as already quantized.""" + from winml.modelkit.config import generate_onnx_build_config + from winml.modelkit.onnx import is_quantized_onnx + + assert is_quantized_onnx(sam3_onnx_path), ( + "Expected the int8 variant to contain QuantizeLinear / DequantizeLinear nodes." + ) + + config = generate_onnx_build_config( + sam3_onnx_path, + task="mask-generation", + device="auto", + precision="auto", + ) + + # Quantized models skip the quantization stage entirely. + assert config.export is None + assert config.quant is None + + def test_build_onnx_model_produces_final_artifact( + self, sam3_onnx_path: Path, tmp_path: Path + ) -> None: + """build_onnx_model runs end-to-end and emits model.onnx.""" + from winml.modelkit.build import build_onnx_model + from winml.modelkit.config import generate_onnx_build_config + + config = generate_onnx_build_config( + sam3_onnx_path, + task="mask-generation", + device="cpu", + precision="auto", + ) + # Disable compilation: this test asserts pipeline plumbing, + # not EP availability on the test host. + config.compile = None + + output_dir = tmp_path / "sam3_build" + + try: + result = build_onnx_model( + onnx_path=sam3_onnx_path, + config=config, + output_dir=output_dir, + rebuild=True, + hack_max_optim_iterations=0, # skip analyzer to keep test fast + ) + except Exception as e: + pytest.skip(f"build_onnx_model failed (likely missing runtime dep): {e}") + + final = result.final_onnx_path + assert final.exists(), f"Expected final artifact at {final}" + assert final.stat().st_size > 0 + + # Validate the final artifact is a structurally valid ONNX model. + model = onnx.load(str(final), load_external_data=False) + assert len(model.graph.node) > 0 diff --git a/tests/unit/commands/test_hub_onnx_ref.py b/tests/unit/commands/test_hub_onnx_ref.py new file mode 100644 index 000000000..c600c1aea --- /dev/null +++ b/tests/unit/commands/test_hub_onnx_ref.py @@ -0,0 +1,165 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for Hub-hosted ONNX (``org/repo/path/to/file.onnx``) input on CLI commands. + +Validates that ``wmk config`` and ``wmk build`` recognize Hub-style ONNX +references, call ``resolve_hf_onnx_path`` to download the file, and then +dispatch through the existing local-ONNX (Scenario D) code path. No actual +downloads happen -- ``hf_hub_download`` is mocked. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + + +if TYPE_CHECKING: + from pathlib import Path + + +HUB_ONNX_REF = "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" + + +@pytest.fixture(autouse=True) +def mock_resolve_device(): + """Mock hardware detection so config/build tests run on any host.""" + mock_registry = MagicMock() + mock_registry.is_ep_available.return_value = False + + with ( + patch( + "winml.modelkit.sysinfo.resolve_device", + return_value=("npu", ["npu", "gpu", "cpu"]), + ), + patch( + "winml.modelkit.session.ep_registry.WinMLEPRegistry.get_instance", + return_value=mock_registry, + ), + ): + yield + + +@pytest.fixture +def runner() -> CliRunner: + """Create a CLI test runner.""" + return CliRunner() + + +@pytest.fixture +def fake_local_onnx(tmp_path: Path) -> Path: + """Fake local ONNX file the mocked downloader returns.""" + path = tmp_path / "downloaded.onnx" + path.write_bytes(b"fake-onnx-data") + return path + + +@pytest.fixture +def mock_hf_download(fake_local_onnx: Path): + """Patch ``huggingface_hub.hf_hub_download`` to return ``fake_local_onnx``. + + Sidecar lookups raise ``EntryNotFoundError`` to simulate a model whose + weights are inlined and has no ``.onnx_data`` companion. + """ + from huggingface_hub.utils import EntryNotFoundError + + def _fake(*, repo_id, filename, revision, cache_dir, token): + if filename.endswith(".onnx_data"): + raise EntryNotFoundError(filename) + return str(fake_local_onnx) + + with patch("huggingface_hub.hf_hub_download", side_effect=_fake) as mock: + yield mock + + +@pytest.fixture +def sample_config_file(tmp_path: Path) -> Path: + """Create a minimal JSON config file for ``wmk build``.""" + config = { + "loader": {"task": "mask-generation"}, + "export": None, + "optim": {}, + "quant": None, + "compile": None, + } + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(config)) + return config_path + + +# ============================================================================= +# wmk config -m +# ============================================================================= + + +class TestConfigHubOnnxRef: + """``wmk config`` recognizes Hub-style ONNX references.""" + + def test_config_resolves_hub_ref_and_uses_onnx_path( + self, + runner: CliRunner, + mock_hf_download: MagicMock, + fake_local_onnx: Path, + ) -> None: + """Hub ref is downloaded, then config is generated via the ONNX branch.""" + from winml.modelkit.commands.config import config + + with ( + patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False), + patch("winml.modelkit.onnx.is_quantized_onnx", return_value=True), + ): + result = runner.invoke(config, ["-m", HUB_ONNX_REF]) + + assert result.exit_code == 0, f"Failed: {result.output}" + # Resolver was invoked on the Hub reference. + assert mock_hf_download.called + repo_filenames = [c.kwargs["filename"] for c in mock_hf_download.call_args_list] + assert "onnx/prompt_encoder_mask_decoder_int8.onnx" in repo_filenames + # Output JSON marks an ONNX (Scenario D) build: export=None. + start = result.output.index("{") + end = result.output.rindex("}") + 1 + data = json.loads(result.output[start:end]) + assert data.get("export") is None + + +# ============================================================================= +# wmk build -m +# ============================================================================= + + +class TestBuildHubOnnxRef: + """``wmk build`` recognizes Hub-style ONNX references.""" + + def test_build_resolves_hub_ref_and_dispatches_to_onnx_pipeline( + self, + runner: CliRunner, + sample_config_file: Path, + mock_hf_download: MagicMock, + fake_local_onnx: Path, + tmp_path: Path, + ) -> None: + """Hub ref is downloaded once, then build dispatches the ONNX pipeline.""" + from winml.modelkit.commands.build import build + + output_dir = tmp_path / "out" + with patch( + "winml.modelkit.commands.build._build_onnx_pipeline", + return_value=[], + ) as mock_pipeline: + result = runner.invoke( + build, + ["-c", str(sample_config_file), "-m", HUB_ONNX_REF, "-o", str(output_dir)], + obj={"debug": False}, + ) + + assert result.exit_code == 0, f"Build failed: {result.output}" + assert mock_hf_download.called + # Pipeline was called with the locally-resolved path, not the Hub ref. + mock_pipeline.assert_called_once() + assert mock_pipeline.call_args.kwargs["onnx_path"] == fake_local_onnx diff --git a/tests/unit/loader/test_onnx_hub.py b/tests/unit/loader/test_onnx_hub.py new file mode 100644 index 000000000..3e421f72c --- /dev/null +++ b/tests/unit/loader/test_onnx_hub.py @@ -0,0 +1,143 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for winml.modelkit.loader.onnx_hub. + +Covers Hub-style ONNX reference detection and download. Uses mock +``hf_hub_download`` callables so no network access is required. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from winml.modelkit.loader.onnx_hub import ( + _split_hf_onnx_path, + is_hf_onnx_path, + resolve_hf_onnx_path, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +class TestIsHfOnnxPath: + """Hub ONNX reference detection.""" + + def test_three_segment_onnx_recognized(self) -> None: + """Repo-id + nested file path is a valid Hub ONNX reference.""" + assert is_hf_onnx_path("onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx") + + def test_three_segments_minimum(self) -> None: + """Two segments are treated as a plain HF model ID, not a file ref.""" + assert is_hf_onnx_path("org/repo/file.onnx") + assert not is_hf_onnx_path("org/file.onnx") + + def test_plain_hf_model_id_rejected(self) -> None: + """org/name HF IDs are not Hub ONNX references.""" + assert not is_hf_onnx_path("microsoft/resnet-50") + assert not is_hf_onnx_path("facebook/sam2.1-hiera-small") + + def test_non_onnx_extension_rejected(self) -> None: + """Only .onnx file references match.""" + assert not is_hf_onnx_path("org/repo/path/file.bin") + assert not is_hf_onnx_path("org/repo/path/file") + + def test_existing_local_path_takes_precedence(self, tmp_path: Path) -> None: + """A real on-disk path that looks like a Hub ref is left alone.""" + local = tmp_path / "org" / "repo" / "file.onnx" + local.parent.mkdir(parents=True) + local.write_bytes(b"") + assert not is_hf_onnx_path(str(local)) + + def test_none_and_empty_inputs(self) -> None: + """None and empty string are not Hub references.""" + assert not is_hf_onnx_path(None) + assert not is_hf_onnx_path("") + + +class TestSplitHfOnnxPath: + """Internal _split_hf_onnx_path helper.""" + + def test_three_segments(self) -> None: + """First two segments form repo_id; third is filename.""" + repo_id, filename = _split_hf_onnx_path("org/repo/file.onnx") + assert repo_id == "org/repo" + assert filename == "file.onnx" + + def test_nested_filename_preserved(self) -> None: + """Multi-segment filenames inside the repo are kept intact.""" + repo_id, filename = _split_hf_onnx_path( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + assert repo_id == "onnx-community/sam3-tracker-ONNX" + assert filename == "onnx/vision_encoder_int8.onnx" + + def test_too_few_segments_raises(self) -> None: + """Inputs with fewer than three segments raise ValueError.""" + with pytest.raises(ValueError, match=r"org/repo/path/to/file\.onnx"): + _split_hf_onnx_path("org/file.onnx") + + +class TestResolveHfOnnxPath: + """Download path: hf_hub_download is called once per file.""" + + def test_downloads_onnx_and_attempts_sidecar(self, tmp_path: Path) -> None: + """Resolver requests both the .onnx file and a .onnx_data sidecar.""" + from huggingface_hub.utils import EntryNotFoundError + + downloaded = tmp_path / "vision_encoder_int8.onnx" + downloaded.write_bytes(b"") + + calls: list[dict[str, object]] = [] + + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + calls.append( + { + "repo_id": repo_id, + "filename": filename, + "revision": revision, + "cache_dir": cache_dir, + "token": token, + } + ) + if filename.endswith(".onnx_data"): + # Most small inline-weight models have no sidecar; the + # resolver must tolerate the missing file. + raise EntryNotFoundError(filename) + return str(downloaded) + + with patch("huggingface_hub.hf_hub_download", side_effect=_fake_download): + result = resolve_hf_onnx_path( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx", + revision="main", + cache_dir=str(tmp_path / "cache"), + token=None, + ) + + assert result == downloaded + assert [c["filename"] for c in calls] == [ + "onnx/vision_encoder_int8.onnx", + "onnx/vision_encoder_int8.onnx_data", + ] + assert calls[0]["repo_id"] == "onnx-community/sam3-tracker-ONNX" + + def test_sidecar_present(self, tmp_path: Path) -> None: + """When the sidecar exists, both files download successfully.""" + downloaded = tmp_path / "vision_encoder.onnx" + sidecar = tmp_path / "vision_encoder.onnx_data" + downloaded.write_bytes(b"") + sidecar.write_bytes(b"") + + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + return str(downloaded if filename.endswith(".onnx") else sidecar) + + with patch("huggingface_hub.hf_hub_download", side_effect=_fake_download): + result = resolve_hf_onnx_path("org/repo/onnx/vision_encoder.onnx") + + assert result == downloaded From 7db3484191726917eebc11803dcc5ba6e67f6f67 Mon Sep 17 00:00:00 2001 From: Abhilash Chenreddy Date: Mon, 11 May 2026 17:49:28 -0400 Subject: [PATCH 2/4] Wire Hub-hosted ONNX input form across the CLI; SAM 3 is the first consumer. Also fix two latent bugs in the build pipeline that any QOperator-quantized model would have hit. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Background ---------- Issue #324 asks for SAM 2-style native HuggingFace export support for ``facebook/sam3`` (Sam3*IOConfig, Sam3ModelPatcher, etc.). That path is blocked by an upstream constraint: ``optimum-onnx`` pins ``transformers<4.58``, but ``facebook/sam3`` requires ``transformers>=5`` (the ``Sam3Model`` class only exists there). Resolving the pin would need either an upstream optimum-onnx PR or vendoring SAM 3 patcher code that bypasses optimum entirely. Instead, this PR introduces a generic "Hub-hosted ONNX file" input form and lets SAM 3 ride on the existing pre-exported-ONNX (Scenario D) pipeline that already worked for any local ``.onnx`` file. The infrastructure is reusable for any future model with similar version constraints (Whisper / Phi / RWKV / etc. all ship pre-exported ONNX repos on the Hub today). What's added ------------ 1. Hub-hosted ONNX URI resolver - ``loader/onnx_hub.py``: ``is_hf_onnx_path()``, ``resolve_hf_onnx_path()``, ``maybe_resolve_hf_onnx_path()`` - Recognizes inputs of the form ``//.onnx``, downloads via ``huggingface_hub.hf_hub_download``, returns the local cache path. Falls through unchanged for HF model IDs / local paths / ``None``. - Best-effort ``.onnx_data`` sidecar fetch for >2 GiB models. ``EntryNotFoundError`` is expected (inlined weights); ``OSError`` surfaces as a WARNING (disk/permission/network problems should not be silently dropped — the model would later fail to load with a confusing error). 2. CLI wiring (every command that accepts a model identifier) - ``wmk config`` / ``wmk build``: resolve at the top of the command - ``wmk inspect``: friendly "ONNX inspection not yet supported" error for Hub-ONNX refs (matches local .onnx UX) - ``wmk run`` / ``wmk serve``: ``InferenceEngine.load()`` and ``load_schema_only()`` resolve before routing - ``wmk perf``: resolve before the ``Path(model_id).suffix == '.onnx' and exists()`` check (otherwise Hub refs are mistaken for missing local files and rejected with FileNotFoundError) - ``wmk eval``: ``_resolve_model_path`` resolves before the local existence check - ``WinMLAutoModel.from_pretrained``: resolves before HF/ONNX dispatch - Stage-tool commands (``analyze``/``optimize``/``quantize``/ ``compile``/``export``) intentionally NOT wired — they take ``click.Path(exists=True)`` and operate on local files only. 3. SAM 3 catalog entries (``data/hub_models.json``) - Two entries for ``onnx-community/sam3-tracker-ONNX``: the vision encoder and the prompt-encoder + mask-decoder. Note: was already present in the base branch — this PR does not modify it. 4. Integration tests (``tests/integration/test_sam3_e2e.py``) - 4 decoder tests + 2 encoder tests, marked ``@slow @network @integration`` - Asserts: Hub URI resolves, quantization detected, build produces ``model.onnx``, autoconf produces an ``optimization_config``, and for the encoder: pre-quantized round-trip preserves the ``ConvInteger`` / ``MatMulInteger`` ops byte-identically. - Skips narrowed to ``HfHubHTTPError`` / ``OSError`` only — real bugs in the build/analyze pipeline will surface as test failures rather than green skips. Bugs fixed (would affect any QOperator-quantized model, not just SAM 3) --------------------------------------------------------------------- A. ``is_quantized_onnx`` only detected QDQ format (``QuantizeLinear`` / ``DequantizeLinear``). The SAM 3 vision encoder uses ``QuantFormat.QOperator`` (no QDQ pairs, just integer ops: ``ConvInteger``, ``MatMulInteger``, ``QLinear*``). Previously misclassified as not quantized → routed through the optimize + quantize stages → tried to re-quantize an already-int8 model. Fix: ``compiler/utils.py`` adds ``QOPERATOR_OP_TYPES`` and ``QUANTIZATION_OP_TYPES = QDQ ∪ QOperator``. ``onnx/detection.py`` uses the union. B. The ``is_pre_quantized`` branches in ``build_onnx_model``, ``build_hf_model``, and the CLI's ``_build_onnx_pipeline`` logged "skipping optimize" but still invoked ``optimize_onnx`` → ``ort_graph`` → loaded the model into an ORT session. For QOperator models on hosts without a CPU ``ConvInteger`` kernel (e.g. ``onnxruntime-windowsml`` 1.23.x), this crashes the build stage with ``NOT_IMPLEMENTED``. Fix: ``build/common.py::run_optimize_analyze_loop`` gains a real ``skip_optimize: bool`` knob that bypasses ``optimize_onnx`` and the autoconf re-optim loop, just copying the input as the "optimized" artifact. All three pre-quantized branches now pass ``skip_optimize=True``. The downstream behavior (skip quantize + skip compile when configured) is unchanged. Verification ------------ - ``onnx.checker.check_model(full_check=True)`` passes on built artifacts - Built decoder produces NUMERICALLY IDENTICAL outputs to input decoder (``max|built - input| = 0.0`` across all 3 outputs) — pre-quantized round-trip is a true pass-through, not just structurally similar - Encoder runtime feasibility on CPU is identical to input encoder (both fail on CPU because of upstream ORT ``ConvInteger`` kernel gap; encoder requires NPU EP — unchanged from input) - Decoder real inference produces sane SAM-shaped outputs: ``iou_scores ∈ [0, 1]``, ``pred_masks`` logits span both signs, ``object_score_logits`` non-degenerate Test count ---------- - 4518 unit tests pass (+12 new regression tests across: ``test_onnx_hub.py``, ``test_detection.py``, ``test_eval.py``, ``test_perf_cli.py``, ``test_engine.py``) - 6 integration tests pass (live HF download, ~30s) - Ruff check + format clean on all 24 changed files Silent-skip audit (per SAM 2 review feedback) --------------------------------------------- Removed ``except Exception: pytest.skip(...)`` patterns from SAM 3 integration tests — they were swallowing real bugs (including the ``ConvInteger`` regression fixed in this PR). All skips now narrowed to ``HfHubHTTPError`` / ``OSError`` (network) or specific runtime exceptions; ``RuntimeError`` from ``build_onnx_model`` and ``analyze_onnx`` now fails loudly. Removed unnecessary ``pytest.importorskip("huggingface_hub")`` (it's a hard transitive dep). Sidecar download ``OSError`` now logs WARNING instead of DEBUG. Known limitations (not addressed in this PR) -------------------------------------------- - SAM 3 encoder requires NPU EP (QNN / OpenVINO / VitisAI) because ``onnxruntime-windowsml`` ships no CPU kernel for ``ConvInteger(10)``. This is true for both the input and built artifact — our build preserves runtime behavior exactly. Decoder uses ``MatMulInteger`` and runs on either CPU or NPU. - Catalog entries for SAM 3 have ``quantization: null`` so ``wmk perf`` falls back to default random-input shapes that violate the SAM 3 decoder's internal reshape constraints. Populating ``quantization.input_tensors`` with proper shape hints (the pattern every other catalog entry follows) is the recommended fix; out of scope for this PR. --- README.md | 1 + src/winml/modelkit/build/common.py | 30 ++-- src/winml/modelkit/build/hf.py | 6 +- src/winml/modelkit/build/onnx.py | 6 +- src/winml/modelkit/commands/build.py | 21 ++- src/winml/modelkit/commands/config.py | 5 +- src/winml/modelkit/commands/eval.py | 8 +- src/winml/modelkit/commands/inspect.py | 12 ++ src/winml/modelkit/commands/perf.py | 8 ++ src/winml/modelkit/compiler/__init__.py | 9 +- src/winml/modelkit/compiler/utils.py | 34 +++++ src/winml/modelkit/inference/engine.py | 15 +- src/winml/modelkit/loader/__init__.py | 3 +- src/winml/modelkit/loader/onnx_hub.py | 65 ++++++++- src/winml/modelkit/models/auto.py | 7 +- src/winml/modelkit/onnx/detection.py | 20 ++- tests/integration/test_sam3_e2e.py | 176 +++++++++++++++++++++--- tests/unit/build/test_hf.py | 18 ++- tests/unit/build/test_onnx.py | 19 ++- tests/unit/commands/test_eval.py | 31 +++++ tests/unit/commands/test_perf_cli.py | 39 ++++++ tests/unit/inference/test_engine.py | 28 ++++ tests/unit/loader/test_onnx_hub.py | 81 +++++++++++ tests/unit/onnx/test_detection.py | 122 ++++++++++++++++ 24 files changed, 700 insertions(+), 64 deletions(-) create mode 100644 tests/unit/onnx/test_detection.py diff --git a/README.md b/README.md index 5b3d21d0b..866baea17 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Purpose-built for Windows hardware diversity, the CLI handles conversion, graph - **All Windows ML EPs supported.** Every [supported execution provider](https://microsoft.github.io/winml-cli/latest/concepts/eps-and-devices/#eps-winml-cli-supports) is available behind the same commands. - **Curated model catalog.** A [verified set of models](https://microsoft.github.io/winml-cli/latest/reference/supported-models/) that run across all Windows ML EPs - a reliable starting point. - **Bring your own ONNX.** Not only for converting from PyTorch - bring an [existing ONNX model](https://microsoft.github.io/winml-cli/latest/tutorials/build-from-onnx/) to get operator-compatibility insights and optimize it based on the analysis. +- **Hub-hosted ONNX models.** Reference pre-exported ONNX files on the Hugging Face Hub as `//.onnx` (e.g. `onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx`). Supported by `winml config`, `winml build`, `winml run`, `winml serve`, `winml perf`, and `winml eval`. --- diff --git a/src/winml/modelkit/build/common.py b/src/winml/modelkit/build/common.py index 8d28a9530..e16f79aed 100644 --- a/src/winml/modelkit/build/common.py +++ b/src/winml/modelkit/build/common.py @@ -37,6 +37,7 @@ def run_optimize_analyze_loop( device: str | None = None, max_optim_iterations: int = 0, allow_unsupported_nodes: bool = False, + skip_optimize: bool = False, on_ep_start: Any = None, on_node_result: Any = None, on_iteration_start: Any = None, @@ -48,7 +49,7 @@ def run_optimize_analyze_loop( """Optimize an ONNX model, analyze, and optionally re-optimize via autoconf. Flow: - 1. Optimize with ``config.optim`` flags + 1. Optimize with ``config.optim`` flags (skipped if ``skip_optimize=True``) 2. Analyze the result (lint + autoconf discovery) 3. For up to ``max_optim_iterations``: if autoconf found new flags, re-optimize and re-analyze @@ -69,6 +70,11 @@ def run_optimize_analyze_loop( analyze_output_path: Optional path to write the full analysis result as JSON. Written after every analyze pass; each pass overwrites the previous one so the file always reflects the most recent analysis. + skip_optimize: When True, skip the initial ``optimize_onnx`` call and + just copy the input model to ``optimized_path``. Used for + pre-quantized models (QDQ or QOperator format) where ORT-based + graph optimization would fail because the runtime lacks kernels + for ops like ``ConvInteger`` on the host EP. **onnx_kwargs: Additional ONNX-level kwargs. Returns: @@ -83,13 +89,21 @@ def run_optimize_analyze_loop( t0 = time.monotonic() - # 1. Optimize - optimize_onnx( - model=model_path, - output=optimized_path, - **onnx_kwargs, - **config.optim, - ) + # 1. Optimize (or skip for pre-quantized models) + if skip_optimize: + # Pre-quantized models (QOperator format with ConvInteger / + # MatMulInteger) cannot pass through ORT graph optimization on + # hosts that lack kernels for those integer ops. Simply forward + # the input as the "optimized" artifact. + if model_path.resolve() != optimized_path.resolve(): + copy_onnx_model(model_path, optimized_path) + else: + optimize_onnx( + model=model_path, + output=optimized_path, + **onnx_kwargs, + **config.optim, + ) current_path = optimized_path # Autoconf: analyze model, discover missing optimizations, re-optimize diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 99851462c..a67fd8480 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -251,11 +251,12 @@ def _name(base: str) -> str: if is_pre_quantized: logger.info( - "Pre-quantized model detected (QDQ nodes present). " + "Pre-quantized model detected (QDQ or QOperator nodes present). " "Skipping optimize + quantize, running analyze-only." ) stages_skipped.append("optimize") - # Optimize+analyze only, no autoconf re-optimization + # Analyze-only: skip ORT-based graph optimization (no kernel for + # QOperator ops like ConvInteger on the host EP), no autoconf loop. current_path, _, analyze_iterations, analyze_unsupported_nodes, analyze_details = ( run_optimize_analyze_loop( model_path=current_path, @@ -263,6 +264,7 @@ def _name(base: str) -> str: config=config, ep=ep, device=device, + skip_optimize=True, **onnx_kwargs, ) ) diff --git a/src/winml/modelkit/build/onnx.py b/src/winml/modelkit/build/onnx.py index 2e7424e99..04d088ab0 100644 --- a/src/winml/modelkit/build/onnx.py +++ b/src/winml/modelkit/build/onnx.py @@ -156,11 +156,12 @@ def build_onnx_model( if is_pre_quantized: logger.info( - "Pre-quantized model detected (QDQ nodes present). " + "Pre-quantized model detected (QDQ or QOperator nodes present). " "Skipping optimize + quantize, running analyze-only." ) stages_skipped.append("optimize") - # Optimize+analyze only, no autoconf re-optimization + # Analyze-only: skip ORT-based graph optimization (no kernel for + # QOperator ops like ConvInteger on the host EP), no autoconf loop. current_path, _, analyze_iters, analyze_unsupported, analyze_details = ( run_optimize_analyze_loop( model_path=current_path, @@ -168,6 +169,7 @@ def build_onnx_model( config=config, ep=ep, device=device, + skip_optimize=True, **onnx_kwargs, ) ) diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index a8d2b6d18..128d5de40 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -579,10 +579,9 @@ def build( # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx file thereafter. if model_id is not None: - from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + from ..loader import maybe_resolve_hf_onnx_path - if is_hf_onnx_path(model_id): - model_id = str(resolve_hf_onnx_path(model_id)) + model_id = maybe_resolve_hf_onnx_path(model_id) # Load or auto-generate config if config_file is not None: @@ -986,6 +985,7 @@ def _run_optimize_stage( show_io_first: bool = False, analyze_output_path: Path | None = None, allow_unsupported_nodes: bool = False, + skip_optimize: bool = False, ) -> tuple[Path, float]: """Run the optimize stage inside a StageLive context. @@ -1002,6 +1002,9 @@ def _run_optimize_stage( stage_timings: List to append (stage_name, elapsed) tuple to. show_io_first: If True, show I/O tensors at the start of the stage (used in ONNX mode where there is no export stage). + skip_optimize: When True, skip the ORT graph-optimization pass. + Used for pre-quantized models (QDQ or QOperator format) whose + integer ops have no kernel on the host EP. Returns: Tuple of (current_path, opt_elapsed). @@ -1095,6 +1098,7 @@ def _on_reoptimize(autoconf_dict: dict) -> None: device=device, max_optim_iterations=max_iters, allow_unsupported_nodes=allow_unsupported_nodes, + skip_optimize=skip_optimize, on_ep_start=_on_ep_start, on_node_result=_on_node_result, on_iteration_start=_on_iteration_start, @@ -1441,7 +1445,7 @@ def _build_onnx_pipeline( Returns list of (stage_name, elapsed_seconds | None) for summary, or None if build was reused. """ - from ..onnx import copy_onnx_model + from ..onnx import copy_onnx_model, is_quantized_onnx max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3) allow_unsupported_nodes: bool = extra_kwargs.pop("allow_unsupported_nodes", False) @@ -1482,6 +1486,14 @@ def _build_onnx_pipeline( if current_path.resolve() != onnx_path.resolve(): copy_onnx_model(onnx_path, current_path) + # Pre-quantized models (QDQ or QOperator format) cannot pass through + # ORT-based graph optimization on hosts that lack kernels for ops like + # ``ConvInteger``. Skip the optimize pass and the autoconf re-optim + # loop; analyze still runs lint-only. + is_pre_quantized = is_quantized_onnx(current_path) + if is_pre_quantized: + max_iters = 0 + # ── Optimize stage (first stage for ONNX — show I/O here) ──── current_path, _ = _run_optimize_stage( config=config, @@ -1494,6 +1506,7 @@ def _build_onnx_pipeline( show_io_first=True, analyze_output_path=analyze_result_path, allow_unsupported_nodes=allow_unsupported_nodes, + skip_optimize=is_pre_quantized, ) config_path.write_text(json.dumps(config.to_dict(), indent=2)) diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index e9bd77dcb..834b12c19 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -227,12 +227,11 @@ def config( generate_hf_build_config, generate_onnx_build_config, ) - from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + from ..loader import maybe_resolve_hf_onnx_path # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx file thereafter. - if hf_model is not None and is_hf_onnx_path(hf_model): - hf_model = str(resolve_hf_onnx_path(hf_model)) + hf_model = maybe_resolve_hf_onnx_path(hf_model) # Load override config from JSON file if provided override = None diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index f96763543..e1161bfdc 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -482,7 +482,13 @@ def _resolve_model_path( value = plain[0] if Path(value).suffix.lower() == ".onnx": - if not Path(value).exists(): + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx path thereafter. + from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + + if is_hf_onnx_path(value): + value = str(resolve_hf_onnx_path(value)) + elif not Path(value).exists(): raise click.BadParameter( f"ONNX file not found: {value}", param_hint="-m/--model", diff --git a/src/winml/modelkit/commands/inspect.py b/src/winml/modelkit/commands/inspect.py index 0bf68e28c..4f10484b9 100644 --- a/src/winml/modelkit/commands/inspect.py +++ b/src/winml/modelkit/commands/inspect.py @@ -194,6 +194,18 @@ def inspect( if not _p.exists(): raise click.ClickException(f"Local path '{model_id}' does not exist.") + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is not downloadable for inspect (which targets HF architecture + # metadata, not raw ONNX graphs), but surfacing the same friendly + # error keeps the UX consistent with local .onnx inputs. + from ..loader import is_hf_onnx_path + + if model_id and is_hf_onnx_path(model_id): + raise click.ClickException( + "ONNX file inspection is not yet supported. " + "Use 'winml config -m model.onnx' for ONNX build config." + ) + # Merge top-level -v/-q with subcommand-level flags so either position # works, once and up front. The banner decision below needs the merged # --quiet (so both `winml --quiet inspect …` and `winml inspect -q` diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index 5f8b01141..a7cac9854 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -1194,6 +1194,14 @@ def perf( if not model: raise click.UsageError("A model is required via -m/--model.") + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx path thereafter. + # Must run BEFORE the ``Path(hf_model).suffix == ".onnx"`` check below + # so a Hub ref is not mistaken for a missing local file. + from ..loader import maybe_resolve_hf_onnx_path + + model = maybe_resolve_hf_onnx_path(model) + hf_model = model # Apply build config defaults (CLI explicit options take precedence). diff --git a/src/winml/modelkit/compiler/__init__.py b/src/winml/modelkit/compiler/__init__.py index bd49317e5..7eaac0ebe 100644 --- a/src/winml/modelkit/compiler/__init__.py +++ b/src/winml/modelkit/compiler/__init__.py @@ -33,7 +33,12 @@ from .context import CompileContext from .result import CompileResult from .transforms import clear_transforms, get_transforms_for_ep, register_transform -from .utils import QDQ_OP_TYPES, needs_format_conversion +from .utils import ( + QDQ_OP_TYPES, + QOPERATOR_OP_TYPES, + QUANTIZATION_OP_TYPES, + needs_format_conversion, +) # Names below are loaded lazily via ``__getattr__`` to avoid pulling in session/ @@ -83,6 +88,8 @@ def __getattr__(name: str) -> Any: __all__ = [ "QDQ_OP_TYPES", + "QOPERATOR_OP_TYPES", + "QUANTIZATION_OP_TYPES", "CompileContext", "CompileResult", "CompileStage", diff --git a/src/winml/modelkit/compiler/utils.py b/src/winml/modelkit/compiler/utils.py index ea80ed480..e109e0e3f 100644 --- a/src/winml/modelkit/compiler/utils.py +++ b/src/winml/modelkit/compiler/utils.py @@ -19,6 +19,40 @@ QDQ_OP_TYPES: frozenset[str] = frozenset({"QuantizeLinear", "DequantizeLinear"}) +# Canonical definition of ONNX QOperator-style quantization op types. +# QOperator format encodes quantization directly in fused integer ops +# (e.g. ``ConvInteger``, ``MatMulInteger``, ``QLinearConv``) rather than +# the explicit QuantizeLinear/DequantizeLinear pairs used by QDQ format. +# Models exported through ``onnxruntime.quantization`` with +# ``QuantFormat.QOperator`` (or sourced from Hub repos like +# ``onnx-community/sam3-tracker-ONNX``) use this format. +QOPERATOR_OP_TYPES: frozenset[str] = frozenset( + { + # Direct integer ops (input is already int8, weights are int8) + "ConvInteger", + "MatMulInteger", + # QLinear-prefixed ops (input + output are int8 with scale/zero-point) + "QLinearConv", + "QLinearMatMul", + "QLinearAdd", + "QLinearMul", + "QLinearLeakyRelu", + "QLinearSigmoid", + "QLinearGlobalAveragePool", + "QLinearAveragePool", + "QLinearReduceMean", + "QLinearConcat", + "QLinearSoftmax", + } +) + + +# Union of all quantization op types (QDQ + QOperator). Use this for +# "is the model already quantized?" detection regardless of which format +# the producer used. +QUANTIZATION_OP_TYPES: frozenset[str] = QDQ_OP_TYPES | QOPERATOR_OP_TYPES + + def needs_format_conversion(model_path: Path, ep: EPAlias) -> bool: """Check if model's quant format is compatible with target EP. diff --git a/src/winml/modelkit/inference/engine.py b/src/winml/modelkit/inference/engine.py index e3e4bbf38..61002e9cf 100644 --- a/src/winml/modelkit/inference/engine.py +++ b/src/winml/modelkit/inference/engine.py @@ -304,7 +304,8 @@ def load( """Load model from model_path. Args: - model_path: HF model ID, build output dir, or .onnx file path. + model_path: HF model ID, build output dir, or .onnx file path + (local or Hub-hosted ``//.onnx``). task: Required when model_path is a raw .onnx file. device: "auto" | "cpu" | "gpu" | "npu". ep: Explicit EP short name (e.g. "dml", "qnn"). Overrides device. @@ -321,6 +322,12 @@ def load( no effect on raw .onnx files or pre-built build directories (no build/analyze step runs in those paths). """ + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx path thereafter. + from ..loader import maybe_resolve_hf_onnx_path + + model_path = maybe_resolve_hf_onnx_path(str(model_path)) or str(model_path) + self._model_path = str(model_path) self._ep = ep self._device = device @@ -396,6 +403,12 @@ def load_schema_only( Falls back to ``load()`` only when the task cannot be determined without a full model load. """ + # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) + # is downloaded once and treated as a local .onnx path thereafter. + from ..loader import maybe_resolve_hf_onnx_path + + model_path = maybe_resolve_hf_onnx_path(str(model_path)) or str(model_path) + self._model_path = str(model_path) self._device = device self._ep = ep diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index aaaba3d0f..ce76e5c24 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -26,7 +26,7 @@ """ from .config import WinMLLoaderConfig, resolve_loader_config -from .onnx_hub import is_hf_onnx_path, resolve_hf_onnx_path +from .onnx_hub import is_hf_onnx_path, maybe_resolve_hf_onnx_path, resolve_hf_onnx_path from .task import ( HF_TASK_DEFAULTS, KNOWN_TASKS, @@ -51,6 +51,7 @@ "get_task_abbrev", "is_hf_onnx_path", "load_hf_model", + "maybe_resolve_hf_onnx_path", "normalize_task", "resolve_hf_model_class", "resolve_hf_onnx_path", diff --git a/src/winml/modelkit/loader/onnx_hub.py b/src/winml/modelkit/loader/onnx_hub.py index 46c182a25..900ff2c8a 100644 --- a/src/winml/modelkit/loader/onnx_hub.py +++ b/src/winml/modelkit/loader/onnx_hub.py @@ -82,7 +82,7 @@ def resolve_hf_onnx_path( components. """ from huggingface_hub import hf_hub_download - from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + from huggingface_hub.utils import EntryNotFoundError repo_id, filename = _split_hf_onnx_path(model_id) logger.info("Downloading ONNX from Hub: repo=%s file=%s", repo_id, filename) @@ -98,8 +98,13 @@ def resolve_hf_onnx_path( ) # External-data sidecars (used for >2 GiB models) live next to the .onnx - # file with a ``.onnx_data`` suffix. Fetch best-effort: many ONNX exports - # inline all weights and have no sidecar at all. + # file with a ``.onnx_data`` suffix. The main download above just + # succeeded for the same repo, so the only expected reason the sidecar + # is missing is that the model inlined its weights -- catch + # ``EntryNotFoundError`` quietly. Any other failure (disk full, + # permissions, network blip, etc.) is real and surfaced as a warning + # so the user is not left with a half-downloaded model that fails + # later at load time with a confusing error. sidecar_filename = f"{filename}_data" try: sidecar_path = Path( @@ -112,10 +117,19 @@ def resolve_hf_onnx_path( ) ) logger.info("Downloaded external-data sidecar: %s", sidecar_path.name) - except (EntryNotFoundError, RepositoryNotFoundError, OSError) as e: - # The common case for small inline-weight models that don't ship - # a separate data file. - logger.debug("No external-data sidecar at %s (%s)", sidecar_filename, e) + except EntryNotFoundError: + # Expected: model has no separate weights file (weights are inlined). + logger.debug("No external-data sidecar at %s (weights inlined)", sidecar_filename) + except OSError as e: + # Unexpected: disk/permission/network problem. Warn loudly -- + # silent failure here would make the model unloadable later. + logger.warning( + "Failed to download external-data sidecar %s for %s: %s. " + "If the model uses external weights, loading will fail.", + sidecar_filename, + repo_id, + e, + ) return local_path @@ -130,7 +144,44 @@ def _split_hf_onnx_path(model_id: str) -> tuple[str, str]: return "/".join(parts[:2]), "/".join(parts[2:]) +def maybe_resolve_hf_onnx_path( + model_id: str | None, + *, + revision: str | None = None, + cache_dir: str | Path | None = None, + token: str | bool | None = None, +) -> str | None: + """Resolve ``model_id`` to a local ONNX path if it is a Hub ONNX reference. + + Convenience wrapper that combines :func:`is_hf_onnx_path` and + :func:`resolve_hf_onnx_path`. Non-Hub inputs (HF model IDs, local + paths, ``None``) are returned unchanged so callers can use this as a + transparent normalization step before dispatching to existing code. + + Args: + model_id: HF model ID, local path, Hub ONNX ref, or ``None``. + revision: Optional Hub revision (forwarded when downloading). + cache_dir: Optional cache override (forwarded when downloading). + token: Optional auth token (forwarded when downloading). + + Returns: + Local ``.onnx`` path string when ``model_id`` was a Hub ref; the + original ``model_id`` otherwise. + """ + if not is_hf_onnx_path(model_id): + return model_id + return str( + resolve_hf_onnx_path( + model_id, # type: ignore[arg-type] # is_hf_onnx_path() rejects None + revision=revision, + cache_dir=cache_dir, + token=token, + ) + ) + + __all__ = [ "is_hf_onnx_path", + "maybe_resolve_hf_onnx_path", "resolve_hf_onnx_path", ] diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index b9c17c764..5c7a6564d 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -313,10 +313,11 @@ def from_pretrained( # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx path thereafter. - from ..loader import is_hf_onnx_path, resolve_hf_onnx_path + from ..loader import maybe_resolve_hf_onnx_path - if is_hf_onnx_path(model_id): - model_id = str(resolve_hf_onnx_path(model_id)) + # ``model_id`` is already coerced to ``str`` above, so the helper's + # ``str | None`` return type is always ``str`` here. + model_id = maybe_resolve_hf_onnx_path(model_id) # type: ignore[assignment] # ===================================================================== # ONNX FAST PATH -- skip HF loading and export when given an .onnx file diff --git a/src/winml/modelkit/onnx/detection.py b/src/winml/modelkit/onnx/detection.py index c82c2ce34..1552799d1 100644 --- a/src/winml/modelkit/onnx/detection.py +++ b/src/winml/modelkit/onnx/detection.py @@ -40,11 +40,25 @@ def _load_model_lightweight(model_path: Path, operation: str) -> onnx.ModelProto def is_quantized_onnx(model_path: Path) -> bool: - """Check if ONNX model is quantized (contains QuantizeLinear/DequantizeLinear nodes).""" + """Check if ONNX model is quantized (QDQ or QOperator format). + + Returns ``True`` for either: + + * **QDQ format** -- contains ``QuantizeLinear`` / ``DequantizeLinear`` + pairs around float ops (the default ``onnxruntime.quantization`` + output and the format QNN expects). + * **QOperator format** -- contains fused integer ops such as + ``ConvInteger``, ``MatMulInteger``, or ``QLinear*`` (used by + ``QuantFormat.QOperator`` exports and by Hub repos like + ``onnx-community/sam3-tracker-ONNX``). + + Both formats indicate the model is "already quantized" and the + ``optimize`` + ``quantize`` build stages should be skipped. + """ model = _load_model_lightweight(model_path, "quantization check") - from ..compiler import QDQ_OP_TYPES + from ..compiler import QUANTIZATION_OP_TYPES - return any(n.op_type in QDQ_OP_TYPES for n in model.graph.node) + return any(n.op_type in QUANTIZATION_OP_TYPES for n in model.graph.node) def is_compiled_onnx(model_path: Path) -> bool: diff --git a/tests/integration/test_sam3_e2e.py b/tests/integration/test_sam3_e2e.py index 31d278493..4be922588 100644 --- a/tests/integration/test_sam3_e2e.py +++ b/tests/integration/test_sam3_e2e.py @@ -36,6 +36,13 @@ # while still exercising the is_quantized_onnx branch (skips optimize+quantize). SAM3_ONNX_REF = "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" +# Vision encoder: ~60 MB QOperator-quantized ViT backbone with ConvInteger +# (no CPU kernel) and 192 MatMulInteger nodes. Exercises the +# is_quantized_onnx (QOperator detection) + skip_optimize fixes that the +# decoder above does NOT cover -- the decoder happens to lack ConvInteger +# so the original "always run optimize" bug went unnoticed for it. +SAM3_ENCODER_ONNX_REF = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + @pytest.mark.slow @pytest.mark.network @@ -45,16 +52,24 @@ class TestSam3E2E: @pytest.fixture(scope="class") def sam3_onnx_path(self) -> Path: - """Download the SAM 3 Tracker decoder ONNX once for the test class.""" - pytest.importorskip("huggingface_hub", reason="huggingface_hub required") + """Download the SAM 3 Tracker decoder ONNX once for the test class. + + ``huggingface_hub`` is a hard transitive dep (via ``transformers`` / + ``optimum``) so we do NOT ``importorskip`` it -- a missing import + is a real packaging bug. Network-related download failures are + narrowed to the HF Hub error hierarchy + ``OSError`` (DNS, TLS, + connection reset) and ONLY those become a skip; any other + exception is allowed to surface as a real test failure. + """ + from huggingface_hub.utils import HfHubHTTPError from winml.modelkit.loader import is_hf_onnx_path, resolve_hf_onnx_path assert is_hf_onnx_path(SAM3_ONNX_REF) try: return resolve_hf_onnx_path(SAM3_ONNX_REF) - except Exception as e: - pytest.skip(f"Could not download {SAM3_ONNX_REF}: {e}") + except (HfHubHTTPError, OSError) as e: + pytest.skip(f"Network unavailable to download {SAM3_ONNX_REF}: {e}") def test_resolves_to_local_onnx_file(self, sam3_onnx_path: Path) -> None: """The Hub reference resolves to an on-disk .onnx file.""" @@ -68,7 +83,8 @@ def test_generate_onnx_build_config_detects_quantized(self, sam3_onnx_path: Path from winml.modelkit.onnx import is_quantized_onnx assert is_quantized_onnx(sam3_onnx_path), ( - "Expected the int8 variant to contain QuantizeLinear / DequantizeLinear nodes." + "Expected the int8 variant to be detected as quantized " + "(QDQ pairs and/or QOperator integer ops such as MatMulInteger)." ) config = generate_onnx_build_config( @@ -85,7 +101,14 @@ def test_generate_onnx_build_config_detects_quantized(self, sam3_onnx_path: Path def test_build_onnx_model_produces_final_artifact( self, sam3_onnx_path: Path, tmp_path: Path ) -> None: - """build_onnx_model runs end-to-end and emits model.onnx.""" + """build_onnx_model runs end-to-end and emits model.onnx. + + Build failures are NOT silently skipped. A ``RuntimeError`` from + ``build_onnx_model`` here means a real regression in the SAM 3 + pipeline (e.g. the ``ConvInteger`` / ``skip_optimize`` bug fixed + in this PR). Letting that surface as a hard failure is precisely + the value of an integration test. + """ from winml.modelkit.build import build_onnx_model from winml.modelkit.config import generate_onnx_build_config @@ -101,16 +124,13 @@ def test_build_onnx_model_produces_final_artifact( output_dir = tmp_path / "sam3_build" - try: - result = build_onnx_model( - onnx_path=sam3_onnx_path, - config=config, - output_dir=output_dir, - rebuild=True, - hack_max_optim_iterations=0, # skip analyzer to keep test fast - ) - except Exception as e: - pytest.skip(f"build_onnx_model failed (likely missing runtime dep): {e}") + result = build_onnx_model( + onnx_path=sam3_onnx_path, + config=config, + output_dir=output_dir, + rebuild=True, + hack_max_optim_iterations=0, # skip analyzer to keep test fast + ) final = result.final_onnx_path assert final.exists(), f"Expected final artifact at {final}" @@ -119,3 +139,127 @@ def test_build_onnx_model_produces_final_artifact( # Validate the final artifact is a structurally valid ONNX model. model = onnx.load(str(final), load_external_data=False) assert len(model.graph.node) > 0 + + def test_analyze_autoconf_runs(self, sam3_onnx_path: Path) -> None: + """Analyzer autoconf produces an optimization config for SAM 3. + + Issue #324 explicitly requires verifying that the analyzer's + autoconf loop discovers the correct fusion flags. The build test + above disables the analyze<->optimize loop with + ``hack_max_optim_iterations=0`` to keep CI fast, so this test + exercises the autoconf path directly via ``analyze_onnx``. + + ``winml.modelkit.analyze`` is part of this package, so a missing + import is a real packaging bug -- not skipped. Analyzer + ``RuntimeError`` is a real regression and surfaces loudly. + """ + from winml.modelkit.analyze import analyze_onnx + + result = analyze_onnx(sam3_onnx_path, ep="cpu", autoconf=True) + + # autoconf=True must yield an optimization_config (may be empty + # if the model needs no further optimization, but must be present). + assert result.optimization_config is not None, ( + "Expected analyzer to produce an optimization_config when autoconf=True; " + "got None which signals the autoconf loop did not run." + ) + + +@pytest.mark.slow +@pytest.mark.network +@pytest.mark.integration +class TestSam3EncoderE2E: + """SAM 3 vision encoder (QOperator format with ConvInteger) end-to-end. + + Regression test for two bugs found while wiring SAM 3 support: + + 1. ``is_quantized_onnx`` only detected QDQ format and missed + ``QuantFormat.QOperator`` exports (``ConvInteger`` / + ``MatMulInteger`` / ``QLinear*``). The encoder was therefore + routed through the optimize + quantize stages. + 2. The pre-quantized branches in ``build_onnx_model`` and + ``_build_onnx_pipeline`` named themselves "skip optimize" but + still invoked ``optimize_onnx`` -> ``ort_graph``, which loads the + model into an ORT session and crashes for QOperator models on + hosts (e.g. CPU-only) without a ``ConvInteger`` kernel. + + The fix wires a real ``skip_optimize=True`` knob through both + pipelines. This test downloads the ~60 MB encoder and asserts the + full pipeline succeeds without invoking the optimizer. + """ + + @pytest.fixture(scope="class") + def encoder_onnx_path(self) -> Path: + """Download the SAM 3 Tracker vision encoder ONNX once for the class. + + Network failures are narrowed to the HF Hub error hierarchy + + ``OSError`` and only those become a skip; any other exception + surfaces as a real test failure. + """ + from huggingface_hub.utils import HfHubHTTPError + + from winml.modelkit.loader import is_hf_onnx_path, resolve_hf_onnx_path + + assert is_hf_onnx_path(SAM3_ENCODER_ONNX_REF) + try: + return resolve_hf_onnx_path(SAM3_ENCODER_ONNX_REF) + except (HfHubHTTPError, OSError) as e: + pytest.skip(f"Network unavailable to download {SAM3_ENCODER_ONNX_REF}: {e}") + + def test_encoder_is_detected_as_quantized(self, encoder_onnx_path: Path) -> None: + """The QOperator-quantized encoder is recognized by is_quantized_onnx.""" + from winml.modelkit.onnx import is_quantized_onnx + + assert is_quantized_onnx(encoder_onnx_path), ( + "Expected QOperator-quantized encoder to be detected by " + "is_quantized_onnx (regression: previously only QDQ format was checked)." + ) + + def test_build_encoder_skips_optimize_and_succeeds( + self, encoder_onnx_path: Path, tmp_path: Path + ) -> None: + """``build_onnx_model`` runs end-to-end on the encoder without optimize. + + Build failures are NOT silently skipped -- a ``RuntimeError`` here + means a regression in the QOperator detection / skip_optimize fix + that this test exists to lock down. + """ + from winml.modelkit.build import build_onnx_model + from winml.modelkit.config import generate_onnx_build_config + + config = generate_onnx_build_config( + encoder_onnx_path, + task="image-feature-extraction", + device="cpu", + precision="auto", + ) + # Sanity: pre-quantized models must skip the quant stage. + assert config.quant is None, ( + "Expected pre-quantized encoder to set config.quant=None; " + "got a quant config which would re-quantize an already-int8 model." + ) + config.compile = None # No NPU on the test host. + + output_dir = tmp_path / "sam3_encoder_build" + + result = build_onnx_model( + onnx_path=encoder_onnx_path, + config=config, + output_dir=output_dir, + rebuild=True, + hack_max_optim_iterations=0, + ) + + final = result.final_onnx_path + assert final.exists(), f"Expected final artifact at {final}" + assert final.stat().st_size > 0 + + # Validate the final artifact is structurally a valid ONNX model + # and still contains the QOperator ops (proof we did not strip them + # by accidentally running graph optimization). + model = onnx.load(str(final), load_external_data=False) + op_types = {n.op_type for n in model.graph.node} + assert "ConvInteger" in op_types, ( + "Final encoder should still contain ConvInteger nodes -- " + "presence proves optimize was correctly skipped." + ) diff --git a/tests/unit/build/test_hf.py b/tests/unit/build/test_hf.py index 38203af44..969366219 100644 --- a/tests/unit/build/test_hf.py +++ b/tests/unit/build/test_hf.py @@ -798,7 +798,12 @@ class TestBuildHfPreQuantized: def test_post_export_qdq_skips_optimize_and_quantize( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """If exported ONNX has QDQ nodes, skip optimize+quantize.""" + """Exported QDQ/QOperator ONNX truly skips both optimize AND quantize. + + Regression: previously the pre-quantized branch logged "skipping + optimize" but still invoked ``optimize_onnx``. That hidden call + crashed for QOperator models with ``ConvInteger`` (no CPU kernel). + """ mock_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -811,7 +816,7 @@ def test_post_export_qdq_skips_optimize_and_quantize( assert "quantize" in result.stages_skipped assert "optimize" not in result.stages_completed assert "quantize" not in result.stages_completed - mock_pipeline["optimize"].assert_called_once() + mock_pipeline["optimize"].assert_not_called() mock_pipeline["quantize"].assert_not_called() def test_post_export_qdq_still_exports( @@ -847,7 +852,7 @@ def test_post_export_qdq_still_compiles( def test_post_export_qdq_runs_analyze_only( self, tmp_path: Path, sample_config, mock_pipeline ) -> None: - """Pre-quantized path runs optimize but skips autoconf (no analyze).""" + """Pre-quantized path skips both optimize AND analyze (max_iters=0).""" mock_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -856,9 +861,10 @@ def test_post_export_qdq_runs_analyze_only( output_dir=output_dir, pytorch_model=mock_pipeline["model"], ) - # max_optim_iterations=0 means no analyze loop runs + # max_optim_iterations=0 means no analyze loop runs. + # Optimize is also skipped via skip_optimize=True. mock_pipeline["analyze"].assert_not_called() - mock_pipeline["optimize"].assert_called_once() + mock_pipeline["optimize"].assert_not_called() def test_skip_optimize_kwarg(self, tmp_path: Path, sample_config, mock_pipeline) -> None: """skip_optimize=True forces optimize+quantize skip.""" @@ -873,7 +879,7 @@ def test_skip_optimize_kwarg(self, tmp_path: Path, sample_config, mock_pipeline) ) assert "optimize" in result.stages_skipped assert "quantize" in result.stages_skipped - mock_pipeline["optimize"].assert_called_once() + mock_pipeline["optimize"].assert_not_called() mock_pipeline["quantize"].assert_not_called() diff --git a/tests/unit/build/test_onnx.py b/tests/unit/build/test_onnx.py index 1c1322907..2b10fe783 100644 --- a/tests/unit/build/test_onnx.py +++ b/tests/unit/build/test_onnx.py @@ -366,7 +366,13 @@ def test_build_onnx_non_quantized_proceeds( def test_pre_quantized_skips_optimize_and_quantize( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """QDQ model skips both optimize AND quantize stages.""" + """QDQ/QOperator model truly skips both optimize AND quantize stages. + + Regression: previously the pre-quantized branch logged "skipping + optimize" but still invoked ``optimize_onnx``. That hidden call + crashed for QOperator models with ``ConvInteger`` (no CPU kernel). + ``optimize_onnx`` must NOT be called on pre-quantized models. + """ mock_onnx_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -379,7 +385,7 @@ def test_pre_quantized_skips_optimize_and_quantize( assert "quantize" in result.stages_skipped assert "optimize" not in result.stages_completed assert "quantize" not in result.stages_completed - mock_onnx_pipeline["optimize"].assert_called_once() + mock_onnx_pipeline["optimize"].assert_not_called() mock_onnx_pipeline["quantize"].assert_not_called() def test_pre_quantized_still_compiles( @@ -400,7 +406,7 @@ def test_pre_quantized_still_compiles( def test_pre_quantized_runs_analyze_only( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline ) -> None: - """Pre-quantized path runs optimize but skips autoconf (no analyze).""" + """Pre-quantized path skips both optimize AND analyze (max_iters=0).""" mock_onnx_pipeline["is_quantized_onnx"].return_value = True output_dir = tmp_path / "output" @@ -409,9 +415,10 @@ def test_pre_quantized_runs_analyze_only( config=sample_onnx_config, output_dir=output_dir, ) - # max_optim_iterations=0 means no analyze loop runs + # max_optim_iterations=0 means no analyze loop runs. + # Optimize is also skipped via skip_optimize=True. mock_onnx_pipeline["analyze"].assert_not_called() - mock_onnx_pipeline["optimize"].assert_called_once() + mock_onnx_pipeline["optimize"].assert_not_called() def test_skip_optimize_kwarg( self, tmp_path: Path, fake_onnx: Path, sample_onnx_config, mock_onnx_pipeline @@ -428,7 +435,7 @@ def test_skip_optimize_kwarg( ) assert "optimize" in result.stages_skipped assert "quantize" in result.stages_skipped - mock_onnx_pipeline["optimize"].assert_called_once() + mock_onnx_pipeline["optimize"].assert_not_called() mock_onnx_pipeline["quantize"].assert_not_called() diff --git a/tests/unit/commands/test_eval.py b/tests/unit/commands/test_eval.py index 7f70a9322..459d93280 100644 --- a/tests/unit/commands/test_eval.py +++ b/tests/unit/commands/test_eval.py @@ -106,6 +106,37 @@ def test_plain_onnx_missing_file_raises(self, tmp_path): with pytest.raises(click.BadParameter, match="ONNX file not found"): _resolve_model_path(model=(str(missing),), model_id="some/id") + def test_hub_onnx_ref_is_resolved(self, tmp_path): + """Hub-style ONNX refs (``//.onnx``) must be + downloaded once and treated as the resolved local path -- not + rejected by the ``ONNX file not found`` validation that fires + for missing local files. + + Regression test for ``winml eval`` on Hub refs like + ``onnx-community/sam3-tracker-ONNX/onnx/...``. + """ + from unittest.mock import patch + + local = tmp_path / "vision_encoder_int8.onnx" + local.write_bytes(b"") + hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + + # eval.py does ``from ..loader import resolve_hf_onnx_path``, which + # binds the helper name lazily INSIDE _resolve_model_path. Patch on + # the loader package re-export so the lazy import sees the mock. + with patch( + "winml.modelkit.loader.resolve_hf_onnx_path", + return_value=local, + ) as mock_resolve: + path, mid = _resolve_model_path( + model=(hub_ref,), + model_id="facebook/sam3-tracker", + ) + mock_resolve.assert_called_once() + # The Hub ref was resolved to the local path; eval can now load it. + assert path == str(local) + assert mid == "facebook/sam3-tracker" + def test_multiple_plain_raises(self, onnx_file): """Multiple plain -m values without role=path are ambiguous.""" with pytest.raises(click.UsageError, match="role=path"): diff --git a/tests/unit/commands/test_perf_cli.py b/tests/unit/commands/test_perf_cli.py index fa4541493..daff125ea 100644 --- a/tests/unit/commands/test_perf_cli.py +++ b/tests/unit/commands/test_perf_cli.py @@ -374,6 +374,45 @@ def test_cli_onnx_not_found_error(self, runner: CliRunner, tmp_path: Path) -> No assert result.exit_code != 0 assert "not found" in result.output.lower() + def test_cli_hub_onnx_ref_is_resolved(self, runner: CliRunner, tmp_path: Path) -> None: + """CLI with a Hub-style ONNX ref must download once before the + ``Path(...).suffix == '.onnx' and exists()`` check, otherwise the + ref string is mistaken for a missing local file and rejected with + ``FileNotFoundError`` before any HF Hub call happens. + + Regression test for ``winml perf -m + onnx-community/sam3-tracker-ONNX/onnx/...``. + """ + local = tmp_path / "vision_encoder_int8.onnx" + local.write_bytes(b"fake onnx") + hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + + with ( + patch( + "winml.modelkit.loader.maybe_resolve_hf_onnx_path", + return_value=str(local), + ) as mock_resolve, + patch( + "winml.modelkit.commands.perf._run_onnx_benchmark", + return_value=MagicMock(), + ) as mock_run, + patch("winml.modelkit.commands.perf.display_console_report"), + patch("winml.modelkit.commands.perf.write_json_report"), + ): + result = runner.invoke( + perf, + ["-m", hub_ref, "-o", str(tmp_path / "out.json")], + obj={}, + ) + + assert result.exit_code == 0, result.output + mock_resolve.assert_called_once_with(hub_ref) + # After resolution, the Hub ref reaches _run_onnx_benchmark as + # the LOCAL path -- not the original Hub ref string. + mock_run.assert_called_once() + called_path = mock_run.call_args.args[0] + assert called_path == local + def test_onnx_load_model_passes_ep(self, tmp_path: Path) -> None: """EP argument should be forwarded to from_onnx.""" onnx_file = tmp_path / "model.onnx" diff --git a/tests/unit/inference/test_engine.py b/tests/unit/inference/test_engine.py index f322636e8..12590a309 100644 --- a/tests/unit/inference/test_engine.py +++ b/tests/unit/inference/test_engine.py @@ -366,6 +366,34 @@ def test_task_param_overrides_manifest(self, tmp_path: Any) -> None: engine.load_schema_only(tmp_path, task="image-classification") assert engine._task == "image-classification" + def test_hub_onnx_ref_is_resolved_before_routing(self, tmp_path: Any) -> None: + """A Hub-style ONNX ref (``//.onnx``) must be + resolved to a local path BEFORE the .onnx-suffix-and-exists check, + otherwise it falls through to the HF model-id branch and tries to + load a Hub-ONNX path string as if it were a transformers config. + + Regression test for ``winml run`` and ``winml serve`` on Hub refs + like ``onnx-community/sam3-tracker-ONNX/onnx/...``. + """ + from unittest.mock import patch + + local = tmp_path / "vision_encoder_int8.onnx" + local.write_bytes(b"fake-onnx") + hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + + engine = InferenceEngine() + with patch( + "winml.modelkit.loader.maybe_resolve_hf_onnx_path", + return_value=str(local), + ) as mock_resolve: + engine.load_schema_only(hub_ref, task="mask-generation") + mock_resolve.assert_called_once() + # After resolution the engine should treat the input as a local + # ONNX file (not as an HF model id), which means _model_id is the + # resolved local path string, not the original Hub ref. + assert engine._model_id == str(local) + assert engine._task == "mask-generation" + # --------------------------------------------------------------------------- # _sanitize_numpy diff --git a/tests/unit/loader/test_onnx_hub.py b/tests/unit/loader/test_onnx_hub.py index 3e421f72c..9771b4fdb 100644 --- a/tests/unit/loader/test_onnx_hub.py +++ b/tests/unit/loader/test_onnx_hub.py @@ -18,6 +18,7 @@ from winml.modelkit.loader.onnx_hub import ( _split_hf_onnx_path, is_hf_onnx_path, + maybe_resolve_hf_onnx_path, resolve_hf_onnx_path, ) @@ -141,3 +142,83 @@ def _fake_download(*, repo_id, filename, revision, cache_dir, token): result = resolve_hf_onnx_path("org/repo/onnx/vision_encoder.onnx") assert result == downloaded + + def test_sidecar_oserror_warns_loudly( + self, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + """Sidecar download OSError is logged at WARNING (not silently ignored). + + Regression: a previous implementation swallowed any + ``OSError`` (disk full, permission denied, network blip) at + ``logger.debug`` level. That hid real environmental problems and + led to confusing failures later when the model loader tried to + resolve missing external initializers. This test verifies the + warning is emitted so the user sees something is wrong. + """ + import logging + + downloaded = tmp_path / "vision_encoder.onnx" + downloaded.write_bytes(b"") + + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + if filename.endswith(".onnx_data"): + raise OSError("disk full") + return str(downloaded) + + with ( + patch("huggingface_hub.hf_hub_download", side_effect=_fake_download), + caplog.at_level(logging.WARNING, logger="winml.modelkit.loader.onnx_hub"), + ): + result = resolve_hf_onnx_path("org/repo/onnx/vision_encoder.onnx") + + # Main download still succeeds even when the sidecar fails. + assert result == downloaded + # Critically: the OSError must surface as a WARNING. + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + warning_messages = [r.getMessage() for r in warning_records] + assert any("disk full" in m for m in warning_messages), ( + f"Expected a WARNING containing 'disk full'; got {warning_messages}" + ) + + +class TestMaybeResolveHfOnnxPath: + """Convenience wrapper that combines is_hf_onnx_path + resolve_hf_onnx_path.""" + + def test_none_passes_through(self) -> None: + """``None`` returns ``None`` without touching the network.""" + with patch("huggingface_hub.hf_hub_download") as mock: + assert maybe_resolve_hf_onnx_path(None) is None + mock.assert_not_called() + + def test_plain_hf_model_id_passes_through(self) -> None: + """An HF model id (e.g. ``microsoft/resnet-50``) is returned unchanged.""" + with patch("huggingface_hub.hf_hub_download") as mock: + assert maybe_resolve_hf_onnx_path("microsoft/resnet-50") == "microsoft/resnet-50" + mock.assert_not_called() + + def test_local_path_passes_through(self, tmp_path: Path) -> None: + """Existing local ``.onnx`` paths take precedence over Hub interpretation.""" + local = tmp_path / "model.onnx" + local.write_bytes(b"") + with patch("huggingface_hub.hf_hub_download") as mock: + assert maybe_resolve_hf_onnx_path(str(local)) == str(local) + mock.assert_not_called() + + def test_hub_ref_is_resolved(self, tmp_path: Path) -> None: + """A Hub-style ONNX ref triggers ``resolve_hf_onnx_path``.""" + from huggingface_hub.utils import EntryNotFoundError + + downloaded = tmp_path / "vision_encoder_int8.onnx" + downloaded.write_bytes(b"") + + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + if filename.endswith(".onnx_data"): + raise EntryNotFoundError(filename) + return str(downloaded) + + with patch("huggingface_hub.hf_hub_download", side_effect=_fake_download): + result = maybe_resolve_hf_onnx_path( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + + assert result == str(downloaded) diff --git a/tests/unit/onnx/test_detection.py b/tests/unit/onnx/test_detection.py new file mode 100644 index 000000000..ba54b8338 --- /dev/null +++ b/tests/unit/onnx/test_detection.py @@ -0,0 +1,122 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for ``winml.modelkit.onnx.detection``. + +Covers ``is_quantized_onnx`` for both QDQ and QOperator formats and +``is_compiled_onnx`` for EPContext detection. Builds tiny synthetic +ONNX models with the relevant ops so no network or large fixtures +are required. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import onnx +from onnx import TensorProto, helper + +from winml.modelkit.onnx.detection import is_compiled_onnx, is_quantized_onnx + + +if TYPE_CHECKING: + from pathlib import Path + + +def _save(graph: onnx.GraphProto, path: Path, *, opset: int = 17) -> Path: + """Save a graph as a minimal ONNX model.""" + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) + model.ir_version = 8 + onnx.save(model, str(path)) + return path + + +class TestIsQuantizedOnnx: + """Both QDQ and QOperator quantization formats are detected.""" + + def test_float_model_is_not_quantized(self, tmp_path: Path) -> None: + """A plain float MatMul model is not flagged as quantized.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]) + w = helper.make_tensor( + "W", TensorProto.FLOAT, [4, 4], np.eye(4, dtype=np.float32).flatten().tolist() + ) + node = helper.make_node("MatMul", ["X", "W"], ["Y"]) + graph = helper.make_graph([node], "g", [x], [y], [w]) + path = _save(graph, tmp_path / "float.onnx") + assert is_quantized_onnx(path) is False + + def test_qdq_quantizelinear_is_detected(self, tmp_path: Path) -> None: + """A graph containing QuantizeLinear is recognized as quantized.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4]) + y = helper.make_tensor_value_info("Y", TensorProto.UINT8, [4]) + scale = helper.make_tensor("scale", TensorProto.FLOAT, [], [0.1]) + zp = helper.make_tensor("zp", TensorProto.UINT8, [], [128]) + node = helper.make_node("QuantizeLinear", ["X", "scale", "zp"], ["Y"]) + graph = helper.make_graph([node], "g", [x], [y], [scale, zp]) + path = _save(graph, tmp_path / "qdq.onnx") + assert is_quantized_onnx(path) is True + + def test_qoperator_matmulinteger_is_detected(self, tmp_path: Path) -> None: + """A graph containing MatMulInteger is recognized as quantized. + + Regression test for the SAM 3 ``vision_encoder_int8.onnx`` case: + the encoder uses ``QuantFormat.QOperator`` (no QDQ pairs), so the + old QDQ-only check returned False and the build pipeline tried to + run the optimizer over already-quantized integer ops. + """ + a = helper.make_tensor_value_info("A", TensorProto.UINT8, [1, 4]) + b = helper.make_tensor_value_info("B", TensorProto.UINT8, [4, 4]) + y = helper.make_tensor_value_info("Y", TensorProto.INT32, [1, 4]) + node = helper.make_node("MatMulInteger", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "g", [a, b], [y]) + path = _save(graph, tmp_path / "qop_matmul.onnx") + assert is_quantized_onnx(path) is True + + def test_qoperator_convinteger_is_detected(self, tmp_path: Path) -> None: + """A graph containing ConvInteger is recognized as quantized.""" + x = helper.make_tensor_value_info("X", TensorProto.UINT8, [1, 1, 4, 4]) + w = helper.make_tensor( + "W", TensorProto.UINT8, [1, 1, 1, 1], np.array([1], dtype=np.uint8).tobytes(), raw=True + ) + y = helper.make_tensor_value_info("Y", TensorProto.INT32, [1, 1, 4, 4]) + node = helper.make_node("ConvInteger", ["X", "W"], ["Y"]) + graph = helper.make_graph([node], "g", [x], [y], [w]) + path = _save(graph, tmp_path / "qop_conv.onnx") + assert is_quantized_onnx(path) is True + + def test_qoperator_qlinearmatmul_is_detected(self, tmp_path: Path) -> None: + """A graph containing QLinearMatMul is recognized as quantized.""" + a = helper.make_tensor_value_info("A", TensorProto.UINT8, [1, 4]) + b = helper.make_tensor_value_info("B", TensorProto.UINT8, [4, 4]) + y = helper.make_tensor_value_info("Y", TensorProto.UINT8, [1, 4]) + a_scale = helper.make_tensor("a_scale", TensorProto.FLOAT, [], [0.1]) + a_zp = helper.make_tensor("a_zp", TensorProto.UINT8, [], [128]) + b_scale = helper.make_tensor("b_scale", TensorProto.FLOAT, [], [0.1]) + b_zp = helper.make_tensor("b_zp", TensorProto.UINT8, [], [128]) + y_scale = helper.make_tensor("y_scale", TensorProto.FLOAT, [], [0.1]) + y_zp = helper.make_tensor("y_zp", TensorProto.UINT8, [], [128]) + node = helper.make_node( + "QLinearMatMul", + ["A", "a_scale", "a_zp", "B", "b_scale", "b_zp", "y_scale", "y_zp"], + ["Y"], + ) + graph = helper.make_graph( + [node], "g", [a, b], [y], [a_scale, a_zp, b_scale, b_zp, y_scale, y_zp] + ) + path = _save(graph, tmp_path / "qop_qlinear.onnx", opset=15) + assert is_quantized_onnx(path) is True + + +class TestIsCompiledOnnx: + """EPContext detection.""" + + def test_float_model_is_not_compiled(self, tmp_path: Path) -> None: + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1]) + node = helper.make_node("Identity", ["X"], ["Y"]) + graph = helper.make_graph([node], "g", [x], [y]) + path = _save(graph, tmp_path / "float.onnx") + assert is_compiled_onnx(path) is False From a0e9c7c49a5875e057772d724a0eb97848f2cbc9 Mon Sep 17 00:00:00 2001 From: Abhilash Chenreddy Date: Mon, 15 Jun 2026 16:30:09 -0400 Subject: [PATCH 3/4] Address review feedback: central model-input resolver + mask-generation evaluator Follow-up to the SAM 3 / Hub-hosted ONNX commits, addressing inline review feedback from @DingmaomaoBJTU, Copilot review, and CodeQL. Refactor: single model-input classifier and resolver ---------------------------------------------------- New module ``winml.modelkit.utils.model_input`` is the single entry point for classifying a ``-m/--model`` value (one of ``hub_onnx``, ``hf_id``, ``local_onnx``, ``build_dir``, ``invalid``) and resolving it (download for ``hub_onnx``, pass-through otherwise). Replaces the previous scattered detection across ``loader/onnx_hub`` (``is_hf_onnx_path``, ``maybe_resolve_hf_onnx_path``), ``utils/cli`` (``is_onnx_file_path``), and ad-hoc ``Path(value).suffix == ".onnx"`` checks in commands. Callers updated: ``commands/build``, ``config``, ``eval``, ``inspect``, ``perf``; ``inference/engine`` (``load`` and ``load_schema_only``); ``models/auto.WinMLAutoModel.from_pretrained``. Dead code removed ----------------- * ``resolve_model_input(...).local_path or str(model_path)`` -- the ``or`` branch was unreachable; ``local_path`` is set for every resolvable input. * Case-sensitive ``.onnx`` check in ``loader/onnx_hub`` (now consistent with case-insensitive ``.lower()`` check in callers). Build pipeline -------------- * ``ensure_pre_quantized_stamped`` is now the single defensive detection point for pre-quantized models in library entry points; the unified CLI path stamps the config up front, library entry points only run detection when needed. * ``run_optimize_analyze_loop`` enforces ``max_optim_iterations = 0`` whenever ``skip_optimize=True``; ORT lacks kernels for ``ConvInteger`` / ``MatMulInteger`` on host EPs, so re-optimize would crash for the same reason the initial optimize is skipped. * Skip-log message updated to ``"QDQ or QOperator nodes present"`` to match ``is_quantized_onnx`` accepting both quantization formats. * ``compiler/utils.QUANTIZATION_OP_TYPES`` extended with ``DynamicQuantizeLinear`` / ``DynamicQuantizeMatMul``; exported via ``__all__`` to satisfy CodeQL unused-global warning. Mask-generation evaluator (response to "would love to see eval") ---------------------------------------------------------------- * ``winml.modelkit.eval.mask_generation_evaluator`` drives encoder + decoder ORT sessions for SAM-family promptable mask generation. Profile-dispatch design supports SAM 3 (1008 input, mean=std=0.5, direct resize) and SAM 2.1 (1024 input, ImageNet mean/std, longest-side-pad). Verified preprocessing is byte-correct against ``onnx-community/sam3-tracker-ONNX/preprocessor_config.json``. * ``winml.modelkit.datasets.mask_generation`` -- ``MaskGenerationDataset`` wraps ``mattmdjaga/human_parsing_dataset`` with binary foreground/background masks. * ``winml.modelkit.eval.metrics.binary_segmentation`` -- mIoU + Dice on binary masks. * Composite SAM 3 entry in ``models_with_acc.json``. Reference scripts ----------------- * ``scripts/sam3_reference_check.py`` -- spot-check against the published reference ``iou_scores`` from the model card. * ``scripts/mask_generation_eval.py`` -- generic harness for any SAM-family ONNX pair with ``--preset`` (sam2 / sam3) and ``--dataset`` (human_parsing / coco). * ``scripts/sam3_smoke_eval.py`` -- back-compat shim that delegates to the generic harness with the SAM 3 preset. Test cleanups (CodeQL + Copilot) -------------------------------- * ``tests/integration/test_sam3_e2e.py`` fixtures: explicit ``raise`` after ``pytest.skip`` to make control flow obvious. * ``tests/unit/onnx/test_detection.py``: consolidate ``import onnx`` / ``from onnx import ...`` into a single import form. * ``tests/unit/core/test_onnx_utils.py``: expected keys updated for new ``input_symbolic_shapes`` field in ``get_io_config`` output. Other ----- * ``.gitignore``: ignore stray ``.data`` ORT external-data sidecars at repo root, ``quantized_info.csv`` Quark side-effect file, and ``scripts/_*.py`` / ``scripts/_*.json`` private debug scripts. --- .gitignore | 12 + scripts/e2e_eval/build_registry.py | 34 +- .../e2e_eval/testsets/models_with_acc.json | 42 + scripts/mask_generation_eval.py | 770 ++++++++++++++++++ scripts/sam3_decoder_shapes.json | 8 + scripts/sam3_reference_check.py | 105 +++ scripts/sam3_smoke_eval.py | 35 + src/winml/modelkit/build/common.py | 52 +- src/winml/modelkit/build/hf.py | 76 +- src/winml/modelkit/build/onnx.py | 76 +- src/winml/modelkit/commands/build.py | 48 +- src/winml/modelkit/commands/config.py | 3 +- src/winml/modelkit/commands/eval.py | 12 +- src/winml/modelkit/commands/inspect.py | 7 +- src/winml/modelkit/commands/perf.py | 64 +- src/winml/modelkit/compiler/configs.py | 20 +- src/winml/modelkit/compiler/utils.py | 37 +- src/winml/modelkit/config/build.py | 16 +- src/winml/modelkit/core/onnx_utils.py | 10 +- src/winml/modelkit/data/hub_models.json | 10 +- src/winml/modelkit/datasets/__init__.py | 2 + .../modelkit/datasets/mask_generation.py | 394 +++++++++ src/winml/modelkit/eval/evaluate.py | 32 +- .../eval/mask_generation_evaluator.py | 593 ++++++++++++++ src/winml/modelkit/eval/metrics/__init__.py | 3 + .../eval/metrics/binary_segmentation.py | 92 +++ src/winml/modelkit/inference/engine.py | 8 +- src/winml/modelkit/loader/__init__.py | 4 +- src/winml/modelkit/loader/onnx_hub.py | 140 ++-- src/winml/modelkit/models/auto.py | 6 +- src/winml/modelkit/onnx/io.py | 12 +- src/winml/modelkit/sysinfo/device.py | 27 + src/winml/modelkit/utils/__init__.py | 10 + src/winml/modelkit/utils/cli.py | 43 +- src/winml/modelkit/utils/eval_utils.py | 17 + src/winml/modelkit/utils/hub_utils.py | 49 +- src/winml/modelkit/utils/model_input.py | 156 ++++ src/winml/modelkit/winml.py | 12 + tests/integration/test_sam3_e2e.py | 17 +- tests/unit/build/test_hf.py | 2 +- tests/unit/build/test_onnx.py | 2 +- tests/unit/commands/test_eval.py | 88 +- tests/unit/commands/test_hub_onnx_ref.py | 29 +- tests/unit/commands/test_perf_cli.py | 36 +- tests/unit/config/test_build.py | 18 + tests/unit/core/test_onnx_utils.py | 1 + tests/unit/datasets/test_mask_generation.py | 298 +++++++ .../eval/test_binary_segmentation_metric.py | 97 +++ .../eval/test_mask_generation_evaluator.py | 369 +++++++++ tests/unit/inference/test_engine.py | 9 +- tests/unit/loader/test_onnx_hub.py | 172 ++-- tests/unit/onnx/test_detection.py | 7 +- tests/unit/utils/test_model_input.py | 215 +++++ 53 files changed, 4048 insertions(+), 349 deletions(-) create mode 100644 scripts/mask_generation_eval.py create mode 100644 scripts/sam3_decoder_shapes.json create mode 100644 scripts/sam3_reference_check.py create mode 100644 scripts/sam3_smoke_eval.py create mode 100644 src/winml/modelkit/datasets/mask_generation.py create mode 100644 src/winml/modelkit/eval/mask_generation_evaluator.py create mode 100644 src/winml/modelkit/eval/metrics/binary_segmentation.py create mode 100644 src/winml/modelkit/utils/model_input.py create mode 100644 tests/unit/datasets/test_mask_generation.py create mode 100644 tests/unit/eval/test_binary_segmentation_metric.py create mode 100644 tests/unit/eval/test_mask_generation_evaluator.py create mode 100644 tests/unit/utils/test_model_input.py diff --git a/.gitignore b/.gitignore index 2184e9f72..089c7dd61 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ $RECYCLE.BIN/ # Temporary files temp/ tmp/ +out/ *.tmp *.temp @@ -223,6 +224,17 @@ tmp/ *.graphml *.onnxdata +# ONNX Runtime external-data sidecars dropped at CWD (UUID-named *.data files) +# from sessions that did not pin the external-data directory +/*.data + +# Quark / quantizer side-effect outputs +/quantized_info.csv + +# Local debug / scratch scripts (convention: leading underscore = private) +scripts/_*.py +scripts/_*.json + # Logs *.log logs/ diff --git a/scripts/e2e_eval/build_registry.py b/scripts/e2e_eval/build_registry.py index 547d8208a..467e20229 100644 --- a/scripts/e2e_eval/build_registry.py +++ b/scripts/e2e_eval/build_registry.py @@ -128,19 +128,35 @@ def load_optimum_types() -> set[str]: def load_curated_entries(curated_path: Path) -> list[dict]: - """Load curated entries (hf_id + task + group + priority) from source JSON.""" + """Load curated entries (hf_id + task + group + priority) from source JSON. + + Optional fields recognised by downstream consumers but not enforced here: + + * ``composite_onnx`` -- ``{role: hub_onnx_ref}`` map for models that + ship as several role-tagged ONNX graphs (e.g. SAM-family mask + generators with ``image-encoder`` + ``prompt-decoder``). Old + consumers ignore this field; new consumers (e.g. + ``mask-generation`` evaluator dispatch) read it to discover the + per-role files. ``hf_id`` is still required and should point at + the canonical repo (typically the encoder's repo). + """ with curated_path.open(encoding="utf-8") as f: entries = json.load(f) - return [ - { + loaded: list[dict] = [] + for e in entries: + if "hf_id" not in e: + continue + item = { "hf_id": e["hf_id"], "task": e.get("task") or "", "group": e.get("group", "P0"), "priority": e.get("priority", "P0"), } - for e in entries - if "hf_id" in e - ] + # Pass-through additive fields so they survive into the built registry. + if "composite_onnx" in e: + item["composite_onnx"] = e["composite_onnx"] + loaded.append(item) + return loaded def print_stats(registry_path: Path) -> None: @@ -380,6 +396,10 @@ def build_registry( existing["priority"] = priority existing["group"] = group safe_print(f" [{priority}] {model_id} / {task} — updated (group={group})") + # Carry curated ``composite_onnx`` onto an existing entry so + # downstream consumers always see the canonical role map. + if "composite_onnx" in c and "composite_onnx" not in existing: + existing["composite_onnx"] = c["composite_onnx"] continue # New curated entry — fetch metadata if not already loaded @@ -399,6 +419,8 @@ def build_registry( "last_update_time": metadata["last_modified"], "optimum_supported": is_optimum, } + if "composite_onnx" in c: + entry["composite_onnx"] = c["composite_onnx"] seen.add(key) entry_lookup[key] = entry diff --git a/scripts/e2e_eval/testsets/models_with_acc.json b/scripts/e2e_eval/testsets/models_with_acc.json index 49e4b73e5..7ba60c8ad 100644 --- a/scripts/e2e_eval/testsets/models_with_acc.json +++ b/scripts/e2e_eval/testsets/models_with_acc.json @@ -1927,5 +1927,47 @@ "depth_column": "depth_map" } } + }, + { + "hf_id": "onnx-community/sam3-tracker-ONNX", + "task": "mask-generation", + "model_type": "sam3_tracker", + "group": "Top200", + "priority": "P2", + "composite_onnx": { + "image-encoder": "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx", + "prompt-decoder": "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" + }, + "dataset_config": { + "path": "mattmdjaga/human_parsing_dataset", + "split": "train", + "metric": "mIoU", + "samples": 10, + "columns_mapping": { + "input_column": "image", + "mask_column": "mask" + } + } + }, + { + "hf_id": "onnx-community/sam2.1-hiera-small-ONNX", + "task": "mask-generation", + "model_type": "sam2", + "group": "Top200", + "priority": "P2", + "composite_onnx": { + "image-encoder": "onnx-community/sam2.1-hiera-small-ONNX/onnx/vision_encoder_int8.onnx", + "prompt-decoder": "onnx-community/sam2.1-hiera-small-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" + }, + "dataset_config": { + "path": "mattmdjaga/human_parsing_dataset", + "split": "train", + "metric": "mIoU", + "samples": 10, + "columns_mapping": { + "input_column": "image", + "mask_column": "mask" + } + } } ] diff --git a/scripts/mask_generation_eval.py b/scripts/mask_generation_eval.py new file mode 100644 index 000000000..f66daa58e --- /dev/null +++ b/scripts/mask_generation_eval.py @@ -0,0 +1,770 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Generic mIoU/Dice smoke test for promptable mask-generation ONNX models. + +Works with any SAM-family encoder + prompt-decoder ONNX pair that follows +the standard naming convention: + +* **Encoder** input ``pixel_values`` ``(B, 3, T, T)`` -> outputs + ``image_embeddings.0``, ``image_embeddings.1``, ``image_embeddings.2``. +* **Decoder** inputs ``input_points``, ``input_labels``, ``input_boxes``, + ``image_embeddings.{0,1,2}`` -> outputs ``iou_scores``, ``pred_masks``, + ``object_score_logits``. + +Supports two datasets: + +* ``--dataset human_parsing`` (default) -- ``mattmdjaga/human_parsing_dataset``, + one binary foreground mask per sample. Fast, but humans-only. +* ``--dataset coco`` -- COCO val2017 instances via ``merve/coco`` annotations + (instance masks decoded with ``pycocotools``) and images fetched from + the COCO CDN. This is the standard cross-domain SAM benchmark and + the only one whose numbers are directly comparable to the published + SAM / SAM 2 / SAM 3 papers. + +Supports three prompt strategies (``--prompt-type``): + +* ``point`` -- single positive click at the GT mask's centroid (snapped + to a foreground pixel for non-convex masks). This is the standard + "1-click" SAM eval protocol. +* ``bbox`` -- tight GT bounding box as box prompt. Standard "box-prompt" + protocol; typically scores 5-10 mIoU higher than 1-click. +* ``point+box`` -- both; used to be the default behaviour of this script. +* ``all`` -- run all three and report a per-prompt-type breakdown. + +Encoder output is cached per unique image_id, so running multiple prompt +types or multiple annotations per COCO image only re-runs the cheap +decoder (~25 ms per call vs ~12 s for the SAM 3 encoder). + +Run: + + # SAM 3 baseline on the humans slice (back-compat, fast) + python scripts/mask_generation_eval.py --preset sam3 --num-samples 10 + + # SAM 3 on COCO val2017, full 3-way prompt comparison + python scripts/mask_generation_eval.py --preset sam3 --dataset coco \\ + --num-samples 50 --prompt-type all + + # Custom encoder/decoder + python scripts/mask_generation_eval.py \\ + --encoder onnx-community//onnx/vision_encoder_int8.onnx \\ + --decoder onnx-community//onnx/prompt_encoder_mask_decoder_int8.onnx \\ + --target-size 1024 --mean 0.485,0.456,0.406 --std 0.229,0.224,0.225 \\ + --dataset coco --num-samples 20 --prompt-type bbox +""" + +from __future__ import annotations + +import argparse +import io +import json +import os +import sys +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import onnxruntime as ort +from PIL import Image + + +# --------------------------------------------------------------------------- # +# Built-in profiles (preprocessing + Hub paths). +# Add a new one to the dict at the bottom of this section to register it. +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True) +class MaskGenProfile: + """Per-model preprocessing + I/O config for the generic harness.""" + + name: str + encoder_ref: str # Hub ONNX ref: //.onnx + decoder_ref: str + target_size: int + mean: tuple[float, float, float] + std: tuple[float, float, float] + + +# SAM 3 Tracker -- published preprocessing constants from the model card + +# our own ``WinMLMaskGenerationEvaluator`` (which uses 0.5/0.5/0.5). +SAM3_PROFILE = MaskGenProfile( + name="sam3", + encoder_ref="onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx", + decoder_ref="onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx", + target_size=1008, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), +) + + +# SAM 2.1 -- ImageNet stats and 1024x1024 input (SAM-paper convention). +SAM2_1_PROFILE = MaskGenProfile( + name="sam2.1", + encoder_ref="onnx-community/sam2.1-hiera-small-ONNX/onnx/vision_encoder_int8.onnx", + decoder_ref="onnx-community/sam2.1-hiera-small-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx", + target_size=1024, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), +) + + +PROFILES: dict[str, MaskGenProfile] = { + "sam3": SAM3_PROFILE, + "sam2.1": SAM2_1_PROFILE, +} + + +PROMPT_TYPES = ("point", "bbox", "point+box") + + +# --------------------------------------------------------------------------- # +# Eval-sample container. +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True) +class EvalSample: + """A single (image, instance-mask) example. + + ``image_id`` groups annotations belonging to the same source image so + the harness only runs the encoder once per unique image. Samples + must be sorted by ``image_id`` for the cache to be effective. + """ + + image: Image.Image + gt_mask: np.ndarray # H x W uint8 (0/1) + name: str # display name in the per-sample log + image_id: str # cache key for encoder embeddings + + +# --------------------------------------------------------------------------- # +# Preprocessing / postprocessing. +# --------------------------------------------------------------------------- # + + +def preprocess_image( + img: Image.Image, profile: MaskGenProfile, +) -> tuple[np.ndarray, float, int, int]: + """Resize longest side to target, pad to square, normalize. + + Returns: + pixel_values: (1, 3, T, T) fp32, NCHW + scale: original_pixel * scale -> resized_pixel + new_h, new_w: dimensions after resize (before padding) + """ + img = img.convert("RGB") + orig_w, orig_h = img.size + scale = profile.target_size / max(orig_h, orig_w) + new_h = int(round(orig_h * scale)) + new_w = int(round(orig_w * scale)) + resized = img.resize((new_w, new_h), Image.BILINEAR) + + arr = np.asarray(resized, dtype=np.float32) / 255.0 + arr = (arr - np.array(profile.mean, dtype=np.float32)) / np.array( + profile.std, dtype=np.float32, + ) + + pad_h = profile.target_size - new_h + pad_w = profile.target_size - new_w + arr = np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant") + + pixel_values = arr.transpose(2, 0, 1)[None, ...] + return pixel_values.astype(np.float32), scale, new_h, new_w + + +def bbox_from_mask(mask: np.ndarray) -> tuple[int, int, int, int] | None: + """Tight xyxy bbox of nonzero pixels, or ``None`` if mask is empty.""" + ys, xs = np.nonzero(mask) + if ys.size == 0: + return None + return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) + + +def sample_point_in_mask(mask: np.ndarray) -> tuple[int, int] | None: + """Pick one foreground point near the mask centroid. + + Returns ``(x, y)`` in original-image pixel coordinates, or ``None`` + if the mask is empty. For non-convex masks (centroid outside the + mask) we snap to the nearest foreground pixel; this is the same + fallback used by the HF SAM image processor's ``point_grid``. + """ + ys, xs = np.nonzero(mask) + if ys.size == 0: + return None + cy = int(round(ys.mean())) + cx = int(round(xs.mean())) + if 0 <= cy < mask.shape[0] and 0 <= cx < mask.shape[1] and mask[cy, cx]: + return cx, cy + d2 = (ys - cy) ** 2 + (xs - cx) ** 2 + i = int(d2.argmin()) + return int(xs[i]), int(ys[i]) + + +def postprocess_mask( + pred_mask: np.ndarray, + profile: MaskGenProfile, + orig_h: int, + orig_w: int, + new_h: int, + new_w: int, +) -> np.ndarray: + """Un-pad and resize a low-res mask back to original image coords.""" + pil = Image.fromarray(pred_mask.astype(np.float32)) + up = pil.resize((profile.target_size, profile.target_size), Image.BILINEAR) + up_arr = np.asarray(up, dtype=np.float32) + + cropped = up_arr[:new_h, :new_w] + + pil2 = Image.fromarray(cropped) + final = pil2.resize((orig_w, orig_h), Image.BILINEAR) + return np.asarray(final, dtype=np.float32) > 0 + + +def iou(pred: np.ndarray, gt: np.ndarray) -> float: + """Binary IoU.""" + pred = pred.astype(bool) + gt = gt.astype(bool) + inter = np.logical_and(pred, gt).sum() + union = np.logical_or(pred, gt).sum() + return float(inter) / float(union) if union > 0 else 0.0 + + +def dice(pred: np.ndarray, gt: np.ndarray) -> float: + """Binary Dice coefficient.""" + pred = pred.astype(bool) + gt = gt.astype(bool) + pp = pred.sum() + pg = gt.sum() + if pp + pg == 0: + return 0.0 + return 2.0 * float(np.logical_and(pred, gt).sum()) / float(pp + pg) + + +# --------------------------------------------------------------------------- # +# Inference driver -- split into encode (per image) + decode (per prompt) so +# embeddings can be cached across prompt types and across multiple +# annotations on the same COCO image. +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True) +class EncoderState: + embeddings: dict[str, np.ndarray] + scale: float + new_h: int + new_w: int + encode_time: float + + +def encode_image( + enc_sess: ort.InferenceSession, + profile: MaskGenProfile, + img: Image.Image, +) -> EncoderState: + pixel_values, scale, new_h, new_w = preprocess_image(img, profile) + t0 = time.monotonic() + enc_out = enc_sess.run(None, {"pixel_values": pixel_values}) + elapsed = time.monotonic() - t0 + enc_names = [o.name for o in enc_sess.get_outputs()] + emb = dict(zip(enc_names, enc_out, strict=True)) + return EncoderState(emb, scale, new_h, new_w, elapsed) + + +def make_prompts( + prompt_type: str, + gt_mask: np.ndarray, + scale: float, +) -> dict[str, np.ndarray] | None: + """Build the (points, labels, boxes) inputs for the given prompt strategy. + + Returns ``None`` if the GT mask is degenerate and no usable prompt can + be derived (caller should skip the sample). All coordinates are in + encoder-input space (i.e. ``original * scale``). + """ + if prompt_type not in PROMPT_TYPES: + raise ValueError(f"Unknown prompt_type {prompt_type!r}") + + bbox = bbox_from_mask(gt_mask) + if bbox is None: + return None + x0, y0, x1, y1 = bbox + + if prompt_type in ("point", "point+box"): + pt = sample_point_in_mask(gt_mask) + if pt is None: + return None + px, py = pt + points = np.array([[px * scale, py * scale]], dtype=np.float32).reshape(1, 1, 1, 2) + labels = np.array([[1]], dtype=np.int64).reshape(1, 1, 1) + else: + points = np.zeros((1, 1, 0, 2), dtype=np.float32) + labels = np.zeros((1, 1, 0), dtype=np.int64) + + if prompt_type in ("bbox", "point+box"): + boxes = np.array( + [[x0 * scale, y0 * scale, x1 * scale, y1 * scale]], + dtype=np.float32, + )[None, ...] + else: + boxes = np.zeros((1, 0, 4), dtype=np.float32) + + return { + "input_points": points, + "input_labels": labels, + "input_boxes": boxes, + } + + +def decode_with_prompt( + dec_sess: ort.InferenceSession, + profile: MaskGenProfile, + state: EncoderState, + prompts: dict[str, np.ndarray], + orig_h: int, + orig_w: int, +) -> tuple[np.ndarray, float, float]: + """Run the decoder once with the supplied prompts. + + Returns ``(binary_mask, pred_iou, decode_seconds)``. + """ + dec_inputs = { + **prompts, + "image_embeddings.0": state.embeddings["image_embeddings.0"], + "image_embeddings.1": state.embeddings["image_embeddings.1"], + "image_embeddings.2": state.embeddings["image_embeddings.2"], + } + t0 = time.monotonic() + iou_scores, pred_masks, _obj_logits = dec_sess.run( + ["iou_scores", "pred_masks", "object_score_logits"], dec_inputs, + ) + elapsed = time.monotonic() - t0 + iou_preds = iou_scores[0, 0] + best_idx = int(iou_preds.argmax()) + best_low_res = pred_masks[0, 0, best_idx] + best_iou_pred = float(iou_preds[best_idx]) + binary = postprocess_mask( + best_low_res, profile, orig_h, orig_w, state.new_h, state.new_w, + ) + return binary, best_iou_pred, elapsed + + +# --------------------------------------------------------------------------- # +# Datasets. +# --------------------------------------------------------------------------- # + + +def _load_human_parsing(n: int) -> list[EvalSample]: + """Sample n binary-mask examples from ``mattmdjaga/human_parsing_dataset``. + + Same dataset the production ``MaskGenerationDataset`` defaults to. + Multi-class body-part labels are collapsed to a single binary + foreground mask. Degenerate samples (coverage <5% or >95%) are + skipped so the prompt is meaningful. + """ + from datasets import load_dataset + + ds_name = "mattmdjaga/human_parsing_dataset" + print(f"Loading {ds_name} (streaming, taking {n} samples)...") + ds = load_dataset(ds_name, split="train", streaming=True) + samples: list[EvalSample] = [] + for i, ex in enumerate(ds): + if len(samples) >= n: + break + img = ex["image"] + mask_arr = np.asarray(ex["mask"]) + binary = (mask_arr > 0).astype(np.uint8) + coverage = binary.sum() / binary.size + if coverage < 0.05 or coverage > 0.95: + continue + samples.append( + EvalSample( + image=img, + gt_mask=binary, + name=f"hp_{i:04d}", + image_id=f"hp_{i:04d}", # unique per sample (no cross-sample reuse) + ), + ) + if not samples: + sys.exit(f"No usable samples from {ds_name}") + print(f" loaded {len(samples)}") + return samples + + +_COCO_ANNOTATIONS_REPO = "merve/coco" +_COCO_ANNOTATIONS_FILE = "annotations/instances_val2017.json" +_COCO_IMAGE_BASE = "http://images.cocodataset.org/val2017" + + +def _download_coco_image(file_name: str, cache_dir: Path) -> Image.Image: + """Fetch a COCO val2017 image from the public CDN, with a local cache.""" + import urllib.request + + cache_dir.mkdir(parents=True, exist_ok=True) + cached = cache_dir / file_name + if not cached.exists(): + url = f"{_COCO_IMAGE_BASE}/{file_name}" + with urllib.request.urlopen(url, timeout=30) as resp: # noqa: S310 + data = resp.read() + cached.write_bytes(data) + with cached.open("rb") as f: + return Image.open(io.BytesIO(f.read())).convert("RGB") + + +def _load_coco(n: int, min_area: float = 1024.0, seed: int = 0) -> list[EvalSample]: + """Sample n COCO val2017 instance annotations. + + Annotations are filtered to ``iscrowd == 0`` and ``area >= min_area`` + so prompts are meaningful (tiny objects aren't a useful SAM + benchmark; crowd RLEs combine multiple instances). Samples are + sorted by ``image_id`` so the encoder cache is effective. + """ + try: + from huggingface_hub import hf_hub_download + from pycocotools import mask as mask_utils + except ImportError as exc: + sys.exit( + f"COCO eval requires huggingface_hub + pycocotools: {exc}", + ) + + print(f"Loading COCO val2017 annotations ({_COCO_ANNOTATIONS_REPO})...") + ann_path = hf_hub_download( + repo_id=_COCO_ANNOTATIONS_REPO, + filename=_COCO_ANNOTATIONS_FILE, + repo_type="dataset", + ) + with open(ann_path, encoding="utf-8") as f: + coco = json.load(f) + images_by_id = {im["id"]: im for im in coco["images"]} + cats_by_id = {c["id"]: c["name"] for c in coco["categories"]} + annotations = [ + a for a in coco["annotations"] + if a.get("iscrowd", 0) == 0 and a.get("area", 0.0) >= min_area + ] + annotations.sort(key=lambda a: a["image_id"]) + rng = np.random.default_rng(seed) + if len(annotations) > n: + # Sample n annotations, then re-sort by image_id for cache locality. + idx = rng.choice(len(annotations), size=n, replace=False) + chosen = sorted((annotations[int(i)] for i in idx), key=lambda a: a["image_id"]) + else: + chosen = annotations + + image_cache_dir = Path.home() / ".cache" / "winml" / "coco_val2017_images" + samples: list[EvalSample] = [] + fetched_images: dict[int, Image.Image] = {} + print(f" decoding {len(chosen)} instance masks (downloading images on demand)...") + for ann in chosen: + image_id = ann["image_id"] + img_meta = images_by_id[image_id] + h, w = int(img_meta["height"]), int(img_meta["width"]) + seg = ann["segmentation"] + if isinstance(seg, list): + rles = mask_utils.frPyObjects(seg, h, w) + rle = mask_utils.merge(rles) + elif isinstance(seg, dict): + rle = seg if isinstance(seg.get("counts"), bytes) else mask_utils.frPyObjects( + seg, h, w, + ) + else: + continue + gt = mask_utils.decode(rle).astype(np.uint8) + if gt.sum() == 0: + continue + if image_id not in fetched_images: + try: + fetched_images[image_id] = _download_coco_image( + img_meta["file_name"], image_cache_dir, + ) + except Exception as exc: # noqa: BLE001 + print(f" SKIP image {image_id}: download failed ({exc})") + continue + img = fetched_images[image_id] + cat_name = cats_by_id.get(ann["category_id"], "?") + samples.append( + EvalSample( + image=img, + gt_mask=gt, + name=f"coco_{image_id}_{ann['id']}_{cat_name}", + image_id=f"coco_{image_id}", + ), + ) + if not samples: + sys.exit("No usable COCO samples (download failures or all filtered).") + print(f" loaded {len(samples)} annotations across {len({s.image_id for s in samples})} images") + return samples + + +def load_eval_samples(dataset: str, n: int) -> list[EvalSample]: + if dataset == "human_parsing": + return _load_human_parsing(n) + if dataset == "coco": + return _load_coco(n) + raise ValueError(f"Unknown dataset {dataset!r}") + + +# --------------------------------------------------------------------------- # +# CLI. +# --------------------------------------------------------------------------- # + + +def _build_providers(ep: str) -> tuple[list[str], list[dict]]: + """Return (providers, provider_options) for the requested EP. + + Falls back to CPU with a warning if the requested EP isn't installed. + VitisAI (AMD NPU) requires extra config on Phoenix/Hawk Point; if + ``RYZEN_AI_INSTALLATION_PATH`` is set we wire the xclbin + automatically, otherwise we warn and let ORT fall back. + """ + providers_map = { + "cpu": ["CPUExecutionProvider"], + "dml": ["DmlExecutionProvider", "CPUExecutionProvider"], + "vitisai": ["VitisAIExecutionProvider", "CPUExecutionProvider"], + } + providers = list(providers_map[ep]) + avail = set(ort.get_available_providers()) + if providers[0] not in avail: + print( + f"WARNING: EP '{providers[0]}' not available " + f"(available: {sorted(avail)}). Falling back to CPU only.", + ) + return ["CPUExecutionProvider"], [{}] + + provider_options: list[dict] = [{} for _ in providers] + if providers[0] == "VitisAIExecutionProvider": + install_dir = os.environ.get("RYZEN_AI_INSTALLATION_PATH", "") + xclbin = os.path.join( + install_dir, "voe-4.0-win_amd64", "xclbins", "phoenix", "4x4.xclbin", + ) + if install_dir and os.path.exists(xclbin): + provider_options[0] = { + "target": "X1", + "xlnx_enable_py3_round": 0, + "xclbin": xclbin, + } + print(f" VitisAI PHX config: target=X1, xclbin={xclbin}") + else: + print( + " WARNING: RYZEN_AI_INSTALLATION_PATH not set or xclbin " + f"not found at '{xclbin}'. VitisAI will likely fall back " + "to CPU. Run inside the Ryzen AI conda env so the env var " + "is populated.", + ) + return providers, provider_options + + +def _resolve_local(ref: str) -> Path: + """Resolve a Hub-ONNX ref or local path to a local file. + + Lazily imports our own resolver so the script also runs standalone + (e.g. from a checkout where ``src/`` is on PYTHONPATH but the package + isn't installed in editable mode). + """ + p = Path(ref) + if p.exists(): + return p + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) + from winml.modelkit.loader.onnx_hub import resolve_hf_onnx_path + return resolve_hf_onnx_path(ref) + + +def main() -> int: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--preset", + choices=sorted(PROFILES), + default=None, + help="Built-in profile. Overrides --encoder/--decoder/--target-size/--mean/--std.", + ) + parser.add_argument("--encoder", help="Hub ONNX ref or local path to encoder.") + parser.add_argument("--decoder", help="Hub ONNX ref or local path to decoder.") + parser.add_argument("--target-size", type=int, default=1024) + parser.add_argument( + "--mean", + default="0.485,0.456,0.406", + help="Comma-separated per-channel mean for normalization.", + ) + parser.add_argument( + "--std", + default="0.229,0.224,0.225", + help="Comma-separated per-channel std for normalization.", + ) + parser.add_argument("--name", default="custom", help="Label used in output table.") + parser.add_argument( + "--dataset", + choices=("human_parsing", "coco"), + default="human_parsing", + help="Evaluation dataset.", + ) + parser.add_argument( + "--num-samples", + type=int, + default=10, + help="For COCO this is the number of annotations (instances), not images.", + ) + parser.add_argument( + "--prompt-type", + choices=(*PROMPT_TYPES, "all"), + default="point+box", + help="Prompt strategy. 'all' runs every type per sample and reports a breakdown.", + ) + parser.add_argument("--out-dir", type=Path, default=Path("out/mask_gen_eval")) + parser.add_argument( + "--ep", + choices=["cpu", "dml", "vitisai"], + default="cpu", + help="Execution provider.", + ) + args = parser.parse_args() + + if args.preset: + profile = PROFILES[args.preset] + else: + if not (args.encoder and args.decoder): + parser.error("--encoder and --decoder are required when --preset is not used.") + profile = MaskGenProfile( + name=args.name, + encoder_ref=args.encoder, + decoder_ref=args.decoder, + target_size=args.target_size, + mean=tuple(float(x) for x in args.mean.split(",")), # type: ignore[arg-type] + std=tuple(float(x) for x in args.std.split(",")), # type: ignore[arg-type] + ) + + prompt_types = list(PROMPT_TYPES) if args.prompt_type == "all" else [args.prompt_type] + + out_dir = args.out_dir / profile.name / args.dataset + out_dir.mkdir(parents=True, exist_ok=True) + + print(f"Profile: {profile.name}") + print(f" encoder: {profile.encoder_ref}") + print(f" decoder: {profile.decoder_ref}") + print(f" target_size={profile.target_size} mean={profile.mean} std={profile.std}") + print(f" dataset: {args.dataset} prompt_types: {prompt_types}") + + enc_path = _resolve_local(profile.encoder_ref) + dec_path = _resolve_local(profile.decoder_ref) + print(f" encoder local: {enc_path}") + print(f" decoder local: {dec_path}") + + providers, provider_options = _build_providers(args.ep) + print(f"Creating ORT sessions (providers={providers})...") + sess_opts = ort.SessionOptions() + sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + enc_sess = ort.InferenceSession( + str(enc_path), sess_options=sess_opts, + providers=providers, provider_options=provider_options, + ) + dec_sess = ort.InferenceSession( + str(dec_path), sess_options=sess_opts, + providers=providers, provider_options=provider_options, + ) + print(f" encoder providers: {enc_sess.get_providers()}") + print(f" decoder providers: {dec_sess.get_providers()}") + + samples = load_eval_samples(args.dataset, args.num_samples) + print(f"Got {len(samples)} samples. Running {profile.name}...") + + # Per-prompt-type rows: name, iou_gt, dice, iou_pred, dec_sec, enc_sec + rows: dict[str, list[tuple[str, float, float, float, float, float]]] = { + pt: [] for pt in prompt_types + } + cache: dict[str, EncoderState] = {} + encode_count = 0 + + for i, sample in enumerate(samples): + if sample.image_id not in cache: + try: + cache[sample.image_id] = encode_image(enc_sess, profile, sample.image) + encode_count += 1 + except Exception as exc: # noqa: BLE001 + print(f" [{i + 1:3d}/{len(samples)}] {sample.name} ENCODE FAILED: {exc}") + continue + state = cache[sample.image_id] + orig_h, orig_w = sample.gt_mask.shape + for pt in prompt_types: + prompts = make_prompts(pt, sample.gt_mask, state.scale) + if prompts is None: + continue + try: + pred, iou_pred, dec_sec = decode_with_prompt( + dec_sess, profile, state, prompts, orig_h, orig_w, + ) + except Exception as exc: # noqa: BLE001 + print(f" [{i + 1:3d}/{len(samples)}] {sample.name} {pt} FAILED: {exc}") + continue + score_iou = iou(pred, sample.gt_mask) + score_dice = dice(pred, sample.gt_mask) + rows[pt].append( + (sample.name, score_iou, score_dice, iou_pred, dec_sec, state.encode_time), + ) + print( + f" [{i + 1:3d}/{len(samples)}] {sample.name[:50]:50s} " + f"{pt:11s} IoU={score_iou:.3f} Dice={score_dice:.3f} " + f"pIoU={iou_pred:.3f} dec={dec_sec * 1000:.1f}ms", + ) + vis = np.stack( + [ + (sample.gt_mask * 255).astype(np.uint8), + (pred * 255).astype(np.uint8), + np.zeros_like(sample.gt_mask, dtype=np.uint8), + ], + axis=-1, + ) + vis_path = out_dir / f"{i:03d}_{Path(sample.name).stem}_{pt.replace('+', '_')}.png" + Image.fromarray(vis).save(vis_path) + + # Aggregate. + if not any(rows.values()): + print("No successful runs.") + return 1 + + print("\n" + "=" * 80) + print( + f"{profile.name} -- mask-generation eval (dataset={args.dataset}, " + f"unique_images={encode_count}, EP={args.ep})", + ) + print("=" * 80) + header = ( + f"{'prompt':<11} {'n':>4} {'mIoU':>7} {'Dice':>7} {'pIoU':>7} " + f"{'mIoU>=0.5':>9} {'mIoU>=0.75':>10} {'enc_s/img':>10} {'dec_ms':>7}" + ) + print(header) + print("-" * len(header)) + enc_times = [] + for pt in prompt_types: + pt_rows = rows[pt] + if not pt_rows: + print(f"{pt:<11} (no successful runs)") + continue + ious = np.array([r[1] for r in pt_rows]) + dices = np.array([r[2] for r in pt_rows]) + pious = np.array([r[3] for r in pt_rows]) + dec_ms = np.array([r[4] for r in pt_rows]) * 1000 + enc_s = np.array([r[5] for r in pt_rows]) + enc_times.extend(enc_s.tolist()) + miou = float(ious.mean()) + rate50 = float((ious >= 0.5).mean()) + rate75 = float((ious >= 0.75).mean()) + print( + f"{pt:<11} {len(pt_rows):>4d} {miou:>7.4f} {dices.mean():>7.4f} " + f"{pious.mean():>7.4f} {rate50:>9.2%} {rate75:>10.2%} " + f"{enc_s.mean():>10.2f} {dec_ms.mean():>7.1f}", + ) + if enc_times: + unique_enc = list({s.image_id: cache[s.image_id].encode_time for s in samples if s.image_id in cache}.values()) + if unique_enc: + print(f"\nEncoder: {len(unique_enc)} unique runs, mean {np.mean(unique_enc):.2f}s/image") + print(f"Visualizations: {out_dir}/ (red=GT, green=prediction)") + print("=" * 80) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/sam3_decoder_shapes.json b/scripts/sam3_decoder_shapes.json new file mode 100644 index 000000000..56b5fb5e6 --- /dev/null +++ b/scripts/sam3_decoder_shapes.json @@ -0,0 +1,8 @@ +{ + "input_points": [1, 1, 1, 2], + "input_labels": [1, 1, 1], + "input_boxes": [1, 1, 4], + "image_embeddings.0": [1, 32, 288, 288], + "image_embeddings.1": [1, 64, 144, 144], + "image_embeddings.2": [1, 256, 72, 72] +} diff --git a/scripts/sam3_reference_check.py b/scripts/sam3_reference_check.py new file mode 100644 index 000000000..2ab0306b8 --- /dev/null +++ b/scripts/sam3_reference_check.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Spot-check SAM 3 ONNX against the reference outputs published by the +model authors in the onnx-community/sam3-tracker-ONNX model card. + +Reference (from README.md of onnx-community/sam3-tracker-ONNX): + Input: truck.jpg + point [[500, 375]] + label [[1]] (no boxes) + Expected iou_scores: [0.9313147068023682, 0.037515610456466675, 0.5128555297851562] + +Uses the same preprocessing as WinMLMaskGenerationEvaluator so this also +sanity-checks our pipeline end-to-end. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import onnxruntime as ort +import requests +from huggingface_hub import snapshot_download +from PIL import Image + +# Import preprocessing helpers from our evaluator to keep the comparison +# consistent with the production code path. +from winml.modelkit.eval.mask_generation_evaluator import ( + _preprocess_image, +) + + +REFERENCE_IOU_SCORES = [0.9313147068023682, 0.037515610456466675, 0.5128555297851562] +TRUCK_URL = ( + "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/" + "resolve/main/truck.jpg" +) +POINT = (500.0, 375.0) + + +def main() -> None: + snap = Path( + snapshot_download( + "onnx-community/sam3-tracker-ONNX", + allow_patterns=["onnx/*_int8.onnx"], + ) + ) + encoder_path = snap / "onnx" / "vision_encoder_int8.onnx" + decoder_path = snap / "onnx" / "prompt_encoder_mask_decoder_int8.onnx" + + print(f"Encoder: {encoder_path}") + print(f"Decoder: {decoder_path}") + + image_path = Path.home() / ".cache" / "winml" / "truck.jpg" + image_path.parent.mkdir(parents=True, exist_ok=True) + if not image_path.exists(): + print(f"Downloading {TRUCK_URL}") + image_path.write_bytes(requests.get(TRUCK_URL, timeout=60).content) + img = Image.open(image_path).convert("RGB") + print(f"Image size: {img.size}") + + # Preprocess via the evaluator's helper (longest-side 1008 + pad + ImageNet norm) + pixel_values, scale_x, scale_y = _preprocess_image(img) + print( + f"Preprocessed: pixel_values={pixel_values.shape}, " + f"scale_x={scale_x:.4f}, scale_y={scale_y:.4f}" + ) + + # Encoder forward + enc = ort.InferenceSession(str(encoder_path), providers=["CPUExecutionProvider"]) + enc_inputs = {"pixel_values": pixel_values} + enc_names = [o.name for o in enc.get_outputs()] + enc_out = enc.run(None, enc_inputs) + emb = dict(zip(enc_names, enc_out)) + print(f"Encoder outputs: {[(k, v.shape) for k, v in emb.items()]}") + + # Build decoder inputs in point-prompt mode + points = np.array( + [[[[POINT[0] * scale_x, POINT[1] * scale_y]]]], dtype=np.float32 + ) + labels = np.array([[[1]]], dtype=np.int64) + boxes = np.zeros((1, 0, 4), dtype=np.float32) + + dec = ort.InferenceSession(str(decoder_path), providers=["CPUExecutionProvider"]) + dec_inputs = { + "input_points": points, + "input_labels": labels, + "input_boxes": boxes, + **emb, + } + dec_out = dec.run(None, dec_inputs) + dec_names = [o.name for o in dec.get_outputs()] + out = dict(zip(dec_names, dec_out)) + + iou = np.asarray(out["iou_scores"]).reshape(-1).tolist() + print() + print("Reference iou_scores : ", [f"{v:.6f}" for v in REFERENCE_IOU_SCORES]) + print("Our iou_scores : ", [f"{v:.6f}" for v in iou]) + diffs = [a - b for a, b in zip(iou, REFERENCE_IOU_SCORES)] + print("Absolute diff : ", [f"{d:+.6f}" for d in diffs]) + print(f"Max abs diff : {max(abs(d) for d in diffs):.6f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sam3_smoke_eval.py b/scripts/sam3_smoke_eval.py new file mode 100644 index 000000000..7f2e46126 --- /dev/null +++ b/scripts/sam3_smoke_eval.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Back-compat shim: run the SAM 3 preset of the generic mask-gen harness. + +The original SAM 3 logic has been generalized into +:mod:`scripts.mask_generation_eval` so other promptable-segmentation +models (SAM 2 / SAM 2.1 / future ``onnx-community`` exports) can reuse +the same harness via ``--preset``. This wrapper preserves the original +``python scripts/sam3_smoke_eval.py`` entrypoint for anyone with +bookmarks or CI invocations. + +The old standalone script had a bug -- it used ImageNet mean/std for +SAM 3, but the SAM 3 Tracker image processor uses ``[0.5, 0.5, 0.5]`` +for both mean and std (matches ``Sam3TrackerImageProcessor`` defaults). +The generic harness encodes the correct values in the ``sam3`` preset. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + + +if __name__ == "__main__": + # Delegate to the generic harness with the SAM 3 preset baked in. + sys.path.insert(0, str(Path(__file__).resolve().parent)) + from mask_generation_eval import main # type: ignore[import-not-found] + + # Inject ``--preset sam3`` if the user didn't override it. + if "--preset" not in sys.argv and "--encoder" not in sys.argv: + sys.argv.insert(1, "--preset") + sys.argv.insert(2, "sam3") + sys.exit(main()) diff --git a/src/winml/modelkit/build/common.py b/src/winml/modelkit/build/common.py index e16f79aed..742267d16 100644 --- a/src/winml/modelkit/build/common.py +++ b/src/winml/modelkit/build/common.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any from ..analyze import analyze_onnx -from ..onnx import copy_onnx_model +from ..onnx import copy_onnx_model, is_quantized_onnx from ..optim import optimize_onnx @@ -28,6 +28,43 @@ logger = logging.getLogger(__name__) +def ensure_pre_quantized_stamped( + config: WinMLBuildConfig, onnx_path: Path, *, force: bool = False +) -> None: + """Stamp ``config.skip_optimize`` (and clear ``config.quant``) once. + + Sets ``config.skip_optimize = True`` and clears ``config.quant`` if the + input ONNX is already quantized. + + This is the **single defensive detection point** for the library entry + points (``build_onnx_model``, ``build_hf_model``). It is a no-op when + ``config.skip_optimize`` is already True (i.e. the unified CLI path + via :func:`generate_onnx_build_config` already stamped the config) so + ``is_quantized_onnx()`` runs at most **once per build**. + + Args: + config: Build config to stamp in place. + onnx_path: Path to the ONNX file under consideration. + force: When True, stamp unconditionally without running + ``is_quantized_onnx`` (used to honor the legacy + ``skip_optimize=True`` kwarg from direct callers). + """ + if config.skip_optimize: + return + if force: + config.skip_optimize = True + config.quant = None + return + + if is_quantized_onnx(onnx_path): + config.skip_optimize = True + config.quant = None + logger.info( + "Pre-quantized model detected (QDQ or QOperator nodes present). " + "Skipping optimize + quantize stages." + ) + + def run_optimize_analyze_loop( model_path: Path, optimized_path: Path, @@ -63,7 +100,11 @@ def run_optimize_analyze_loop( ep: Target execution provider for the analyzer. device: Target device for the analyzer. max_optim_iterations: Maximum autoconf re-optimization rounds. - 0 means optimize+analyze only (no autoconf re-optimization). + 0 disables the autoconf re-optimize/analyze loop entirely + (i.e. ``_run_analyze_loop`` is not invoked), in which case + this function performs the initial ``optimize_onnx`` pass + only (or, when ``skip_optimize=True``, just copies the input + to ``optimized_path``). allow_unsupported_nodes: If True, log a warning instead of raising when unsupported nodes persist after analysis, letting the build proceed (the EP may still run them, e.g. via CPU fallback). @@ -87,6 +128,13 @@ def run_optimize_analyze_loop( if not config.auto: max_optim_iterations = 0 + # Enforce the skip_optimize invariant: autoconf re-optimize would + # crash on pre-quantized models for the same reason the initial + # optimize was skipped (ORT lacks kernels for the integer ops on the + # host EP). Drop iterations to 0 so callers can pass any value safely. + if skip_optimize: + max_optim_iterations = 0 + t0 = time.monotonic() # 1. Optimize (or skip for pre-quantized models) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index a67fd8480..67c4e771c 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -29,9 +29,9 @@ from ..compiler import compile_onnx from ..export import export_onnx -from ..onnx import copy_onnx_model, is_quantized_onnx +from ..onnx import copy_onnx_model from ..quant import quantize_onnx -from .common import run_optimize_analyze_loop +from .common import ensure_pre_quantized_stamped, run_optimize_analyze_loop if TYPE_CHECKING: @@ -241,22 +241,24 @@ def _name(base: str) -> str: # duplicated between build_hf_model() and build_onnx_model(). Extract # into a shared run_build_stages() function in common.py. # ========================================================================= - skip_optimize: bool = kwargs.pop("skip_optimize", False) - # Defensive fallback: when called through the unified pipeline, - # generate_*_build_config() already detects QDQ models and sets - # config.quant=None. This is_quantized_onnx() check is redundant in that - # path but kept for backward compatibility when build_hf_model() - # is called directly with a hand-built config. - is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize + # Single defensive detection on the freshly exported ONNX. No-op when + # the caller already stamped ``config.skip_optimize``. HF export rarely + # produces a pre-quantized ONNX (Optimum exports float weights), but a + # direct caller could plausibly hand a pre-quantized + # ``pytorch_model`` and reach this branch. + skip_optimize_kwarg: bool = kwargs.pop("skip_optimize", False) + ensure_pre_quantized_stamped(config, current_path, force=skip_optimize_kwarg) + is_pre_quantized = config.skip_optimize if is_pre_quantized: - logger.info( - "Pre-quantized model detected (QDQ or QOperator nodes present). " - "Skipping optimize + quantize, running analyze-only." - ) + logger.info("Skipping optimize + quantize stages (config.skip_optimize=True)") stages_skipped.append("optimize") - # Analyze-only: skip ORT-based graph optimization (no kernel for - # QOperator ops like ConvInteger on the host EP), no autoconf loop. + # Skip the ORT-based graph optimization (no kernel for QOperator + # ops like ConvInteger on the host EP). The autoconf re-optim/ + # analyze loop is disabled too -- ``run_optimize_analyze_loop`` + # forces ``max_optim_iterations=0`` when ``skip_optimize=True``, + # so ``_run_analyze_loop`` is not invoked. The model still flows + # through later stages (quantize-skip + compile) for validation. current_path, _, analyze_iterations, analyze_unsupported_nodes, analyze_details = ( run_optimize_analyze_loop( model_path=current_path, @@ -299,37 +301,31 @@ def _name(base: str) -> str: # ========================================================================= # [4] QUANTIZE (optional — config.quant=None means skip) # ========================================================================= + # No defensive ``is_quantized_onnx`` re-check here: when the model is + # pre-quantized, ``ensure_pre_quantized_stamped`` has already set + # ``config.quant = None`` at stage [3], so this branch naturally + # falls through to the ``quant is None`` skip path. quant_result = None if is_pre_quantized: if "quantize" not in stages_skipped: stages_skipped.append("quantize") logger.info("Quantize skipped (pre-quantized model)") elif config.quant is not None: - # Defensive fallback: catches the edge case where a direct caller - # provides config.quant != None but the model already has QDQ nodes - # (e.g., hand-built config without running generate_*_build_config). - if is_quantized_onnx(current_path): - logger.warning( - "Model already contains QDQ nodes, skipping quantization. " - "Set config.quant=None to silence this warning." - ) - stages_skipped.append("quantize") - else: - logger.info("Quantizing model...") - t0 = time.monotonic() - quant_result = quantize_onnx( - model_path=current_path, - output_path=quantized_path, - config=config.quant, - **onnx_kwargs, - ) - if not quant_result.success: - errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" - raise RuntimeError(f"Quantization failed: {errors}") - current_path = quantized_path - stage_timings["quantize"] = time.monotonic() - t0 - stages_completed.append("quantize") - logger.info("Quantize done (%.1fs) -> %s", stage_timings["quantize"], quantized_path) + logger.info("Quantizing model...") + t0 = time.monotonic() + quant_result = quantize_onnx( + model_path=current_path, + output_path=quantized_path, + config=config.quant, + **onnx_kwargs, + ) + if not quant_result.success: + errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" + raise RuntimeError(f"Quantization failed: {errors}") + current_path = quantized_path + stage_timings["quantize"] = time.monotonic() - t0 + stages_completed.append("quantize") + logger.info("Quantize done (%.1fs) -> %s", stage_timings["quantize"], quantized_path) else: stages_skipped.append("quantize") logger.info("Quantize skipped (config.quant is None)") diff --git a/src/winml/modelkit/build/onnx.py b/src/winml/modelkit/build/onnx.py index 04d088ab0..fb0560370 100644 --- a/src/winml/modelkit/build/onnx.py +++ b/src/winml/modelkit/build/onnx.py @@ -21,9 +21,9 @@ from typing import TYPE_CHECKING, Any from ..compiler import compile_onnx -from ..onnx import copy_onnx_model, is_quantized_onnx +from ..onnx import copy_onnx_model from ..quant import quantize_onnx -from .common import run_optimize_analyze_loop +from .common import ensure_pre_quantized_stamped, run_optimize_analyze_loop from .hf import BuildResult @@ -141,27 +141,27 @@ def build_onnx_model( copy_onnx_model(onnx_path, current_path) # ========================================================================= - # [1] OPTIMIZE + ANALYZE (or ANALYZE-ONLY for pre-quantized) + # [1] OPTIMIZE + ANALYZE (or SKIP-BOTH for pre-quantized) # FIXME: Stages [1]-[4] (optimize, quantize, compile, finalize) are # duplicated between build_onnx_model() and build_hf_model(). Extract # into a shared run_build_stages() function in common.py. # ========================================================================= - skip_optimize: bool = kwargs.pop("skip_optimize", False) - # Defensive fallback: when called through the unified pipeline, - # generate_onnx_build_config() already detects QDQ models and sets - # config.quant=None. This is_quantized_onnx() check is redundant in that - # path but kept for backward compatibility when build_onnx_model() - # is called directly with a hand-built config. - is_pre_quantized = is_quantized_onnx(current_path) or skip_optimize + # Single defensive detection. No-op when the CLI path (via + # ``generate_onnx_build_config``) already stamped ``config.skip_optimize``. + # Direct callers who hand-built a config trigger the one detection here. + skip_optimize_kwarg: bool = kwargs.pop("skip_optimize", False) + ensure_pre_quantized_stamped(config, current_path, force=skip_optimize_kwarg) + is_pre_quantized = config.skip_optimize if is_pre_quantized: - logger.info( - "Pre-quantized model detected (QDQ or QOperator nodes present). " - "Skipping optimize + quantize, running analyze-only." - ) + logger.info("Skipping optimize + quantize stages (config.skip_optimize=True)") stages_skipped.append("optimize") - # Analyze-only: skip ORT-based graph optimization (no kernel for - # QOperator ops like ConvInteger on the host EP), no autoconf loop. + # Skip the ORT-based graph optimization (no kernel for QOperator + # ops like ConvInteger on the host EP). The autoconf re-optim/ + # analyze loop is disabled too -- ``run_optimize_analyze_loop`` + # forces ``max_optim_iterations=0`` when ``skip_optimize=True``, + # so ``_run_analyze_loop`` is not invoked. The model still flows + # through later stages (quantize-skip + compile) for validation. current_path, _, analyze_iters, analyze_unsupported, analyze_details = ( run_optimize_analyze_loop( model_path=current_path, @@ -199,6 +199,10 @@ def build_onnx_model( # ========================================================================= # [2] QUANTIZE (optional — config.quant=None means skip) # ========================================================================= + # No defensive ``is_quantized_onnx`` re-check here: when the model is + # pre-quantized, ``ensure_pre_quantized_stamped`` has already set + # ``config.quant = None`` at stage [1], so this branch naturally + # falls through to the ``quant is None`` skip path. quant_result = None if is_pre_quantized: # Already handled above -- skip quantize for pre-quantized models @@ -206,31 +210,21 @@ def build_onnx_model( stages_skipped.append("quantize") logger.info("Quantize skipped (pre-quantized model)") elif config.quant is not None: - # Defensive fallback: catches the edge case where a direct caller - # provides config.quant != None but the model already has QDQ nodes - # (e.g., hand-built config without running generate_*_build_config). - if is_quantized_onnx(current_path): - logger.warning( - "Model already contains QDQ nodes, skipping quantization. " - "Set config.quant=None to silence this warning." - ) - stages_skipped.append("quantize") - else: - logger.info("Quantizing model...") - t0 = time.monotonic() - quant_result = quantize_onnx( - model_path=current_path, - output_path=quantized_path, - config=config.quant, - **onnx_kwargs, - ) - if not quant_result.success: - errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" - raise RuntimeError(f"Quantization failed: {errors}") - current_path = quantized_path - stage_timings["quantize"] = time.monotonic() - t0 - stages_completed.append("quantize") - logger.info("Quantize done (%.1fs) -> %s", stage_timings["quantize"], quantized_path) + logger.info("Quantizing model...") + t0 = time.monotonic() + quant_result = quantize_onnx( + model_path=current_path, + output_path=quantized_path, + config=config.quant, + **onnx_kwargs, + ) + if not quant_result.success: + errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" + raise RuntimeError(f"Quantization failed: {errors}") + current_path = quantized_path + stage_timings["quantize"] = time.monotonic() - t0 + stages_completed.append("quantize") + logger.info("Quantize done (%.1fs) -> %s", stage_timings["quantize"], quantized_path) else: stages_skipped.append("quantize") logger.info("Quantize skipped (config.quant is None)") diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 128d5de40..79b523544 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -34,7 +34,6 @@ print_error, print_final, print_setup, - print_stage_skip, print_stages_header, ) from ..utils.logging import configure_logging @@ -579,9 +578,7 @@ def build( # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx file thereafter. if model_id is not None: - from ..loader import maybe_resolve_hf_onnx_path - - model_id = maybe_resolve_hf_onnx_path(model_id) + model_id = cli_utils.normalize_model_arg(model_id) # Load or auto-generate config if config_file is not None: @@ -595,11 +592,22 @@ def build( raise click.UsageError("-m/--model is required when -c is not provided.") from ..config import generate_build_config - config_or_configs = generate_build_config( - model_id, - trust_remote_code=trust_remote_code, - device=device, - ) + # When ``model_id`` resolves to an .onnx file (either a local path or + # a Hub-hosted ONNX ref that was just downloaded by + # ``normalize_model_arg``), route to the ONNX config generator instead + # of treating the path as a HuggingFace repo id (which would try to + # load the .onnx file as a JSON config and crash). + if cli_utils.is_onnx_file_path(model_id): + config_or_configs = generate_build_config( + onnx_path=model_id, + device=device, + ) + else: + config_or_configs = generate_build_config( + model_id, + trust_remote_code=trust_remote_code, + device=device, + ) if not quant: config_or_configs.quant = None # Auto-generated configs: compile disabled by default unless @@ -1150,16 +1158,14 @@ def _run_quantize_stage( Returns: Updated current_path (quantized_path if quantization ran, else unchanged). """ - from ..onnx import is_quantized_onnx from ..quant import quantize_onnx from ..utils.console import StageLive if config.quant is None: - return current_path - - if is_quantized_onnx(current_path): - print_stage_skip(console, "quantize", "(QDQ nodes already present)") - stage_timings.append(("Quantize", None)) + # ``generate_onnx_build_config`` and ``ensure_pre_quantized_stamped`` + # (in build/common.py) both clear ``config.quant`` for pre-quantized + # inputs, so this single check covers both "user-explicit None" and + # "auto-detected pre-quantized" cases. return current_path with StageLive("quantize", console) as sl: @@ -1445,7 +1451,7 @@ def _build_onnx_pipeline( Returns list of (stage_name, elapsed_seconds | None) for summary, or None if build was reused. """ - from ..onnx import copy_onnx_model, is_quantized_onnx + from ..onnx import copy_onnx_model max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3) allow_unsupported_nodes: bool = extra_kwargs.pop("allow_unsupported_nodes", False) @@ -1488,11 +1494,11 @@ def _build_onnx_pipeline( # Pre-quantized models (QDQ or QOperator format) cannot pass through # ORT-based graph optimization on hosts that lack kernels for ops like - # ``ConvInteger``. Skip the optimize pass and the autoconf re-optim - # loop; analyze still runs lint-only. - is_pre_quantized = is_quantized_onnx(current_path) - if is_pre_quantized: - max_iters = 0 + # ``ConvInteger``. The unified pipeline stamps ``config.skip_optimize`` + # exactly once in ``generate_onnx_build_config`` -- downstream stages + # (here and inside ``build_onnx_model``) read the flag instead of + # re-running ``is_quantized_onnx`` on the same file. + is_pre_quantized = config.skip_optimize # ── Optimize stage (first stage for ONNX — show I/O here) ──── current_path, _ = _run_optimize_stage( diff --git a/src/winml/modelkit/commands/config.py b/src/winml/modelkit/commands/config.py index 834b12c19..75d6824cb 100644 --- a/src/winml/modelkit/commands/config.py +++ b/src/winml/modelkit/commands/config.py @@ -227,11 +227,10 @@ def config( generate_hf_build_config, generate_onnx_build_config, ) - from ..loader import maybe_resolve_hf_onnx_path # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx file thereafter. - hf_model = maybe_resolve_hf_onnx_path(hf_model) + hf_model = cli_utils.normalize_model_arg(hf_model) # Load override config from JSON file if provided override = None diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index e1161bfdc..00d22daec 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -453,6 +453,10 @@ def _resolve_model_path( raise click.UsageError( "--model-id is required when using composite `-m role=path` options." ) + # Each role's path may be either a local .onnx file OR a Hub-hosted + # ONNX ref (``org/repo/path/file.onnx``). ``normalize_model_arg`` + # resolves Hub refs to local cached paths so downstream code sees + # only filesystem paths. sub_model_paths: dict[str, str] = {} for v in role_assigned: role, _, path = v.partition("=") @@ -467,6 +471,7 @@ def _resolve_model_path( f"Duplicate role {role!r} in -m options.", param_hint="-m/--model", ) + path = cli_utils.normalize_model_arg(path) or path if not Path(path).exists(): raise click.BadParameter( f"ONNX file not found: {path}", @@ -484,11 +489,8 @@ def _resolve_model_path( if Path(value).suffix.lower() == ".onnx": # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx path thereafter. - from ..loader import is_hf_onnx_path, resolve_hf_onnx_path - - if is_hf_onnx_path(value): - value = str(resolve_hf_onnx_path(value)) - elif not Path(value).exists(): + value = cli_utils.normalize_model_arg(value) or value + if not Path(value).exists(): raise click.BadParameter( f"ONNX file not found: {value}", param_hint="-m/--model", diff --git a/src/winml/modelkit/commands/inspect.py b/src/winml/modelkit/commands/inspect.py index 4f10484b9..f2c68748e 100644 --- a/src/winml/modelkit/commands/inspect.py +++ b/src/winml/modelkit/commands/inspect.py @@ -197,10 +197,11 @@ def inspect( # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is not downloadable for inspect (which targets HF architecture # metadata, not raw ONNX graphs), but surfacing the same friendly - # error keeps the UX consistent with local .onnx inputs. - from ..loader import is_hf_onnx_path + # error keeps the UX consistent with local .onnx inputs. Detect via + # the unified classifier so we don't trigger an unwanted download. + from ..utils.model_input import classify_model_input - if model_id and is_hf_onnx_path(model_id): + if model_id and classify_model_input(model_id).kind == "hub_onnx": raise click.ClickException( "ONNX file inspection is not yet supported. " "Use 'winml config -m model.onnx' for ONNX build config." diff --git a/src/winml/modelkit/commands/perf.py b/src/winml/modelkit/commands/perf.py index a7cac9854..7e7b9dd99 100644 --- a/src/winml/modelkit/commands/perf.py +++ b/src/winml/modelkit/commands/perf.py @@ -191,6 +191,7 @@ def to_dict(self) -> dict[str, Any]: def generate_random_inputs( io_config: dict[str, Any], batch_size: int = 1, + shape_config: dict[str, Any] | None = None, ) -> dict[str, np.ndarray]: """Generate random inputs based on model io_config. @@ -200,26 +201,47 @@ def generate_random_inputs( Args: io_config: Model I/O configuration from WinMLSession.io_config. Expected keys: ``input_names``, ``input_shapes``, ``input_types``. - Optional key: ``input_value_ranges`` -- a dict mapping input names - to ``[low, high)`` integer ranges sourced from the build config. + Optional keys: ``input_value_ranges`` -- a dict mapping input + names to ``[low, high)`` integer ranges sourced from the build + config; ``input_symbolic_shapes`` -- a list of shapes whose + dynamic dims hold the declared symbolic dim_param name. batch_size: Override batch dimension + shape_config: Optional overrides for dynamic dimensions. Two forms + are supported and may be mixed: + + * Per-input full-shape override: + ``{"input_points": [1, 1, 1, 2], ...}`` -- the value is used + as the resolved shape verbatim. + * Symbolic dim override: + ``{"num_points_per_image": 1, "num_boxes_per_image": 1}`` -- + applied to any dim whose ``dim_param`` matches the key. + + Symbolic overrides take precedence over positional defaults. Returns: Dictionary of input_name -> numpy array """ from ..core import generate_dummy_inputs_from_specs + symbolic_shapes = io_config.get("input_symbolic_shapes") or [ + [None] * len(s) for s in io_config["input_shapes"] + ] + overrides = shape_config or {} + specs: dict[str, dict[str, Any]] = {} - for name, shape, dtype_str in zip( + for name, shape, symbolic, dtype_str in zip( io_config["input_names"], io_config["input_shapes"], + symbolic_shapes, io_config["input_types"], strict=True, ): resolved_shape = _resolve_shape( shape=shape, + symbolic_shape=symbolic, input_name=name, batch_size=batch_size, + shape_config=overrides, ) np_dtype = np.dtype(dtype_str) @@ -240,17 +262,37 @@ def _resolve_shape( shape: list | tuple | None, input_name: str, batch_size: int, + symbolic_shape: list | tuple | None = None, + shape_config: dict[str, Any] | None = None, ) -> tuple[int, ...]: - """Resolve dynamic dimensions in shape.""" + """Resolve dynamic dimensions in shape. + + Resolution priority for each dim: + 1. ``shape_config[input_name]`` -- full per-input shape override. + 2. ``shape_config[dim_param]`` -- symbolic dim override (when the + ONNX graph exposed a ``dim_param`` for this dim). + 3. ``batch_size`` for the first dim. + 4. ``DYNAMIC_DIM_DEFAULTS`` positional fallback (defaults to 128). + """ + overrides = shape_config or {} + + # Form 1: full per-input shape override. + if input_name in overrides and isinstance(overrides[input_name], (list, tuple)): + return tuple(int(d) for d in overrides[input_name]) + if shape is None: logger.warning("Shape unknown for '%s', using (batch_size,)", input_name) return (batch_size,) + sym = list(symbolic_shape) if symbolic_shape is not None else [None] * len(shape) resolved = [] for i, dim in enumerate(shape): - if dim is None or dim == -1 or (isinstance(dim, str)): + if dim is None or dim == -1 or isinstance(dim, str): # Dynamic dimension - resolve - if i == 0: + sym_name = sym[i] if i < len(sym) else None + if isinstance(sym_name, str) and sym_name in overrides: + resolved.append(int(overrides[sym_name])) + elif i == 0: # First dimension is almost always batch resolved.append(batch_size) else: @@ -394,6 +436,7 @@ def _generate_inputs(self) -> None: self._inputs = generate_random_inputs( io_config=io_config, batch_size=self.config.batch_size, + shape_config=self.config.shape_config, ) def _run_benchmark(self) -> PerfStats: @@ -1198,11 +1241,10 @@ def perf( # is downloaded once and treated as a local .onnx path thereafter. # Must run BEFORE the ``Path(hf_model).suffix == ".onnx"`` check below # so a Hub ref is not mistaken for a missing local file. - from ..loader import maybe_resolve_hf_onnx_path - - model = maybe_resolve_hf_onnx_path(model) - - hf_model = model + # ``normalize_model_arg`` returns ``str | None`` per its signature; + # the ``or model`` keeps the narrowed ``str`` type for downstream use. + hf_model: str = cli_utils.normalize_model_arg(model) or model + model = hf_model # Apply build config defaults (CLI explicit options take precedence). # Read raw JSON so missing keys are distinguishable from dataclass defaults. diff --git a/src/winml/modelkit/compiler/configs.py b/src/winml/modelkit/compiler/configs.py index 50f2c71a5..d769edfd4 100644 --- a/src/winml/modelkit/compiler/configs.py +++ b/src/winml/modelkit/compiler/configs.py @@ -211,10 +211,24 @@ def for_openvino(cls, device: str | None = None) -> WinMLCompileConfig: @classmethod def for_vitisai(cls, device: str | None = None) -> WinMLCompileConfig: - """Factory for Vitis AI (AMD NPU) compilation.""" + """Factory for Vitis AI (AMD NPU) compilation. + + Populates Phoenix XDNA defaults from ``RYZEN_AI_INSTALLATION_PATH`` + when available (target=X1, xclbin=/voe-4.0-win_amd64/ + xclbins/phoenix/4x4.xclbin, xlnx_enable_py3_round=0). VitisAI EP + ignores ``device_type``; the correct device hint is the xclbin path. + """ + import os + from pathlib import Path as _Path + provider_options: dict[str, str] = {} - if device: - provider_options["device_type"] = device.upper() + ryzen_ai = os.environ.get("RYZEN_AI_INSTALLATION_PATH") + if ryzen_ai: + xclbin = _Path(ryzen_ai) / "voe-4.0-win_amd64" / "xclbins" / "phoenix" / "4x4.xclbin" + if xclbin.exists(): + provider_options["target"] = "X1" + provider_options["xclbin"] = str(xclbin) + provider_options["xlnx_enable_py3_round"] = "0" ep_cfg = EPConfig( provider="vitisai", enable_ep_context=True, diff --git a/src/winml/modelkit/compiler/utils.py b/src/winml/modelkit/compiler/utils.py index e109e0e3f..9c0fa9583 100644 --- a/src/winml/modelkit/compiler/utils.py +++ b/src/winml/modelkit/compiler/utils.py @@ -47,10 +47,39 @@ ) -# Union of all quantization op types (QDQ + QOperator). Use this for -# "is the model already quantized?" detection regardless of which format -# the producer used. -QUANTIZATION_OP_TYPES: frozenset[str] = QDQ_OP_TYPES | QOPERATOR_OP_TYPES +# Dynamic quantization op types. Produced by ``onnxruntime.quantization`` +# in dynamic mode (e.g. ``QuantType.QUInt8`` without static calibration). +# These ops compute the input scale/zero-point at inference time rather +# than baking them into the graph, so a model containing them is already +# quantized and must not be re-optimized or re-quantized. +DYNAMIC_QUANT_OP_TYPES: frozenset[str] = frozenset( + { + "DynamicQuantizeLinear", + "DynamicQuantizeMatMul", + } +) + + +# Union of all quantization op types (QDQ + QOperator + dynamic). Use +# this for "is the model already quantized?" detection regardless of +# which format the producer used. +QUANTIZATION_OP_TYPES: frozenset[str] = ( + QDQ_OP_TYPES | QOPERATOR_OP_TYPES | DYNAMIC_QUANT_OP_TYPES +) + + +# CodeQL flagged ``QUANTIZATION_OP_TYPES`` as unused because it is +# consumed via the lazy re-export in ``modelkit.onnx`` (see +# ``onnx/__init__.py``'s ``_LAZY_MAP``) rather than a direct import. +# Declaring ``__all__`` makes the public surface explicit for both the +# import system and static analyzers. +__all__ = [ + "DYNAMIC_QUANT_OP_TYPES", + "QDQ_OP_TYPES", + "QOPERATOR_OP_TYPES", + "QUANTIZATION_OP_TYPES", + "needs_format_conversion", +] def needs_format_conversion(model_path: Path, ep: EPAlias) -> bool: diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index d1724c1ae..e50dcb210 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -136,6 +136,14 @@ class WinMLBuildConfig: compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) eval: WinMLEvaluationConfig | None = None auto: bool = True + # Stamped True by generate_*_build_config (or by the build_*_model + # entry-point defensive fallback) when the input ONNX is already + # quantized (QDQ or QOperator format). When True, the optimize stage + # is bypassed for downstream pipelines (no ORT graph optimization, + # no autoconf analyze loop). This is the SINGLE source of truth for + # "is this model pre-quantized?" — downstream stages must read this + # flag instead of calling ``is_quantized_onnx`` again. + skip_optimize: bool = False def __post_init__(self) -> None: # Lazy import: inject into module globals so typing.get_type_hints() @@ -169,6 +177,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: ), eval=eval_cfg, auto=config_dict.get("auto", True), + skip_optimize=config_dict.get("skip_optimize", False), ) def to_dict(self) -> dict: @@ -176,6 +185,8 @@ def to_dict(self) -> dict: result: dict = {} if not self.auto: result["auto"] = False + if self.skip_optimize: + result["skip_optimize"] = True result.update( { "export": self.export.to_dict() if self.export is not None else None, @@ -428,8 +439,11 @@ def generate_onnx_build_config( ) if is_quantized_onnx(onnx_path_resolved): - # Skip optimize+quantize, compile with resolved policy + # Skip optimize+quantize, compile with resolved policy. + # ``skip_optimize`` is the single source of truth — downstream + # pipelines must read this flag and not re-detect. config.quant = None + config.skip_optimize = True config.compile = resolved_compile logger.info("Quantized model (QDQ) detected") else: diff --git a/src/winml/modelkit/core/onnx_utils.py b/src/winml/modelkit/core/onnx_utils.py index bbc1b8e01..7737994fa 100644 --- a/src/winml/modelkit/core/onnx_utils.py +++ b/src/winml/modelkit/core/onnx_utils.py @@ -394,6 +394,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: io_config: dict[str, list] = { "input_names": [], "input_shapes": [], + "input_symbolic_shapes": [], "input_types": [], "output_names": [], "output_shapes": [], @@ -418,18 +419,25 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: except (KeyError, AttributeError): dtype = np.dtype(np.float32) # Default fallback - # Extract shape (None for dynamic dims) + # Extract shape (None for dynamic dims) and capture symbolic + # dim_param names in a parallel list so downstream consumers + # can resolve dynamic dims by their declared name. shape: list[int | None] = [] + symbolic_shape: list[int | str | None] = [] if tensor_type.HasField("shape"): for dim in tensor_type.shape.dim: if dim.HasField("dim_value"): shape.append(dim.dim_value) + symbolic_shape.append(dim.dim_value) else: shape.append(None) # Dynamic dimension + symbolic_shape.append(dim.dim_param or None) io_config[f"{prefix}_names"].append(name) io_config[f"{prefix}_shapes"].append(shape) io_config[f"{prefix}_types"].append(dtype) + if prefix == "input": + io_config["input_symbolic_shapes"].append(symbolic_shape) return io_config diff --git a/src/winml/modelkit/data/hub_models.json b/src/winml/modelkit/data/hub_models.json index 66266ea4e..404a8156b 100644 --- a/src/winml/modelkit/data/hub_models.json +++ b/src/winml/modelkit/data/hub_models.json @@ -1656,13 +1656,17 @@ }, { "model_id": "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx", - "task": "mask-generation", - "model_type": "sam3_tracker" + "task": "image-feature-extraction", + "model_type": "sam3_tracker", + "supported_eps": {}, + "size_mb": 504.2 }, { "model_id": "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx", "task": "mask-generation", - "model_type": "sam3_tracker" + "model_type": "sam3_tracker", + "supported_eps": {}, + "size_mb": 9.6 } ] } diff --git a/src/winml/modelkit/datasets/__init__.py b/src/winml/modelkit/datasets/__init__.py index ebbde0cfb..c77f0ca47 100644 --- a/src/winml/modelkit/datasets/__init__.py +++ b/src/winml/modelkit/datasets/__init__.py @@ -20,6 +20,7 @@ from .depth_estimation import DEFAULT_DEPTH_ESTIMATION_SIZE, DepthEstimationDataset from .image import ImageDataset from .image_segmentation import ImageSegmentationDataset +from .mask_generation import MaskGenerationDataset from .object_detection import DEFAULT_OBJECT_DETECTION_SIZE, ObjectDetectionDataset from .processor_utils import get_image_processor_config from .random_dataset import RandomDataset @@ -47,6 +48,7 @@ "fill-mask": TextDataset, "zero-shot-classification": TextDataset, "image-segmentation": ImageSegmentationDataset, + "mask-generation": MaskGenerationDataset, "depth-estimation": DepthEstimationDataset, "random": RandomDataset, # Add more task types as needed diff --git a/src/winml/modelkit/datasets/mask_generation.py b/src/winml/modelkit/datasets/mask_generation.py new file mode 100644 index 000000000..aa7ccb31e --- /dev/null +++ b/src/winml/modelkit/datasets/mask_generation.py @@ -0,0 +1,394 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Mask-generation dataset for promptable segmentation models (SAM/SAM2/SAM3). + +Unlike ``ImageSegmentationDataset`` (semantic segmentation, single model → +pixel-wise class map), promptable mask-generation requires +``(image, prompt, gt_mask)`` triples: the model receives an image *and* a +user prompt (bbox / point / text concept), and emits the corresponding +binary mask. + +This dataset yields raw image + GT mask + an auto-derived prompt -- it +does **not** apply SAM-specific preprocessing (1008x1008 padding, +ImageNet-mean normalization), because that lives in the evaluator +alongside the encoder/decoder ONNX sessions. + +Supported prompt modes: +- ``"bbox"`` -- tight axis-aligned bbox derived from the GT mask. + Matches the SAM family's standard mIoU benchmark protocol. +- ``"point"`` — single foreground point sampled from the GT mask + centroid (or a random foreground pixel if centroid is outside the + mask). +- ``"text"`` — free-form text concept. Requires the dataset to expose a + text column via ``text_col`` config (typically a class name / caption). + SAM 3's flagship "concept-prompted segmentation" mode. + +The yielded GT mask is collapsed to **binary foreground vs background** +(``mask > 0``) by default. Datasets that distinguish multiple instances +or classes can opt into instance-level GT via ``binarize=False``, but +multi-instance evaluation is left to the evaluator (this dataset still +emits one prompt per sample; instance-level AP requires repeating +samples once per instance, which the COCO evaluator handles). +""" + +from __future__ import annotations + +import logging +from random import Random +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from datasets import load_dataset +from datasets.features import Image as HFImage +from PIL import Image as PILImage + +from .base import BaseTaskDataset + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(__name__) + +PromptMode = Literal["bbox", "point", "text"] +VALID_PROMPT_MODES: tuple[PromptMode, ...] = ("bbox", "point", "text") + +# Mask coverage filter defaults (fraction of pixels that are foreground). +# Excludes degenerate empty masks (no signal) and near-full masks (trivial). +DEFAULT_MIN_COVERAGE = 0.005 # 0.5% +DEFAULT_MAX_COVERAGE = 0.95 # 95% + + +class MaskGenerationDataset(BaseTaskDataset): + """Dataset for promptable mask-generation tasks (SAM/SAM2/SAM3 family). + + Each sample is a ``dict`` with keys: + + - ``image``: ``PIL.Image.Image`` in RGB, at original resolution. + - ``gt_mask``: ``np.ndarray`` of shape ``(H, W)`` and dtype ``bool``, + same size as ``image``. + - ``prompt``: ``dict`` whose key depends on ``prompt_mode``: + - bbox: ``{"bbox": [x1, y1, x2, y2]}`` (xyxy, in original pixels). + - point: ``{"point": [x, y], "label": 1}`` (foreground point). + - text: ``{"text": ""}``. + - ``sample_id``: ``str`` -- stable identifier for logging/visualization. + """ + + DEFAULT_DATASET = "mattmdjaga/human_parsing_dataset" + DEFAULT_SPLIT = "train" + + def __init__( + self, + model_name: str, + dataset_name: str | None = None, + max_samples: int | None = None, + data_split: str | None = None, + prompt_mode: PromptMode = "bbox", + binarize: bool = True, + min_mask_coverage: float = DEFAULT_MIN_COVERAGE, + max_mask_coverage: float = DEFAULT_MAX_COVERAGE, + text_col: str | None = None, + seed: int = 42, + **kwargs: Any, + ) -> None: + """Initialize the mask-generation dataset. + + Args: + model_name: HuggingFace model identifier (kept for API parity + with other datasets; mask-generation does not consult the + model's image processor since SAM does its own + preprocessing in the evaluator). + dataset_name: Source dataset (defaults to + ``mattmdjaga/human_parsing_dataset``). + max_samples: Cap the number of samples (None = all). + data_split: HF dataset split (defaults to ``"train"``). + prompt_mode: One of ``"bbox"``, ``"point"``, ``"text"``. + binarize: Collapse multi-class masks to foreground-vs-background. + min_mask_coverage: Drop samples whose foreground fraction is + below this (filters empty/near-empty masks). + max_mask_coverage: Drop samples whose foreground fraction is + above this (filters trivial near-full masks). + text_col: For ``prompt_mode="text"``, the dataset column holding + the text prompt. Ignored otherwise. + seed: RNG seed for point sampling + sample subselection. + **kwargs: forwarded to ``BaseTaskDataset``. + """ + if prompt_mode not in VALID_PROMPT_MODES: + raise ValueError( + f"prompt_mode={prompt_mode!r} is not one of {VALID_PROMPT_MODES}" + ) + if not 0.0 <= min_mask_coverage <= max_mask_coverage <= 1.0: + raise ValueError( + "Require 0 <= min_mask_coverage <= max_mask_coverage <= 1; " + f"got min={min_mask_coverage}, max={max_mask_coverage}" + ) + + self._prompt_mode = prompt_mode + self._binarize = binarize + self._min_coverage = min_mask_coverage + self._max_coverage = max_mask_coverage + self._text_col = text_col + self._seed = seed + self._rng = Random(seed) + + if data_split is None: + data_split = self.DEFAULT_SPLIT + + super().__init__( + model_name=model_name, + dataset_name=dataset_name, + max_samples=max_samples, + data_split=data_split, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Initialization + # ------------------------------------------------------------------ + + def _initialize(self) -> None: + if self._dataset_name is None: + self._dataset_name = self.DEFAULT_DATASET + + logger.info( + "Loading mask-generation dataset %s (split=%s, prompt_mode=%s)", + self._dataset_name, self._data_split, self._prompt_mode, + ) + dataset = load_dataset(self._dataset_name, split=self._data_split) + self._detect_columns(dataset) + + if self._prompt_mode == "text" and self._text_col is None: + raise ValueError( + "prompt_mode='text' requires text_col to be set explicitly " + f"(dataset {self._dataset_name!r} columns: " + f"{list(dataset.features) if hasattr(dataset, 'features') else 'unknown'})" + ) + if self._text_col is not None and self._text_col not in dataset.features: + raise ValueError( + f"text_col={self._text_col!r} not found in dataset features " + f"({list(dataset.features)})" + ) + + # Optional shuffle BEFORE truncation so different seeds give + # different sample windows. + shuffle = self._config.get("shuffle", False) + if shuffle: + dataset = dataset.shuffle(seed=self._seed) + + # Apply max_samples cap (or take all). We over-fetch slightly to + # absorb the coverage filter so a hard cap of N still yields ~N + # samples in most cases. + if self._max_samples is not None: + cap = self._max_samples + over = min(max(2 * cap, cap + 20), len(dataset)) + dataset = dataset.select(range(over)) + # else: keep full dataset (caller manages cost) + + self._dataset = dataset + logger.info( + "Mask-generation dataset ready: %d candidate samples, image='%s', mask='%s'", + len(self._dataset), self._image_col, self._mask_col, + ) + + def _detect_columns(self, dataset: Any) -> None: + """Pick image + mask columns. + + Uses the same heuristics as ``ImageSegmentationDataset``: prefer + name-matched HF Image features, else fall back to the + only-two-Image-columns convention. + """ + if not hasattr(dataset, "features"): + raise ValueError(f"Dataset {self._dataset_name} has no features metadata") + features = dataset.features + + image_cands: list[str] = [] + mask_cands: list[str] = [] + for col, feat in features.items(): + if not isinstance(feat, HFImage): + continue + lc = col.lower() + mask_keywords = ("annotation", "mask", "label", "segmentation", "target", "gt") + if any(k in lc for k in mask_keywords): + mask_cands.append(col) + else: + image_cands.append(col) + + # Fallback: if only two Image columns and we couldn't classify both, + # assume order [image, mask]. + all_image_feats = [c for c, f in features.items() if isinstance(f, HFImage)] + if len(all_image_feats) == 2 and (not image_cands or not mask_cands): + image_cands = [all_image_feats[0]] + mask_cands = [all_image_feats[1]] + + if not image_cands or not mask_cands: + raise ValueError( + f"Could not auto-detect image + mask columns in " + f"{self._dataset_name!r}; available features: {list(features)}" + ) + + # Prefer canonical names if present. + self._image_col = next((c for c in image_cands if c.lower() == "image"), image_cands[0]) + preferred_mask = ("mask", "annotation", "label", "segmentation") + self._mask_col = next( + (c for c in mask_cands if c.lower() in preferred_mask), + mask_cands[0], + ) + + # ------------------------------------------------------------------ + # ABC overrides + # ------------------------------------------------------------------ + + @property + def label_col(self) -> str: + """Dataset column holding the per-sample mask (alias of ``mask_col``).""" + return self._mask_col + + @property + def mask_col(self) -> str: + """Dataset column holding the per-sample mask.""" + return self._mask_col + + @property + def image_col(self) -> str: + """Dataset column holding the per-sample image.""" + return self._image_col + + @property + def prompt_mode(self) -> PromptMode: + """Configured prompt mode (``bbox`` / ``point`` / ``text``).""" + return self._prompt_mode + + def __len__(self) -> int: + # Note: len() reflects the candidate window; the coverage filter + # is applied lazily in __getitem__, so iterators should check + # the returned dict for None (skip) or use iter_valid() below. + return len(self._dataset) if self._dataset is not None else 0 + + def __getitem__(self, idx: int) -> dict[str, Any] | None: # type: ignore[override] + """Return one ``(image, gt_mask, prompt)`` triple, or ``None``. + + Returns ``None`` when the sample fails the coverage filter (caller + should skip). + """ + if self._dataset is None: + raise IndexError("Dataset not initialized") + row = self._dataset[idx] + image = row[self._image_col] + mask = row[self._mask_col] + if not isinstance(image, PILImage.Image): + raise TypeError(f"Expected PIL.Image at row {idx}, got {type(image).__name__}") + image = image.convert("RGB") + + gt = self._to_binary_mask(mask, target_size=(image.height, image.width)) + coverage = float(gt.mean()) + if not self._min_coverage <= coverage <= self._max_coverage: + return None + + prompt = self._derive_prompt(gt, row) + return { + "image": image, + "gt_mask": gt, + "prompt": prompt, + "sample_id": f"sample_{idx:04d}", + "coverage": coverage, + } + + def iter_valid(self, max_samples: int | None = None) -> Iterator[dict[str, Any]]: + """Yield only samples that pass the coverage filter. + + Args: + max_samples: stop after this many valid samples (None = all + that pass). + """ + if self._dataset is None: + raise RuntimeError("Dataset not initialized") + cap = max_samples if max_samples is not None else (self._max_samples or len(self._dataset)) + yielded = 0 + for i in range(len(self._dataset)): + if yielded >= cap: + break + sample = self[i] + if sample is None: + continue + yield sample + yielded += 1 + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _to_binary_mask(self, mask: Any, target_size: tuple[int, int]) -> np.ndarray: + """Convert raw mask to ``(H, W) bool`` foreground array. + + ``target_size`` is ``(H, W)`` of the paired image. Some datasets + (notably LIP / human_parsing) store masks at transposed + resolution; if shapes disagree we resize the mask to match the + image. + """ + if not isinstance(mask, PILImage.Image): + raise TypeError(f"Expected PIL.Image for mask, got {type(mask).__name__}") + if mask.size != (target_size[1], target_size[0]): + mask = mask.resize((target_size[1], target_size[0]), PILImage.Resampling.NEAREST) + arr = np.array(mask) + if arr.ndim == 3: + arr = arr[..., 0] + if self._binarize: + return arr > 0 + # Instance mode: caller will handle multi-class. + return arr.astype(np.int32) + + def _derive_prompt(self, gt: np.ndarray, row: dict[str, Any]) -> dict[str, Any]: + """Derive a prompt of the configured mode from the GT mask + row.""" + if self._prompt_mode == "bbox": + x1, y1, x2, y2 = _bbox_from_mask(gt) + return {"bbox": [int(x1), int(y1), int(x2), int(y2)]} + if self._prompt_mode == "point": + x, y = _foreground_point(gt, self._rng) + return {"point": [int(x), int(y)], "label": 1} + # text mode + text = row[self._text_col] # type: ignore[index] + if not isinstance(text, str): + text = str(text) + return {"text": text} + + +# ---------------------------------------------------------------------- +# Geometry helpers (small, pure, easy to unit-test) +# ---------------------------------------------------------------------- + + +def _bbox_from_mask(mask: np.ndarray) -> tuple[int, int, int, int]: + """Tight axis-aligned bbox ``(x1, y1, x2, y2)`` for ``mask``. + + Computed around the True pixels of ``mask`` (bool 2D). + """ + if mask.dtype != np.bool_: + mask = mask.astype(bool) + ys, xs = np.where(mask) + if ys.size == 0: + raise ValueError("Cannot derive bbox from empty mask") + return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) + + +def _foreground_point(mask: np.ndarray, rng: Random) -> tuple[int, int]: + """Single foreground point. + + Returns the mask centroid when it lies inside the mask, else a + random foreground pixel. Falling back to a random pixel handles + concave / disconnected masks (e.g., people with bags) where the + centroid lands in background. + """ + if mask.dtype != np.bool_: + mask = mask.astype(bool) + ys, xs = np.where(mask) + if ys.size == 0: + raise ValueError("Cannot derive point from empty mask") + cy, cx = int(ys.mean()), int(xs.mean()) + if 0 <= cy < mask.shape[0] and 0 <= cx < mask.shape[1] and mask[cy, cx]: + return cx, cy + # Fall back to a random foreground pixel. + i = rng.randrange(ys.size) + return int(xs[i]), int(ys[i]) diff --git a/src/winml/modelkit/eval/evaluate.py b/src/winml/modelkit/eval/evaluate.py index cab10d26e..2882f9521 100644 --- a/src/winml/modelkit/eval/evaluate.py +++ b/src/winml/modelkit/eval/evaluate.py @@ -64,6 +64,8 @@ "winml.modelkit.eval.depth_estimation_evaluator:WinMLDepthEstimationEvaluator", "compare-tensor": "winml.modelkit.eval.tensor_similarity_evaluator:TensorSimilarityEvaluator", + "mask-generation": + "winml.modelkit.eval.mask_generation_evaluator:WinMLMaskGenerationEvaluator", } # fmt: on @@ -172,6 +174,13 @@ def get_evaluator_class(config: WinMLEvaluationConfig) -> type[WinMLEvaluator]: # the legacy `nyu_depth_v2.py` loader script. "revision": "refs/convert/parquet", }, + "mask-generation": { + # LIP-derived multi-class body-part labels, collapsed to a single + # binary foreground/background mask by ``MaskGenerationDataset``. + # Same dataset used by ``scripts/sam3_smoke_eval.py``. + "path": "mattmdjaga/human_parsing_dataset", + "split": "train", + }, } @@ -190,13 +199,26 @@ def to_dict(self) -> dict[str, Any]: } -def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel: - """Load model from ONNX path or HF model ID.""" +def _load_model(config: WinMLEvaluationConfig) -> WinMLPreTrainedModel | None: + """Load model from ONNX path or HF model ID. + + For evaluators that handle their own ORT session construction from a + composite ``role=path`` model dict (currently only + ``mask-generation``), returns ``None`` -- the evaluator reads + ``config.model_path`` directly. Going through ``WinMLAutoModel``'s + composite registry would require registering the model type (e.g., + SAM 3), which is a heavier follow-up; this bypass lets standalone + ONNX exports be evaluated today. + """ from ..models import WinMLAutoModel if config.model_id is None: raise ValueError("model_id is required.") + if isinstance(config.model_path, dict) and config.task == "mask-generation": + # Evaluator-driven session loading; skip WinMLAutoModel entirely. + return None + if config.model_path is not None: # Pre-built ONNX: precision is already baked into the model and is # ignored here (mirrors winml perf's ONNX path). @@ -306,7 +328,11 @@ def evaluate(config: WinMLEvaluationConfig) -> EvalResult: cls = get_evaluator_class(config) try: console.print("[bold]Loading dataset and evaluating...[/bold]") - task_evaluator = cls(config, model) + # ``model`` is ``None`` for composite evaluators that load ORT + # sessions directly from ``config.model_path`` (currently only + # mask-generation). Type-checker can't follow the per-task + # invariant, so suppress here at the unified call site. + task_evaluator = cls(config, model) # type: ignore[arg-type] metrics = task_evaluator.compute() except DatasetValidationError as error: raise ValueError( diff --git a/src/winml/modelkit/eval/mask_generation_evaluator.py b/src/winml/modelkit/eval/mask_generation_evaluator.py new file mode 100644 index 000000000..637932298 --- /dev/null +++ b/src/winml/modelkit/eval/mask_generation_evaluator.py @@ -0,0 +1,593 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Promptable mask-generation evaluator for SAM-family ONNX models. + +Does *not* go through HF's ``pipeline`` / ``evaluate`` libraries because: + +1. HF's ``mask-generation`` task is a high-level wrapper around the + *full* PyTorch SAM model -- it isn't compatible with raw ORT sessions. +2. Mask-generation here is *composite*: encoder + decoder must be + orchestrated manually (the same as :file:`scripts/sam3_smoke_eval.py` + does informally). The base :class:`WinMLEvaluator`'s single-model + pipeline assumption doesn't fit. + +The evaluator instead drives two ORT sessions directly: + +* **image-encoder** -- consumes ``pixel_values``, emits 3 multi-scale + image embeddings (``image_embeddings.0/1/2`` for SAM 3). +* **prompt-decoder** -- consumes a prompt (bbox or point) plus the + embeddings, emits up to 3 candidate masks plus their predicted IoU. + +For each sample we derive the prompt from the GT mask (so we're measuring +the model's ability to *trace boundaries* given a known prompt -- the +standard SAM eval setup), pick the highest predicted-IoU mask, map it +back to the original image resolution, and accumulate mIoU + Dice via +:class:`~winml.modelkit.eval.metrics.BinarySegmentationMetric`. + +Text-prompt mode is intentionally not implemented yet -- the publicly +cached SAM 3 ONNX decoder does not accept a text input port (see +``input_points``/``input_labels``/``input_boxes`` only). Text-concept +prompting requires a separate text-encoder ONNX that is not yet on the +Hub; tracked as a follow-up. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +from .base_evaluator import WinMLEvaluator + + +if TYPE_CHECKING: + from PIL import Image + from transformers.pipelines.base import Pipeline + + from ..models.winml.base import WinMLPreTrainedModel + from .config import WinMLEvaluationConfig + + +logger = logging.getLogger(__name__) + + +# ---------------------------------------------------------------------- +# Per-family preprocessing profiles. +# +# The ONNX-community SAM 2.1 and SAM 3 exports share the *decoder* I/O +# schema (``input_points``/``input_labels``/``input_boxes`` -> +# ``iou_scores``/``pred_masks``/``object_score_logits``) and the encoder +# output names (``image_embeddings.{0,1,2}``); only the *image* side +# differs: +# +# * SAM 3 Tracker: direct bilinear resize to 1008x1008 (no padding) with +# mean/std = 0.5/0.5 (preprocessor_config.json on +# onnx-community/sam3-tracker-ONNX). +# * SAM 2.1: longest-side bilinear resize to 1024 with zero-pad to a +# 1024x1024 square; ImageNet mean/std. Matches the SAM-paper +# convention and ``onnx-community/sam2.1-hiera-small-ONNX``. +# +# A profile bundles those constants together. The active profile is +# resolved per-evaluator from the encoder's static ``pixel_values`` shape +# (falling back to a ``model_id`` substring heuristic, then SAM 3). +# ---------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _MaskGenProfile: + """Per-family preprocessing constants for a SAM-style ONNX export.""" + + name: str + target_size: int + mean: tuple[float, float, float] + std: tuple[float, float, float] + # "direct" -> resize per-axis to target_size x target_size + # "longest_side_pad" -> longest-side resize, zero-pad bottom/right to square + resize_mode: str + + +SAM3_PROFILE = _MaskGenProfile( + name="sam3", + target_size=1008, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + resize_mode="direct", +) + + +SAM2_PROFILE = _MaskGenProfile( + name="sam2", + target_size=1024, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + resize_mode="longest_side_pad", +) + + +# Back-compat module-level SAM 3 constant (preserved so existing imports +# from tests/scripts keep working unchanged). +_TARGET_SIZE = SAM3_PROFILE.target_size + +__all__ = ["_TARGET_SIZE", "WinMLMaskGenerationEvaluator"] + + +class WinMLMaskGenerationEvaluator(WinMLEvaluator): + """Evaluator for SAM-style promptable mask generation. + + Constructor accepts the standard ``(config, model)`` signature so the + registry dispatch in :mod:`~winml.modelkit.eval.evaluate` works + unmodified. The ``model`` argument may be ``None`` -- this evaluator + reads ``config.model_path`` (a ``dict[str, str]`` mapping + ``image-encoder`` / ``prompt-decoder`` to ONNX file paths) and + constructs its own ORT sessions, bypassing the + ``WinMLAutoModel`` composite-registry path. + """ + + # Required sub-model role names (must appear as keys in + # ``config.model_path`` when it is a dict). + _ENCODER_ROLE = "image-encoder" + _DECODER_ROLE = "prompt-decoder" + + def __init__( + self, + config: WinMLEvaluationConfig, + model: WinMLPreTrainedModel | None, + ) -> None: + if not isinstance(config.model_path, dict): + raise TypeError( + "Mask-generation evaluation requires composite `-m role=path` " + "model arguments. Pass --model image-encoder= and " + f"--model {self._DECODER_ROLE}=.", + ) + for role in (self._ENCODER_ROLE, self._DECODER_ROLE): + if role not in config.model_path: + raise ValueError( + f"Missing required `-m {role}=` argument. " + f"Got roles: {sorted(config.model_path)}.", + ) + + # Pre-seed the attributes that ``prepare_data`` (invoked from the + # base ``WinMLEvaluator.__init__``) depends on. ``self.config`` + # is needed by ``_load_sessions`` (it reads ``config.model_path`` + # and ``config.ep``) so we set it before calling super. + self.config = config + mapping = config.dataset.columns_mapping or {} + self._prompt_mode: str = mapping.get("prompt_mode", "bbox") + self._enc_sess, self._dec_sess = self._load_sessions() + # Pick the per-family preprocessing profile from the encoder's + # static input shape (falling back to a model_id heuristic, then + # SAM 3). Threaded through preprocess + postprocess in _predict. + self._profile = _resolve_profile(self.config, self._enc_sess) + logger.info( + "Mask-generation profile: %s (target=%d, mean=%s, std=%s, resize=%s)", + self._profile.name, self._profile.target_size, + self._profile.mean, self._profile.std, self._profile.resize_mode, + ) + + # Defer the rest of attribute setup (``self.model``, ``self.data``, + # ``self.pipe``) to the base class so we satisfy the evaluator + # contract and CodeQL's ``py/missing-call-to-init`` rule. The + # base ``prepare_pipeline`` is overridden here to return ``None``, + # so it is safe to call from ``WinMLEvaluator.__init__``. + super().__init__(config, model) # type: ignore[arg-type] + + # ------------------------------------------------------------------ + # WinMLEvaluator overrides + # ------------------------------------------------------------------ + + def prepare_data(self) -> Any: + """Build a :class:`MaskGenerationDataset` from ``config.dataset``.""" + from ..datasets.mask_generation import MaskGenerationDataset + + ds = self.config.dataset + mapping = ds.columns_mapping or {} + + # ``model_name`` is required by ``BaseTaskDataset`` for API parity + # with other datasets; mask-generation does not actually consult + # the model's image processor. Fall back to a safe sentinel when + # ``model_id`` is unset (composite mask-gen sometimes runs without + # a single canonical model_id). + return MaskGenerationDataset( + model_name=self.config.model_id or "sam-mask-generation", + dataset_name=ds.path or MaskGenerationDataset.DEFAULT_DATASET, + data_split=ds.split or MaskGenerationDataset.DEFAULT_SPLIT, + max_samples=ds.samples, + prompt_mode=self._prompt_mode, # type: ignore[arg-type] + text_col=mapping.get("text_column"), + seed=ds.seed, + ) + + def prepare_pipeline(self) -> Pipeline | None: # type: ignore[override] + """No HF pipeline -- ORT sessions are driven directly in ``compute``.""" + return None + + def compute(self) -> dict[str, Any]: + """Run mask-generation eval and return mIoU / Dice.""" + from .metrics import BinarySegmentationMetric + + metric = BinarySegmentationMetric() + # ``self.data`` length is the *over-fetch* candidate window the + # dataset built to absorb coverage-filter drops; the user's actual + # requested count lives in ``config.dataset.samples``. Iterating + # past that would silently inflate cost (we saw 23 evaluations for + # ``--samples 3`` before this cap). + requested = self.config.dataset.samples + logger.info( + "Mask-generation eval: requesting %d samples (candidate window=%d)", + requested, len(self.data), + ) + + processed = 0 + for sample in self.data.iter_valid(max_samples=requested): + try: + pred_mask = self._predict(sample) + except Exception as e: + logger.warning( + "Skipping sample %s: prediction failed (%s)", + sample.get("sample_id", "?"), + e, + ) + continue + metric.update(pred_mask, sample["gt_mask"]) + processed += 1 + if processed % 10 == 0: + logger.info(" processed %d samples", processed) + + return metric.compute() + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _load_sessions(self) -> tuple[Any, Any]: + """Construct ORT sessions for the encoder + decoder.""" + import onnxruntime as ort + + paths = self.config.model_path + assert isinstance(paths, dict) # already validated in __init__ + + providers, provider_options = _build_providers(self.config.ep or "cpu") + logger.info( + "Creating ORT sessions for mask-generation (providers=%s)", providers, + ) + + sess_opts = ort.SessionOptions() + sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + enc = ort.InferenceSession( + paths[self._ENCODER_ROLE], + sess_options=sess_opts, + providers=providers, + provider_options=provider_options, + ) + dec = ort.InferenceSession( + paths[self._DECODER_ROLE], + sess_options=sess_opts, + providers=providers, + provider_options=provider_options, + ) + logger.info(" encoder providers: %s", enc.get_providers()) + logger.info(" decoder providers: %s", dec.get_providers()) + return enc, dec + + def _predict(self, sample: dict[str, Any]) -> np.ndarray: + """Run encoder + decoder for one sample, return binary mask.""" + image = sample["image"] + gt = sample["gt_mask"] + prompt = sample["prompt"] + + pixel_values, scale_x, scale_y, new_h, new_w = _preprocess_for_profile( + self._profile, image, + ) + enc_out = self._enc_sess.run(None, {"pixel_values": pixel_values}) + enc_names = [o.name for o in self._enc_sess.get_outputs()] + emb = dict(zip(enc_names, enc_out, strict=True)) + + dec_inputs = _build_decoder_inputs( + prompt=prompt, + prompt_mode=self._prompt_mode, + scale_x=scale_x, + scale_y=scale_y, + emb=emb, + ) + iou_scores, pred_masks, _ = self._dec_sess.run( + ["iou_scores", "pred_masks", "object_score_logits"], dec_inputs, + ) + + # pred_masks: (1, num_prompts, num_masks, H, W); pick the + # best-scoring of the candidate masks for the first prompt. + iou_preds = iou_scores[0, 0] # (num_masks,) + best_idx = int(iou_preds.argmax()) + best_low_res = pred_masks[0, 0, best_idx] + + return _postprocess_for_profile( + self._profile, + best_low_res, + orig_h=gt.shape[0], + orig_w=gt.shape[1], + new_h=new_h, + new_w=new_w, + ) + + +# ---------------------------------------------------------------------- +# Pure helpers (kept at module scope so they're easy to test in isolation) +# ---------------------------------------------------------------------- + + +def _resolve_profile( + config: WinMLEvaluationConfig, + enc_sess: Any, +) -> _MaskGenProfile: + """Pick the per-family preprocessing profile for the active model. + + Resolution priority: + + 1. **Encoder static input shape**. If the encoder's ``pixel_values`` + has a static last dim that matches a registered profile's + ``target_size`` (e.g. 1024 -> SAM 2, 1008 -> SAM 3), use that. + This is the most reliable signal because it comes from the actual + ONNX export. + 2. **``config.model_id`` substring**. Falls back to matching common + family identifiers (``sam2`` / ``sam-2`` -> SAM 2; + ``sam3`` / ``sam-3`` -> SAM 3) when the encoder shape is dynamic. + 3. **Default SAM 3** -- preserves the original evaluator behaviour. + """ + known = (SAM3_PROFILE, SAM2_PROFILE) + + try: + shape = enc_sess.get_inputs()[0].shape + except Exception: + shape = [] + if len(shape) >= 4 and isinstance(shape[-1], int): + for prof in known: + if shape[-1] == prof.target_size: + return prof + + mid = (config.model_id or "").lower() + if "sam2" in mid or "sam-2" in mid: + return SAM2_PROFILE + if "sam3" in mid or "sam-3" in mid: + return SAM3_PROFILE + + return SAM3_PROFILE + + +def _preprocess_for_profile( + profile: _MaskGenProfile, + img: Image.Image, +) -> tuple[np.ndarray, float, float, int, int]: + """Profile-driven preprocessing. + + Returns ``(pixel_values, scale_x, scale_y, new_h, new_w)``: + + * ``pixel_values`` -- ``(1, 3, T, T)`` fp32 NCHW. + * ``scale_x`` / ``scale_y`` -- multiply original pixel coords to map + into encoder-input space (so prompts can be transformed for the + decoder). For ``longest_side_pad`` they are equal (single uniform + scale); for ``direct`` they differ per axis. + * ``new_h`` / ``new_w`` -- post-resize, pre-pad dimensions; needed by + the postprocess step to undo the padding before resizing to the + original image. For ``direct`` they equal ``T``. + """ + from PIL import Image as PILImage + + img = img.convert("RGB") + orig_w, orig_h = img.size + target = profile.target_size + mean = np.asarray(profile.mean, dtype=np.float32) + std = np.asarray(profile.std, dtype=np.float32) + + if profile.resize_mode == "direct": + scale_x = target / orig_w + scale_y = target / orig_h + resized = img.resize((target, target), PILImage.Resampling.BILINEAR) + arr = np.asarray(resized, dtype=np.float32) / 255.0 + arr = (arr - mean) / std + new_h = target + new_w = target + elif profile.resize_mode == "longest_side_pad": + # SAM 2.1 convention: longest-side resize preserving aspect ratio, + # then zero-pad bottom/right to a square. Prompts use a single + # uniform scale (``scale_x == scale_y``). + scale = target / max(orig_h, orig_w) + new_h = round(orig_h * scale) + new_w = round(orig_w * scale) + resized = img.resize((new_w, new_h), PILImage.Resampling.BILINEAR) + arr = np.asarray(resized, dtype=np.float32) / 255.0 + arr = (arr - mean) / std + pad_h = target - new_h + pad_w = target - new_w + arr = np.pad(arr, ((0, pad_h), (0, pad_w), (0, 0)), mode="constant") + scale_x = scale + scale_y = scale + else: + raise ValueError( + f"Unsupported resize_mode={profile.resize_mode!r} for profile " + f"{profile.name!r}; expected 'direct' or 'longest_side_pad'.", + ) + + pixel_values = arr.transpose(2, 0, 1)[None, ...] + return pixel_values.astype(np.float32), scale_x, scale_y, new_h, new_w + + +def _postprocess_for_profile( + profile: _MaskGenProfile, + pred_mask: np.ndarray, + orig_h: int, + orig_w: int, + new_h: int, + new_w: int, +) -> np.ndarray: + """Profile-driven postprocessing. + + * ``direct`` -- low-res mask maps 1:1 to the full original image; a + single resize is enough. + * ``longest_side_pad`` -- up-sample the low-res mask to the encoder + input size, crop off the zero-pad region (back to ``new_h x new_w``), + then resize to the original image dimensions. + """ + from PIL import Image as PILImage + + if profile.resize_mode == "direct": + pil = PILImage.fromarray(pred_mask.astype(np.float32)) + final = pil.resize((orig_w, orig_h), PILImage.Resampling.BILINEAR) + return np.asarray(final, dtype=np.float32) > 0 + + if profile.resize_mode == "longest_side_pad": + target = profile.target_size + pil = PILImage.fromarray(pred_mask.astype(np.float32)) + up = pil.resize((target, target), PILImage.Resampling.BILINEAR) + up_arr = np.asarray(up, dtype=np.float32) + cropped = up_arr[:new_h, :new_w] + pil2 = PILImage.fromarray(cropped) + final = pil2.resize((orig_w, orig_h), PILImage.Resampling.BILINEAR) + return np.asarray(final, dtype=np.float32) > 0 + + raise ValueError( + f"Unsupported resize_mode={profile.resize_mode!r} for profile " + f"{profile.name!r}; expected 'direct' or 'longest_side_pad'.", + ) + + +# ---------------------------------------------------------------------- +# Back-compat SAM 3 wrappers. Preserved so existing imports / tests that +# call ``_preprocess_image(img)`` -> 3-tuple keep working unchanged. +# ---------------------------------------------------------------------- + + +def _preprocess_image( + img: Image.Image, +) -> tuple[np.ndarray, float, float]: + """SAM 3 preprocessing wrapper -- direct resize to 1008x1008, mean=std=0.5. + + Returns the original 3-tuple ``(pixel_values, scale_x, scale_y)``; + profile-aware callers should use :func:`_preprocess_for_profile`. + """ + pv, sx, sy, _new_h, _new_w = _preprocess_for_profile(SAM3_PROFILE, img) + return pv, sx, sy + + +def _postprocess_mask( + pred_mask: np.ndarray, + orig_h: int, + orig_w: int, +) -> np.ndarray: + """SAM 3 postprocessing wrapper -- direct resize back to original.""" + return _postprocess_for_profile( + SAM3_PROFILE, + pred_mask, + orig_h=orig_h, + orig_w=orig_w, + new_h=SAM3_PROFILE.target_size, + new_w=SAM3_PROFILE.target_size, + ) + + +def _build_decoder_inputs( + prompt: dict[str, Any], + prompt_mode: str, + scale_x: float, + scale_y: float, + emb: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + """Assemble the decoder feed dict for bbox or point prompts. + + See decoder signature: + + * ``input_points``: ``(batch=1, 1, num_points, 2)`` fp32 in resized + (1008) coordinates. + * ``input_labels``: ``(batch=1, 1, num_points)`` int64 (1=foreground, + 0=background, -1=padding/null). + * ``input_boxes``: ``(batch=1, num_boxes, 4)`` fp32 in resized coords, + ``[x0, y0, x1, y1]`` order. + + For *point* prompts we still must satisfy ``input_boxes``; ORT does + not accept a zero-size box dim across all builds, so we pass a sentinel + full-image box (rejected by SAM's prompt encoder via the all-foreground + point) plus a single fg point. This matches SAM 1/2 reference impls. + """ + if prompt_mode == "bbox": + x0, y0, x1, y1 = prompt["bbox"] + box = np.array( + [[[x0 * scale_x, y0 * scale_y, x1 * scale_x, y1 * scale_y]]], + dtype=np.float32, + ) # (1, 1, 4) + points: np.ndarray = np.zeros((1, 1, 0, 2), dtype=np.float32) + labels = np.zeros((1, 1, 0), dtype=np.int64) + elif prompt_mode == "point": + px, py = prompt["point"] + points = np.array( + [[[[px * scale_x, py * scale_y]]]], + dtype=np.float32, + ) # (1, 1, 1, 2) + labels = np.ones((1, 1, 1), dtype=np.int64) # 1 = foreground + # Empty box (0 num_boxes). If a future runtime build rejects + # zero-size dims here, switch to a [0, 0, _TARGET_SIZE, _TARGET_SIZE] + # sentinel and rely on the point to override. + box = np.zeros((1, 0, 4), dtype=np.float32) + else: + raise ValueError( + f"Unsupported prompt_mode={prompt_mode!r} (expected 'bbox' or 'point'). " + "Text-prompt mode is not yet supported for SAM 3 ONNX -- the cached " + "decoder export has no text input port; tracked as a follow-up.", + ) + + return { + "input_points": points, + "input_labels": labels, + "input_boxes": box, + "image_embeddings.0": emb["image_embeddings.0"], + "image_embeddings.1": emb["image_embeddings.1"], + "image_embeddings.2": emb["image_embeddings.2"], + } + + +def _build_providers(ep: str) -> tuple[list[str], list[dict[str, Any]]]: + """Map ``--ep`` to ORT provider list + per-provider options. + + Mirrors :file:`scripts/sam3_smoke_eval.py`. Falls back to CPU with a + warning if the requested EP is not present in this ORT install. + """ + import onnxruntime as ort + + providers_map = { + "cpu": ["CPUExecutionProvider"], + "dml": ["DmlExecutionProvider", "CPUExecutionProvider"], + "vitisai": ["VitisAIExecutionProvider", "CPUExecutionProvider"], + "directml": ["DmlExecutionProvider", "CPUExecutionProvider"], + } + providers = providers_map.get(ep.lower(), ["CPUExecutionProvider"]) + avail = set(ort.get_available_providers()) + if providers[0] not in avail: + logger.warning( + "Requested EP %r not available (have %s); falling back to CPU.", + providers[0], + sorted(avail), + ) + providers = ["CPUExecutionProvider"] + + provider_options: list[dict[str, Any]] = [{} for _ in providers] + if providers[0] == "VitisAIExecutionProvider": + install_dir = os.environ.get("RYZEN_AI_INSTALLATION_PATH", "") + xclbin = Path(install_dir) / "voe-4.0-win_amd64" / "xclbins" / "phoenix" / "4x4.xclbin" + if install_dir and xclbin.exists(): + provider_options[0] = { + "target": "X1", + "xlnx_enable_py3_round": 0, + "xclbin": str(xclbin), + } + else: + logger.warning( + "RYZEN_AI_INSTALLATION_PATH unset or xclbin missing; VitisAI may " + "fall back to CPU. Activate the Ryzen AI conda env first.", + ) + + return providers, provider_options diff --git a/src/winml/modelkit/eval/metrics/__init__.py b/src/winml/modelkit/eval/metrics/__init__.py index 2695c3d84..ce2a4094b 100644 --- a/src/winml/modelkit/eval/metrics/__init__.py +++ b/src/winml/modelkit/eval/metrics/__init__.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: + from .binary_segmentation import BinarySegmentationMetric from .classification import ClassificationMetric from .depth import DepthMetric from .knn_accuracy import KNNAccuracyMetric @@ -26,6 +27,7 @@ # this package does not pull in numpy / scipy / torch / torchmetrics for callers # that do not actually use the metric in question. _LAZY_ATTRS: dict[str, str] = { + "BinarySegmentationMetric": ".binary_segmentation:BinarySegmentationMetric", "ClassificationMetric": ".classification:ClassificationMetric", "DepthMetric": ".depth:DepthMetric", "IGNORE_INDEX": ".mean_iou:IGNORE_INDEX", @@ -56,6 +58,7 @@ def __dir__() -> list[str]: __all__ = [ "IGNORE_INDEX", + "BinarySegmentationMetric", "ClassificationMetric", "DepthMetric", "KNNAccuracyMetric", diff --git a/src/winml/modelkit/eval/metrics/binary_segmentation.py b/src/winml/modelkit/eval/metrics/binary_segmentation.py new file mode 100644 index 000000000..3db545865 --- /dev/null +++ b/src/winml/modelkit/eval/metrics/binary_segmentation.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Binary-segmentation metrics for promptable mask-generation. + +Unlike :class:`MeanIoUMetric` (semantic, multi-class, pixel-level), this +metric operates on a *per-instance* binary prediction / GT pair (each +sample is one prompted mask). It computes: + +* **mIoU**: arithmetic mean of per-sample Intersection-over-Union. +* **Dice**: arithmetic mean of per-sample Dice coefficient + (2 * |P ∩ G| / (|P| + |G|)). + +Both are dataset-level (macro) means -- one number per sample, then +averaged. This matches the canonical SAM / mask-generation reporting +convention. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + + +class BinarySegmentationMetric: + """Per-instance mIoU + Dice for promptable mask generation. + + Incremental: call :meth:`update` per sample, then :meth:`compute` once. + Stores only running sums (O(1) memory per added sample). + """ + + def __init__(self) -> None: + self._iou_sum = 0.0 + self._dice_sum = 0.0 + self._count = 0 + # Track empty-GT samples separately so a degenerate dataset doesn't + # silently inflate the score. An empty GT can only score IoU/Dice + # of either 0 (any positive prediction) or undefined (empty pred), + # so we exclude them and surface the skip count to the caller. + self._skipped = 0 + + def update(self, pred: np.ndarray, gt: np.ndarray) -> None: + """Add one (pred, gt) pair to the running totals. + + Both must be 2-D and the same shape. Any nonzero value is treated + as foreground. Empty-GT samples are counted in ``skipped`` and do + not contribute to mIoU / Dice. + """ + if pred.shape != gt.shape: + raise ValueError( + f"pred shape {pred.shape} != gt shape {gt.shape}", + ) + pred_b = pred.astype(bool) + gt_b = gt.astype(bool) + gt_pos = int(gt_b.sum()) + if gt_pos == 0: + self._skipped += 1 + return + + inter = int(np.logical_and(pred_b, gt_b).sum()) + union = int(np.logical_or(pred_b, gt_b).sum()) + pred_pos = int(pred_b.sum()) + + iou = inter / union if union > 0 else 0.0 + dice = (2.0 * inter) / (pred_pos + gt_pos) if (pred_pos + gt_pos) > 0 else 0.0 + + self._iou_sum += iou + self._dice_sum += dice + self._count += 1 + + def compute(self) -> dict[str, Any]: + """Return the aggregated metrics. + + Always includes ``num_samples`` and ``num_skipped`` so the caller + can detect when too many samples were filtered out. + """ + if self._count == 0: + return { + "mIoU": 0.0, + "dice": 0.0, + "num_samples": 0, + "num_skipped": self._skipped, + } + return { + "mIoU": self._iou_sum / self._count, + "dice": self._dice_sum / self._count, + "num_samples": self._count, + "num_skipped": self._skipped, + } diff --git a/src/winml/modelkit/inference/engine.py b/src/winml/modelkit/inference/engine.py index 61002e9cf..9236f868e 100644 --- a/src/winml/modelkit/inference/engine.py +++ b/src/winml/modelkit/inference/engine.py @@ -324,9 +324,9 @@ def load( """ # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx path thereafter. - from ..loader import maybe_resolve_hf_onnx_path + from ..utils.model_input import resolve_model_input - model_path = maybe_resolve_hf_onnx_path(str(model_path)) or str(model_path) + model_path = resolve_model_input(str(model_path)).local_path or str(model_path) self._model_path = str(model_path) self._ep = ep @@ -405,9 +405,9 @@ def load_schema_only( """ # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx path thereafter. - from ..loader import maybe_resolve_hf_onnx_path + from ..utils.model_input import resolve_model_input - model_path = maybe_resolve_hf_onnx_path(str(model_path)) or str(model_path) + model_path = resolve_model_input(str(model_path)).local_path or str(model_path) self._model_path = str(model_path) self._device = device diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index ce76e5c24..265d485ec 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -26,7 +26,7 @@ """ from .config import WinMLLoaderConfig, resolve_loader_config -from .onnx_hub import is_hf_onnx_path, maybe_resolve_hf_onnx_path, resolve_hf_onnx_path +from .onnx_hub import resolve_hf_onnx_path from .task import ( HF_TASK_DEFAULTS, KNOWN_TASKS, @@ -49,9 +49,7 @@ "detect_task", "get_supported_tasks", "get_task_abbrev", - "is_hf_onnx_path", "load_hf_model", - "maybe_resolve_hf_onnx_path", "normalize_task", "resolve_hf_model_class", "resolve_hf_onnx_path", diff --git a/src/winml/modelkit/loader/onnx_hub.py b/src/winml/modelkit/loader/onnx_hub.py index 900ff2c8a..e6c3c50b5 100644 --- a/src/winml/modelkit/loader/onnx_hub.py +++ b/src/winml/modelkit/loader/onnx_hub.py @@ -4,25 +4,16 @@ # -------------------------------------------------------------------------- """Download pre-exported ONNX files hosted on the HuggingFace Hub. -ModelKit accepts two model input forms today: a HuggingFace model ID -(``org/name``) for the standard ``transformers`` + ``optimum-onnx`` export -path, and a local ``.onnx`` file path for the Scenario D pipeline in -``modelkit.build.build_onnx_model``. - -This module recognizes a third form -- a path-style reference to a -pre-exported ONNX artifact in a Hub repo, e.g.:: - - onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx - -The first two ``/``-separated components are interpreted as the repo ID; -everything that follows is the file path inside the repo. The file is -downloaded once via ``huggingface_hub.hf_hub_download`` and the local -path is then handed to the existing Scenario D code path. This is the -supported route for models like SAM 3 whose ``transformers`` requirement -exceeds what ``optimum-onnx`` currently pins. - -Any sibling ``.onnx_data`` external-data sidecar is fetched -best-effort so the ONNX loader can resolve external initializers. +This module is the **download** half of Hub-hosted ONNX support. +Classification (deciding whether a ``-m/--model`` value is a Hub ONNX +ref, a local ``.onnx`` file, an HF model ID, or a build directory) lives +in :mod:`winml.modelkit.utils.model_input` and is the single entry point +that all CLI commands and library APIs should go through. + +The function exposed here, :func:`resolve_hf_onnx_path`, is called by +``resolve_model_input`` for the ``hub_onnx`` case and downloads the +``.onnx`` file (plus any ``.onnx_data`` sidecar) via +``huggingface_hub.hf_hub_download``. """ from __future__ import annotations @@ -34,25 +25,6 @@ logger = logging.getLogger(__name__) -def is_hf_onnx_path(model_id: str | None) -> bool: - """Check whether ``model_id`` is a Hub-style reference to a pre-exported ONNX file. - - Returns True only when the value has at least three ``/``-separated - components, ends with ``.onnx``, and does not point at an existing - local file or directory. Local paths always win over the Hub - interpretation so users can keep working with paths that happen to - look like repo IDs. - """ - if not model_id: - return False - if not model_id.endswith(".onnx"): - return False - if Path(model_id).exists(): - return False - parts = [p for p in model_id.split("/") if p] - return len(parts) >= 3 - - def resolve_hf_onnx_path( model_id: str, *, @@ -80,6 +52,11 @@ def resolve_hf_onnx_path( Raises: ValueError: If ``model_id`` does not have at least three ``/``-separated components. + FileNotFoundError: If the referenced ``.onnx`` file does not exist in + the repo. The error message lists the ``.onnx`` files that *are* + present so the user can correct the path. + huggingface_hub.utils.RepositoryNotFoundError: If the repo itself does + not exist (re-raised unchanged). """ from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError @@ -87,15 +64,27 @@ def resolve_hf_onnx_path( repo_id, filename = _split_hf_onnx_path(model_id) logger.info("Downloading ONNX from Hub: repo=%s file=%s", repo_id, filename) - local_path = Path( - hf_hub_download( - repo_id=repo_id, - filename=filename, - revision=revision, - cache_dir=cache_dir, - token=token, + try: + local_path = Path( + hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + token=token, + ) + ) + except EntryNotFoundError as e: + # The repo exists but ``filename`` does not. Surface the available + # ``.onnx`` files so the user can pick the right one without leaving + # the terminal. Re-raise as ``FileNotFoundError`` so callers that + # already handle local-file-missing errors get a consistent type. + hint = _format_available_onnx_files( + repo_id, revision=revision, token=token ) - ) + raise FileNotFoundError( + f"ONNX file '{filename}' not found in Hub repo '{repo_id}'.\n{hint}" + ) from e # External-data sidecars (used for >2 GiB models) live next to the .onnx # file with a ``.onnx_data`` suffix. The main download above just @@ -144,44 +133,43 @@ def _split_hf_onnx_path(model_id: str) -> tuple[str, str]: return "/".join(parts[:2]), "/".join(parts[2:]) -def maybe_resolve_hf_onnx_path( - model_id: str | None, +def _format_available_onnx_files( + repo_id: str, *, revision: str | None = None, - cache_dir: str | Path | None = None, token: str | bool | None = None, -) -> str | None: - """Resolve ``model_id`` to a local ONNX path if it is a Hub ONNX reference. - - Convenience wrapper that combines :func:`is_hf_onnx_path` and - :func:`resolve_hf_onnx_path`. Non-Hub inputs (HF model IDs, local - paths, ``None``) are returned unchanged so callers can use this as a - transparent normalization step before dispatching to existing code. +) -> str: + """Build a human-readable hint listing ``.onnx`` files in a Hub repo. + + Used to enrich ``EntryNotFoundError`` messages so users who guessed the + wrong filename can see the available options without leaving the + terminal. Best-effort: if listing fails for any reason (network, + auth, gated repo) we return a generic fallback hint instead of + masking the original error. + """ + from huggingface_hub import list_repo_files - Args: - model_id: HF model ID, local path, Hub ONNX ref, or ``None``. - revision: Optional Hub revision (forwarded when downloading). - cache_dir: Optional cache override (forwarded when downloading). - token: Optional auth token (forwarded when downloading). + try: + files = list_repo_files(repo_id, revision=revision, token=token) + except Exception as list_err: + logger.debug("Could not list files for %s: %s", repo_id, list_err) + return ( + f"Could not list available .onnx files in '{repo_id}' " + f"(see https://huggingface.co/{repo_id}/tree/main)." + ) - Returns: - Local ``.onnx`` path string when ``model_id`` was a Hub ref; the - original ``model_id`` otherwise. - """ - if not is_hf_onnx_path(model_id): - return model_id - return str( - resolve_hf_onnx_path( - model_id, # type: ignore[arg-type] # is_hf_onnx_path() rejects None - revision=revision, - cache_dir=cache_dir, - token=token, + onnx_files = sorted(f for f in files if f.lower().endswith(".onnx")) + if not onnx_files: + return ( + f"No .onnx files were found in '{repo_id}'. " + f"This repo may not host pre-exported ONNX weights; " + f"see https://huggingface.co/{repo_id}/tree/main." ) - ) + + listing = "\n".join(f" - {repo_id}/{f}" for f in onnx_files) + return f"Available .onnx files in '{repo_id}':\n{listing}" __all__ = [ - "is_hf_onnx_path", - "maybe_resolve_hf_onnx_path", "resolve_hf_onnx_path", ] diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 5c7a6564d..17de13b0d 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -313,11 +313,9 @@ def from_pretrained( # Hub-hosted ONNX (e.g. ``onnx-community/sam3-tracker-ONNX/onnx/...``) # is downloaded once and treated as a local .onnx path thereafter. - from ..loader import maybe_resolve_hf_onnx_path + from ..utils.model_input import resolve_model_input - # ``model_id`` is already coerced to ``str`` above, so the helper's - # ``str | None`` return type is always ``str`` here. - model_id = maybe_resolve_hf_onnx_path(model_id) # type: ignore[assignment] + model_id = resolve_model_input(model_id).local_path or model_id # ===================================================================== # ONNX FAST PATH -- skip HF loading and export when given an .onnx file diff --git a/src/winml/modelkit/onnx/io.py b/src/winml/modelkit/onnx/io.py index 56013b59e..06384c29f 100644 --- a/src/winml/modelkit/onnx/io.py +++ b/src/winml/modelkit/onnx/io.py @@ -222,6 +222,7 @@ def get_io_config( io_config: dict[str, Any] = { "input_names": [], "input_shapes": [], + "input_symbolic_shapes": [], "input_types": [], "output_names": [], "output_shapes": [], @@ -243,18 +244,25 @@ def get_io_config( # Extract dtype dtype = ONNX_ELEM_TYPE_TO_NUMPY.get(tensor_type.elem_type, np.dtype("float32")) - # Extract shape (None for dynamic dims) - shape = [] + # Extract shape (None for dynamic dims) and capture symbolic + # dim_param strings in a parallel list so downstream consumers + # can resolve dynamic dims by their declared symbolic name. + shape: list[Any] = [] + symbolic_shape: list[Any] = [] if tensor_type.HasField("shape"): for dim in tensor_type.shape.dim: if dim.HasField("dim_value"): shape.append(dim.dim_value) + symbolic_shape.append(dim.dim_value) else: shape.append(None) + symbolic_shape.append(dim.dim_param or None) io_config[f"{prefix}_names"].append(io.name) io_config[f"{prefix}_shapes"].append(shape) io_config[f"{prefix}_types"].append(dtype) + if prefix == "input": + io_config["input_symbolic_shapes"].append(symbolic_shape) # Enhance with value ranges from winml.io.inputs metadata for prop in model.metadata_props: diff --git a/src/winml/modelkit/sysinfo/device.py b/src/winml/modelkit/sysinfo/device.py index 6f6fdb5b9..e97366d8d 100644 --- a/src/winml/modelkit/sysinfo/device.py +++ b/src/winml/modelkit/sysinfo/device.py @@ -91,6 +91,15 @@ def get_device_ep_map() -> dict[str, list[EPName]]: return {device: list(eps) for device, eps in _DEVICE_EP_MAP.items()} +# EPs that exist in ``onnxruntime.get_available_providers()`` but are not yet +# exposed via the new ``get_ep_devices()``/AutoEP machinery. Mapped to the +# canonical device they target so they can still be selected via the legacy +# ``SessionOptions.add_provider`` code path. +_LEGACY_EP_DEVICE_FALLBACK: dict[EPName, str] = { + "VitisAIExecutionProvider": "npu", # AMD Phoenix/Strix XDNA NPU +} + + @functools.lru_cache(maxsize=1) def _get_device_ep_map_from_ort() -> dict[str, tuple[EPName, ...]]: """Return device -> EPs targeting it, derived from registered ORT EP devices. @@ -100,6 +109,11 @@ def _get_device_ep_map_from_ort() -> dict[str, tuple[EPName, ...]]: :func:`_get_available_devices`, :func:`resolve_device`, and :func:`resolve_eps`. Cached for the process lifetime since hardware/EPs do not change at runtime. + + Also merges in EPs from :data:`_LEGACY_EP_DEVICE_FALLBACK` that are + advertised by ``onnxruntime.get_available_providers()`` but not yet + registered as ``OrtEpDevice`` instances (e.g. ``VitisAIExecutionProvider`` + in ``onnxruntime-vitisai`` 1.23.x). """ result: dict[str, list[EPName]] = {} try: @@ -113,6 +127,19 @@ def _get_device_ep_map_from_ort() -> dict[str, tuple[EPName, ...]]: # map and raises "No execution providers detected" — the user needs # the root cause visible at default verbosity to act on it. logger.warning("Failed to enumerate registered EP devices", exc_info=True) + + # Legacy-API fallback: some EPs (e.g. VitisAI) only register via + # ``get_available_providers()``, not via ``get_ep_devices()``. + try: + import onnxruntime as ort + + available = set(ort.get_available_providers()) + for ep_name, device_name in _LEGACY_EP_DEVICE_FALLBACK.items(): + if ep_name in available and ep_name not in result.get(device_name, ()): + result.setdefault(device_name, []).append(ep_name) + except Exception: + logger.debug("Legacy EP fallback enumeration failed", exc_info=True) + return {dev: tuple(eps) for dev, eps in result.items()} diff --git a/src/winml/modelkit/utils/__init__.py b/src/winml/modelkit/utils/__init__.py index 093c8cada..47d572c45 100644 --- a/src/winml/modelkit/utils/__init__.py +++ b/src/winml/modelkit/utils/__init__.py @@ -12,6 +12,12 @@ load_hf_components_from_onnx, save_local_model_configs, ) +from .model_input import ( + ModelInput, + ModelInputKind, + classify_model_input, + resolve_model_input, +) from .optimum_loader import ( OptimumONNXModel, load_optimum_model, @@ -19,12 +25,16 @@ __all__ = [ + "ModelInput", + "ModelInputKind", "OptimumONNXModel", + "classify_model_input", "inject_hub_metadata", "is_hub_model", "load_hf_components_from_onnx", "load_optimum_model", "merge_config", "normalize_ep_name", + "resolve_model_input", "save_local_model_configs", ] diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index d82f6617a..9f293f9b0 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -465,11 +465,46 @@ def load_build_config(config_path: Path) -> tuple[WinMLBuildConfig, dict]: def is_onnx_file_path(model_input: str) -> bool: """Check if input is a path to an existing ``.onnx`` file. - Shared helper for CLI commands that accept either a HuggingFace model ID - or a local ``.onnx`` file path for the ``-m/--model`` option. + Thin wrapper kept for backwards-compatible callers; new code should + use :func:`~winml.modelkit.utils.model_input.classify_model_input` + directly and inspect the returned ``ModelInput.kind``. """ - path = Path(model_input) - return path.suffix == ".onnx" and path.exists() + from .model_input import classify_model_input + + return classify_model_input(model_input).kind == "local_onnx" + + +def normalize_model_arg(value: str | None) -> str | None: + """Normalize a CLI ``-m/--model`` value to a local path or pass-through. + + Single CLI-layer entry point for resolving Hub-hosted ONNX references + (``org/repo/path/file.onnx``) into local cached paths. Every + ``winml`` subcommand should call this once on the raw ``-m`` value + near the top of its command body, so downstream code (build configs, + perf benchmarks, eval sessions, inspect lookups) only ever sees: + + * a local filesystem path (Hub refs are resolved here), or + * a HuggingFace model ID (``org/name``, passed through unchanged), or + * ``None`` (pass-through). + + Delegates to :func:`~winml.modelkit.utils.model_input.resolve_model_input`, + the single unified classifier+resolver. This is the CLI counterpart + to library entry points such as :meth:`WinMLSession.load` and + :meth:`WinMLAutoModel.from_pretrained`, which call ``resolve_model_input`` + directly at the programmatic boundary. + + Args: + value: Raw ``-m/--model`` value (HF id, local path, Hub ONNX ref, or ``None``). + + Returns: + Local ``.onnx`` path string when ``value`` was a Hub ref; the + original ``value`` otherwise. ``None`` returns ``None``. + """ + if value is None: + return None + from .model_input import resolve_model_input + + return resolve_model_input(value).local_path or value def is_cli_provided(ctx: click.Context, param_name: str) -> bool: diff --git a/src/winml/modelkit/utils/eval_utils.py b/src/winml/modelkit/utils/eval_utils.py index 5a854cff9..fa5d58c7b 100644 --- a/src/winml/modelkit/utils/eval_utils.py +++ b/src/winml/modelkit/utils/eval_utils.py @@ -287,6 +287,22 @@ class TaskSchema: ), ) +_MASK_GENERATION_SCHEMA = TaskSchema( + columns=( + SchemaItem( + "input_column", "input image (PIL.Image)", + default="image", remap_hint="", + ), + SchemaItem( + "mask_column", "binary or instance ground-truth mask (PIL.Image)", + default="mask", remap_hint="", + ), + ), + # SAM-family composite: ``image-encoder`` runs once, ``prompt-decoder`` + # consumes the embeddings plus point / box prompts to produce masks. + roles=("image-encoder", "prompt-decoder"), +) + TASK_SCHEMAS: dict[str, TaskSchema] = { "image-classification": _IMAGE_CLASSIFICATION_SCHEMA, "text-classification": _TEXT_CLASSIFICATION_SCHEMA, @@ -304,6 +320,7 @@ class TaskSchema: "zero-shot-classification": _ZERO_SHOT_CLASSIFICATION_SCHEMA, "zero-shot-image-classification": _ZERO_SHOT_IMAGE_CLASSIFICATION_SCHEMA, "depth-estimation": _DEPTH_ESTIMATION_SCHEMA, + "mask-generation": _MASK_GENERATION_SCHEMA, } diff --git a/src/winml/modelkit/utils/hub_utils.py b/src/winml/modelkit/utils/hub_utils.py index 1de403c8d..73340ad27 100644 --- a/src/winml/modelkit/utils/hub_utils.py +++ b/src/winml/modelkit/utils/hub_utils.py @@ -17,6 +17,43 @@ logger = logging.getLogger(__name__) +# Local-path indicators used to short-circuit Hub detection. Centralized +# here so every callsite that classifies a model input string applies the +# same rejection rules (existing path, ./ ../ /. ~/ prefixes, Windows +# drive letter). Without this shared helper, each detector tends to +# reimplement only the easiest check (Path.exists) and accept inputs +# like ``./model.onnx`` as Hub references. +_LOCAL_PATH_PREFIXES: tuple[str, ...] = ("./", "../", "/", "~/") +_WIN_DRIVE_RE = re.compile(r"^[A-Za-z]:[\\/]") + + +def _is_local_path(value: str | None) -> bool: + r"""Return ``True`` if *value* looks like a local filesystem path. + + Heuristic check used to reject local paths before treating an input as + a Hub identifier. Detects: + + * existing filesystem entries (``Path.exists()``); + * Unix-style relative/absolute/home prefixes (``./``, ``../``, ``/``, ``~/``); + * Windows drive-letter absolute paths (``C:\``, ``D:/``). + + ``None`` and empty strings are not local paths. + """ + if not value: + return False + try: + if Path(value).exists(): + return True + except (OSError, ValueError): + # Path may be syntactically invalid on the current platform + # (e.g. control characters); treat as "not a local path" so the + # caller can apply Hub-format heuristics instead. + pass + if value.startswith(_LOCAL_PATH_PREFIXES): + return True + return bool(_WIN_DRIVE_RE.match(value)) + + def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: """Comprehensive Hub model detection with metadata extraction. @@ -26,16 +63,8 @@ def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: Returns: Tuple of (is_hub_model, metadata_dict) """ - # Quick rejection for obvious local paths - if Path(model_name_or_path).exists(): - return False, {"type": "local", "path": model_name_or_path} - - # Check for local path indicators - if any(model_name_or_path.startswith(prefix) for prefix in ["./", "../", "/", "~/"]): - return False, {"type": "local", "path": model_name_or_path} - - # Check for Windows absolute paths - if re.match(r"^[A-Za-z]:[\\/]", model_name_or_path): + # Local-path rejection (existing path, ./ ../ /. ~/ prefixes, Windows drive) + if _is_local_path(model_name_or_path): return False, {"type": "local", "path": model_name_or_path} # Parse potential Hub model format diff --git a/src/winml/modelkit/utils/model_input.py b/src/winml/modelkit/utils/model_input.py new file mode 100644 index 000000000..4d99f5c62 --- /dev/null +++ b/src/winml/modelkit/utils/model_input.py @@ -0,0 +1,156 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Single classifier + resolver for ``-m/--model`` input values. + +ModelKit accepts four shapes for a model reference: + +* a HuggingFace model ID (``org/name``); +* a local ``.onnx`` file path; +* a Hub-hosted ONNX artifact (``org/repo/path/file.onnx``); and +* a local build-output directory (containing a ModelKit manifest + cached ONNX). + +Historically the codebase had separate detectors for each form +(``is_hub_model``, ``is_hf_onnx_path``, ``is_onnx_file_path``, plus +scattered ``path.suffix == ".onnx"`` checks). This module replaces them +with a single classifier (:func:`classify_model_input`) and a single +resolver (:func:`resolve_model_input`) so adding a fourth input form +later means editing one function, not seven. + +The classifier is pure (no I/O beyond the shared ``_is_local_path`` +existence check). The resolver downloads Hub-hosted ONNX artifacts via +:func:`~winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path` and +populates ``local_path`` on the returned :class:`ModelInput`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, replace +from pathlib import Path +from typing import Literal + +from .hub_utils import _is_local_path + + +ModelInputKind = Literal[ + "local_onnx", # an existing .onnx file on the local filesystem + "build_dir", # an existing directory (typically a ModelKit build output) + "hub_onnx", # org/repo/path/file.onnx -> needs Hub download + "hf_id", # HuggingFace model id (org/name or model-name) + "invalid", # empty or unparsable +] + + +@dataclass(frozen=True) +class ModelInput: + """Discriminated classification of a ``-m/--model`` value. + + Attributes: + kind: One of :data:`ModelInputKind`. + raw: The original user-supplied string (unchanged). + local_path: Local filesystem path for ``local_onnx`` / ``build_dir``, + or the cached download path for ``hub_onnx`` after + :func:`resolve_model_input`. ``None`` for ``hf_id`` / + ``invalid``. + hf_id: HuggingFace repo id (``org/name``) for ``hf_id`` / + ``hub_onnx``. ``None`` otherwise. + """ + + kind: ModelInputKind + raw: str + local_path: str | None = None + hf_id: str | None = None + + +def classify_model_input(value: str) -> ModelInput: + """Classify a ``-m/--model`` value without any network I/O. + + Resolution order (first match wins): + + 1. Empty / falsy → ``invalid``. + 2. Looks like a local path (``_is_local_path``) → + ``local_onnx`` if the suffix is ``.onnx``, else ``build_dir`` if + it is an existing directory, else ``invalid``. + 3. Has ≥ 3 ``/``-separated components and ends with ``.onnx`` → + ``hub_onnx``. + 4. Otherwise → ``hf_id``. + + Local-path rejection in step 2 reuses ``_is_local_path`` (existing + path, ``./``/``../``/``/``/``~/`` prefixes, Windows drive letters), + the same logic used by ``is_hub_model``, so all four forms apply + identical rejection rules. + """ + if not value: + return ModelInput(kind="invalid", raw=value or "") + + raw = value + + if _is_local_path(value): + path = Path(value) + if path.suffix.lower() == ".onnx": + return ModelInput(kind="local_onnx", raw=raw, local_path=str(path)) + if path.is_dir(): + return ModelInput(kind="build_dir", raw=raw, local_path=str(path)) + # A local path that is neither .onnx nor a directory (e.g. a + # mistyped file). Leave it to the caller to surface a friendly + # error from its own context. + return ModelInput(kind="invalid", raw=raw, local_path=str(path)) + + # Case-insensitive .onnx match keeps parity with the rest of the + # CLI, which lowercases suffixes when sniffing file types. + if value.lower().endswith(".onnx"): + parts = [p for p in value.split("/") if p] + if len(parts) >= 3: + repo_id = "/".join(parts[:2]) + return ModelInput(kind="hub_onnx", raw=raw, hf_id=repo_id) + return ModelInput(kind="invalid", raw=raw) + + return ModelInput(kind="hf_id", raw=raw, hf_id=value) + + +def resolve_model_input( + value: str, + *, + revision: str | None = None, + cache_dir: str | Path | None = None, + token: str | bool | None = None, +) -> ModelInput: + """Classify + download Hub-hosted ONNX refs in one call. + + Equivalent to :func:`classify_model_input` for every kind except + ``hub_onnx``, where the file is fetched via ``huggingface_hub`` and + the returned :class:`ModelInput` has ``local_path`` populated with + the cached path. + + Args: + value: ``-m/--model`` value (HF id, local path, Hub ONNX ref). + revision: Optional Hub revision forwarded for ``hub_onnx``. + cache_dir: Optional cache override forwarded for ``hub_onnx``. + token: Optional auth token forwarded for ``hub_onnx``. + + Returns: + A :class:`ModelInput` with ``local_path`` populated whenever the + kind implies a filesystem path. + """ + mi = classify_model_input(value) + if mi.kind != "hub_onnx": + return mi + + # Lazy import: keeps huggingface_hub off the CLI startup path for + # commands that never touch the Hub. Tests patch the downloader on + # the loader package so the lookup picks up the mock at call time. + from ..loader.onnx_hub import resolve_hf_onnx_path + + local = resolve_hf_onnx_path( + value, revision=revision, cache_dir=cache_dir, token=token, + ) + return replace(mi, local_path=str(local)) + + +__all__ = [ + "ModelInput", + "ModelInputKind", + "classify_model_input", + "resolve_model_input", +] diff --git a/src/winml/modelkit/winml.py b/src/winml/modelkit/winml.py index d43d94b07..81a8f39ba 100644 --- a/src/winml/modelkit/winml.py +++ b/src/winml/modelkit/winml.py @@ -83,4 +83,16 @@ def add_ep_for_device( [ep_device], {} if ep_options is None else ep_options ) return True + + # Legacy-API fallback: some EPs (notably VitisAIExecutionProvider in + # ``onnxruntime-vitisai`` 1.23.x) are present in + # ``get_available_providers()`` but not in ``get_ep_devices()``. Use the + # legacy ``SessionOptions.add_provider`` path for those. ``add_provider`` + # takes ``dict[str, str]`` so coerce values. + if ep_name in ort.get_available_providers(): + str_options = {k: str(v) for k, v in (ep_options or {}).items()} + logger.info("Adding %s via legacy add_provider (no OrtEpDevice)", ep_name) + session_options.add_provider(ep_name, str_options) + return True + return False diff --git a/tests/integration/test_sam3_e2e.py b/tests/integration/test_sam3_e2e.py index 4be922588..b558e5b15 100644 --- a/tests/integration/test_sam3_e2e.py +++ b/tests/integration/test_sam3_e2e.py @@ -12,8 +12,9 @@ Pipeline verified by this test: -1. ``is_hf_onnx_path`` recognizes the Hub-style ONNX reference and - ``resolve_hf_onnx_path`` downloads the file via ``huggingface_hub``. +1. ``classify_model_input`` recognizes the Hub-style ONNX reference as + ``kind == "hub_onnx"`` and ``resolve_hf_onnx_path`` downloads the file + via ``huggingface_hub``. 2. ``generate_onnx_build_config`` produces a valid build config for the already-quantized ONNX (skips optimize and quantize stages). 3. ``build_onnx_model`` produces a final ``model.onnx`` artifact that loads @@ -63,13 +64,15 @@ def sam3_onnx_path(self) -> Path: """ from huggingface_hub.utils import HfHubHTTPError - from winml.modelkit.loader import is_hf_onnx_path, resolve_hf_onnx_path + from winml.modelkit.loader import resolve_hf_onnx_path + from winml.modelkit.utils.model_input import classify_model_input - assert is_hf_onnx_path(SAM3_ONNX_REF) + assert classify_model_input(SAM3_ONNX_REF).kind == "hub_onnx" try: return resolve_hf_onnx_path(SAM3_ONNX_REF) except (HfHubHTTPError, OSError) as e: pytest.skip(f"Network unavailable to download {SAM3_ONNX_REF}: {e}") + raise # unreachable (pytest.skip raises Skipped); satisfies static analyzers def test_resolves_to_local_onnx_file(self, sam3_onnx_path: Path) -> None: """The Hub reference resolves to an on-disk .onnx file.""" @@ -198,13 +201,15 @@ def encoder_onnx_path(self) -> Path: """ from huggingface_hub.utils import HfHubHTTPError - from winml.modelkit.loader import is_hf_onnx_path, resolve_hf_onnx_path + from winml.modelkit.loader import resolve_hf_onnx_path + from winml.modelkit.utils.model_input import classify_model_input - assert is_hf_onnx_path(SAM3_ENCODER_ONNX_REF) + assert classify_model_input(SAM3_ENCODER_ONNX_REF).kind == "hub_onnx" try: return resolve_hf_onnx_path(SAM3_ENCODER_ONNX_REF) except (HfHubHTTPError, OSError) as e: pytest.skip(f"Network unavailable to download {SAM3_ENCODER_ONNX_REF}: {e}") + raise # unreachable (pytest.skip raises Skipped); satisfies static analyzers def test_encoder_is_detected_as_quantized(self, encoder_onnx_path: Path) -> None: """The QOperator-quantized encoder is recognized by is_quantized_onnx.""" diff --git a/tests/unit/build/test_hf.py b/tests/unit/build/test_hf.py index 969366219..067e9c1e2 100644 --- a/tests/unit/build/test_hf.py +++ b/tests/unit/build/test_hf.py @@ -150,7 +150,7 @@ def mock_pipeline(): return_value=_default_analyze_result(), ) as m_analyze, patch( - "winml.modelkit.build.hf.is_quantized_onnx", + "winml.modelkit.build.common.is_quantized_onnx", return_value=False, ) as m_has_qdq, patch( diff --git a/tests/unit/build/test_onnx.py b/tests/unit/build/test_onnx.py index 2b10fe783..e333aa368 100644 --- a/tests/unit/build/test_onnx.py +++ b/tests/unit/build/test_onnx.py @@ -144,7 +144,7 @@ def mock_onnx_pipeline(): side_effect=_create_file_side_effect("output_path", compile_result), ) as m_compile, patch( - "winml.modelkit.build.onnx.is_quantized_onnx", + "winml.modelkit.build.common.is_quantized_onnx", return_value=False, ) as m_has_qdq, patch( diff --git a/tests/unit/commands/test_eval.py b/tests/unit/commands/test_eval.py index 459d93280..6ec330c54 100644 --- a/tests/unit/commands/test_eval.py +++ b/tests/unit/commands/test_eval.py @@ -121,11 +121,13 @@ def test_hub_onnx_ref_is_resolved(self, tmp_path): local.write_bytes(b"") hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" - # eval.py does ``from ..loader import resolve_hf_onnx_path``, which - # binds the helper name lazily INSIDE _resolve_model_path. Patch on - # the loader package re-export so the lazy import sees the mock. + # eval.py routes all Hub-ONNX resolution through + # ``cli_utils.normalize_model_arg`` -> ``resolve_model_input`` + # (the single CLI-layer entry point). Patch the underlying + # downloader so the lazy ``from ..loader.onnx_hub import + # resolve_hf_onnx_path`` picks up the mock at call time. with patch( - "winml.modelkit.loader.resolve_hf_onnx_path", + "winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path", return_value=local, ) as mock_resolve: path, mid = _resolve_model_path( @@ -219,6 +221,84 @@ def test_whitespace_stripped(self, onnx_vision): ) assert path == {"image-encoder": str(onnx_vision)} + def test_composite_hub_refs_resolved(self, tmp_path): + """role=org/repo/path/file.onnx resolves via Hub-ONNX loader. + + Regression test for the multi-role variant of the single-model + Hub-ref resolution -- needed by SAM 3's + ``-m image-encoder=...ONNX/onnx/vision_encoder_int8.onnx + -m prompt-decoder=...ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx`` + invocation in ``winml eval``. + """ + enc_local = tmp_path / "vision_encoder_int8.onnx" + enc_local.write_bytes(b"") + dec_local = tmp_path / "prompt_encoder_mask_decoder_int8.onnx" + dec_local.write_bytes(b"") + enc_ref = ( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + dec_ref = ( + "onnx-community/sam3-tracker-ONNX/onnx/" + "prompt_encoder_mask_decoder_int8.onnx" + ) + + # Map each Hub ref to its (different) local cache location. + def fake_resolve(ref, **kwargs): + return { + enc_ref: enc_local, + dec_ref: dec_local, + }[str(ref)] + + with patch( + "winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path", + side_effect=fake_resolve, + ) as mock_resolve: + path, mid = _resolve_model_path( + model=( + f"image-encoder={enc_ref}", + f"prompt-decoder={dec_ref}", + ), + model_id="facebook/sam3-tracker", + ) + assert mock_resolve.call_count == 2 + assert path == { + "image-encoder": str(enc_local), + "prompt-decoder": str(dec_local), + } + assert mid == "facebook/sam3-tracker" + + def test_composite_mixed_hub_and_local(self, onnx_vision, tmp_path): + """One role is a Hub ref, the other is a local path -- both work.""" + dec_local = tmp_path / "decoder.onnx" + dec_local.write_bytes(b"") + enc_ref = ( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + + # ``resolve_hf_onnx_path`` is the underlying downloader; the + # unified classifier+resolver only calls it for hub_onnx inputs. + # Local paths short-circuit in the classifier and never reach + # this mock. + with patch( + "winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path", + return_value=dec_local, + ) as mock_resolve: + path, mid = _resolve_model_path( + model=( + f"image-encoder={enc_ref}", + f"prompt-decoder={onnx_vision}", + ), + model_id="facebook/sam3-tracker", + ) + # Only the Hub ref triggers a download; local path passes + # through the classifier without touching the resolver. + mock_resolve.assert_called_once() + assert path == { + "image-encoder": str(dec_local), + "prompt-decoder": str(onnx_vision), + } + assert mid == "facebook/sam3-tracker" + # --------------------------------------------------------------------------- # Mixing forms diff --git a/tests/unit/commands/test_hub_onnx_ref.py b/tests/unit/commands/test_hub_onnx_ref.py index c600c1aea..4c8104fb5 100644 --- a/tests/unit/commands/test_hub_onnx_ref.py +++ b/tests/unit/commands/test_hub_onnx_ref.py @@ -27,9 +27,28 @@ HUB_ONNX_REF = "onnx-community/sam3-tracker-ONNX/onnx/prompt_encoder_mask_decoder_int8.onnx" +_DEVICE_TO_EPS = { + "npu": ["QNNExecutionProvider"], + "gpu": ["DmlExecutionProvider"], + "cpu": ["CPUExecutionProvider"], +} + + +def _fake_resolve_check_device_ep(*, device: str = "auto", ep: str | None = None): + """Side effect for resolve_check_device_ep that honours the requested device.""" + resolved = device.lower() if device != "auto" else "npu" + eps = _DEVICE_TO_EPS.get(resolved, ["CPUExecutionProvider"]) + return resolved, ["npu", "gpu", "cpu"], eps + + @pytest.fixture(autouse=True) def mock_resolve_device(): - """Mock hardware detection so config/build tests run on any host.""" + """Mock hardware detection so config/build tests run on any host. + + Build/config CLIs auto-resolve device + EP at the top of the command, + so ``resolve_device``, ``resolve_eps``, and ``resolve_check_device_ep`` + must all be patched (mirrors ``tests/unit/commands/test_build.py``). + """ mock_registry = MagicMock() mock_registry.is_ep_available.return_value = False @@ -38,6 +57,14 @@ def mock_resolve_device(): "winml.modelkit.sysinfo.resolve_device", return_value=("npu", ["npu", "gpu", "cpu"]), ), + patch( + "winml.modelkit.sysinfo.resolve_eps", + side_effect=lambda device: list(_DEVICE_TO_EPS.get(device, [])), + ), + patch( + "winml.modelkit.sysinfo.resolve_check_device_ep", + side_effect=_fake_resolve_check_device_ep, + ), patch( "winml.modelkit.session.ep_registry.WinMLEPRegistry.get_instance", return_value=mock_registry, diff --git a/tests/unit/commands/test_perf_cli.py b/tests/unit/commands/test_perf_cli.py index daff125ea..c1ac7031a 100644 --- a/tests/unit/commands/test_perf_cli.py +++ b/tests/unit/commands/test_perf_cli.py @@ -387,15 +387,26 @@ def test_cli_hub_onnx_ref_is_resolved(self, runner: CliRunner, tmp_path: Path) - local.write_bytes(b"fake onnx") hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + mock_result = MagicMock() + mock_result.to_dict = MagicMock(return_value={}) + + # Stub PerfBenchmark so the test stays fast and EP-independent; + # capture the BenchmarkConfig it was constructed with so we can + # assert ``model_id`` is the resolved local path, not the Hub ref. + captured_configs: list = [] + original_init = PerfBenchmark.__init__ + + def _capturing_init(self_, config, *args, **kwargs): + captured_configs.append(config) + original_init(self_, config, *args, **kwargs) + with ( patch( - "winml.modelkit.loader.maybe_resolve_hf_onnx_path", - return_value=str(local), + "winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path", + return_value=local, ) as mock_resolve, - patch( - "winml.modelkit.commands.perf._run_onnx_benchmark", - return_value=MagicMock(), - ) as mock_run, + patch.object(PerfBenchmark, "__init__", _capturing_init), + patch.object(PerfBenchmark, "run", return_value=mock_result) as mock_run, patch("winml.modelkit.commands.perf.display_console_report"), patch("winml.modelkit.commands.perf.write_json_report"), ): @@ -406,12 +417,15 @@ def test_cli_hub_onnx_ref_is_resolved(self, runner: CliRunner, tmp_path: Path) - ) assert result.exit_code == 0, result.output - mock_resolve.assert_called_once_with(hub_ref) - # After resolution, the Hub ref reaches _run_onnx_benchmark as - # the LOCAL path -- not the original Hub ref string. + # ``resolve_model_input`` forwards revision/cache_dir/token kwargs + # to the downloader; only the positional Hub ref is meaningful here. + mock_resolve.assert_called_once() + assert mock_resolve.call_args.args == (hub_ref,) + # After resolution, the PerfBenchmark sees the LOCAL path on its + # config.model_id -- not the original Hub ref string. mock_run.assert_called_once() - called_path = mock_run.call_args.args[0] - assert called_path == local + assert len(captured_configs) == 1 + assert Path(captured_configs[0].model_id) == local def test_onnx_load_model_passes_ep(self, tmp_path: Path) -> None: """EP argument should be forwarded to from_onnx.""" diff --git a/tests/unit/config/test_build.py b/tests/unit/config/test_build.py index 6e298772f..17bed4ede 100644 --- a/tests/unit/config/test_build.py +++ b/tests/unit/config/test_build.py @@ -985,6 +985,9 @@ def test_config_cli_with_override_file( return_value=mock_export_config, ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), + # config now inspects the HF config to route seq2seq composites (#850); + # stub that load (bert -> no composite) so the placeholder -m isn't fetched. + patch("transformers.AutoConfig.from_pretrained", return_value=BertConfig()), ): runner = CliRunner() result = runner.invoke( @@ -1018,6 +1021,9 @@ def test_config_cli_without_override( return_value=mock_export_config, ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), + # config now inspects the HF config to route seq2seq composites (#850); + # stub that load (bert -> no composite) so the placeholder -m isn't fetched. + patch("transformers.AutoConfig.from_pretrained", return_value=BertConfig()), ): runner = CliRunner() result = runner.invoke( @@ -1454,6 +1460,9 @@ def test_empty_override_json_is_noop( return_value=mock_export_config, ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), + # config now inspects the HF config to route seq2seq composites (#850); + # stub that load (bert -> no composite) so the placeholder -m isn't fetched. + patch("transformers.AutoConfig.from_pretrained", return_value=BertConfig()), ): runner = CliRunner() result = runner.invoke( @@ -1649,6 +1658,9 @@ def test_shape_config_cli_with_config_combined( return_value=mock_export_config, ) as mock_gen_export, patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), + # config now inspects the HF config to route seq2seq composites (#850); + # stub that load (bert -> no composite) so the placeholder -m isn't fetched. + patch("transformers.AutoConfig.from_pretrained", return_value=BertConfig()), ): runner = CliRunner() result = runner.invoke( @@ -2153,6 +2165,11 @@ def _mock_deps( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("npu", ["npu", "gpu", "cpu"], ["QNNExecutionProvider"]), ), + # config now inspects the HF config to route seq2seq composites (#850); + # stub that load (bert -> no composite) so the placeholder -m isn't fetched. + "autoconfig": patch( + "transformers.AutoConfig.from_pretrained", return_value=BertConfig() + ), } def _invoke(self, tmp_path, extra_args: list[str] | None = None): @@ -2167,6 +2184,7 @@ def _invoke(self, tmp_path, extra_args: list[str] | None = None): self._patches["export"], self._patches["registry"], self._patches["device"], + self._patches["autoconfig"], ): runner = CliRunner() result = runner.invoke(config_command, args) diff --git a/tests/unit/core/test_onnx_utils.py b/tests/unit/core/test_onnx_utils.py index 353776031..c76035b1f 100644 --- a/tests/unit/core/test_onnx_utils.py +++ b/tests/unit/core/test_onnx_utils.py @@ -207,6 +207,7 @@ def test_returns_correct_structure(self) -> None: expected_keys = { "input_names", "input_shapes", + "input_symbolic_shapes", "input_types", "output_names", "output_shapes", diff --git a/tests/unit/datasets/test_mask_generation.py b/tests/unit/datasets/test_mask_generation.py new file mode 100644 index 000000000..70a36c001 --- /dev/null +++ b/tests/unit/datasets/test_mask_generation.py @@ -0,0 +1,298 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for MaskGenerationDataset and its geometry helpers. + +These tests mock ``datasets.load_dataset`` so they don't hit the network -- +each test builds a small in-memory fake dataset that mimics the HF +``datasets.Dataset`` interface enough for ``MaskGenerationDataset`` to +operate on. +""" + +from __future__ import annotations + +from random import Random +from typing import Any +from unittest.mock import patch + +import numpy as np +import pytest +from datasets.features import Features, Value +from datasets.features import Image as HFImage +from PIL import Image as PILImage + +from winml.modelkit.datasets.mask_generation import ( + MaskGenerationDataset, + _bbox_from_mask, + _foreground_point, +) + + +# --------------------------------------------------------------------------- +# Pure-function helpers +# --------------------------------------------------------------------------- + + +class TestBboxFromMask: + def test_centered_square(self): + mask = np.zeros((10, 10), dtype=bool) + mask[3:7, 4:9] = True + assert _bbox_from_mask(mask) == (4, 3, 8, 6) + + def test_single_pixel(self): + mask = np.zeros((5, 5), dtype=bool) + mask[2, 3] = True + assert _bbox_from_mask(mask) == (3, 2, 3, 2) + + def test_empty_mask_raises(self): + mask = np.zeros((4, 4), dtype=bool) + with pytest.raises(ValueError, match="empty mask"): + _bbox_from_mask(mask) + + def test_accepts_non_bool_dtype(self): + """Helper should coerce uint8/int masks transparently.""" + mask = np.zeros((6, 6), dtype=np.uint8) + mask[1:4, 2:5] = 1 + assert _bbox_from_mask(mask) == (2, 1, 4, 3) + + +class TestForegroundPoint: + def test_centroid_inside_mask(self): + mask = np.zeros((20, 20), dtype=bool) + mask[5:15, 5:15] = True # convex square; centroid at (9, 9) + rng = Random(0) + x, y = _foreground_point(mask, rng) + # Centroid of pixels 5..14 = (9.5 floored to 9) + assert mask[y, x] # point IS on mask + assert (x, y) == (9, 9) + + def test_concave_mask_falls_back_to_random(self): + """Two disjoint blobs -- centroid lands in background.""" + mask = np.zeros((20, 20), dtype=bool) + mask[2:5, 2:5] = True + mask[15:18, 15:18] = True + # Mean of all foreground ys/xs ~= (9.5, 9.5) which is background. + rng = Random(42) + x, y = _foreground_point(mask, rng) + assert mask[y, x] # picked an actual foreground pixel + + def test_empty_mask_raises(self): + mask = np.zeros((4, 4), dtype=bool) + with pytest.raises(ValueError, match="empty mask"): + _foreground_point(mask, Random(0)) + + +# --------------------------------------------------------------------------- +# Fake HuggingFace dataset +# --------------------------------------------------------------------------- + + +def _make_pil(image_arr: np.ndarray) -> PILImage.Image: + """Wrap a numpy array as a PIL image (uint8 RGB or single-channel).""" + if image_arr.ndim == 2: + return PILImage.fromarray(image_arr.astype(np.uint8), mode="L") + return PILImage.fromarray(image_arr.astype(np.uint8), mode="RGB") + + +class _FakeHFDataset: + """Minimal stand-in for ``datasets.Dataset`` -- enough for + ``MaskGenerationDataset`` to call ``len()``, ``[]``, ``.select()``, + ``.shuffle()``, ``.features``. + """ + + def __init__(self, rows: list[dict[str, Any]], features: Features): + self._rows = rows + self.features = features + + def __len__(self) -> int: + return len(self._rows) + + def __getitem__(self, idx): + return self._rows[idx] + + def select(self, indices): + return _FakeHFDataset([self._rows[i] for i in indices], self.features) + + def shuffle(self, seed: int = 0): + rng = Random(seed) + rows = list(self._rows) + rng.shuffle(rows) + return _FakeHFDataset(rows, self.features) + + +def _build_fake_ds(n_rows: int = 5, with_text: bool = False) -> _FakeHFDataset: + """Build a fake dataset of (image, mask) pairs. + + Image: 32x40 RGB (W=40, H=32 in PIL terms). + Mask: same size, foreground = a centered 10x12 rectangle. + """ + feats = { + "image": HFImage(), + "mask": HFImage(), + } + if with_text: + feats["text"] = Value("string") + features = Features(feats) + + rows: list[dict[str, Any]] = [] + for i in range(n_rows): + img = np.full((32, 40, 3), 200, dtype=np.uint8) + mask = np.zeros((32, 40), dtype=np.uint8) + # Cover ~10*12/(32*40) = 9.4% (passes default 0.5%-95% filter). + mask[10:22, 14:24] = 1 + row = {"image": _make_pil(img), "mask": _make_pil(mask)} + if with_text: + row["text"] = f"object_{i}" + rows.append(row) + return _FakeHFDataset(rows, features) + + +# --------------------------------------------------------------------------- +# MaskGenerationDataset +# --------------------------------------------------------------------------- + + +def _make_ds(monkeypatch, fake_ds, **kwargs) -> MaskGenerationDataset: + """Construct a MaskGenerationDataset with load_dataset patched to fake_ds.""" + with patch( + "winml.modelkit.datasets.mask_generation.load_dataset", + return_value=fake_ds, + ): + return MaskGenerationDataset( + model_name="dummy/model", + dataset_name="dummy/dataset", + **kwargs, + ) + + +class TestBboxPromptMode: + def test_basic_bbox_sample(self, monkeypatch): + fake = _build_fake_ds(n_rows=3) + ds = _make_ds(monkeypatch, fake) + assert ds.prompt_mode == "bbox" + sample = ds[0] + assert sample is not None + assert isinstance(sample["image"], PILImage.Image) + assert sample["image"].size == (40, 32) + assert sample["gt_mask"].shape == (32, 40) + assert sample["gt_mask"].dtype == np.bool_ + assert sample["prompt"] == {"bbox": [14, 10, 23, 21]} + assert sample["sample_id"] == "sample_0000" + + def test_image_col_property(self, monkeypatch): + fake = _build_fake_ds(n_rows=1) + ds = _make_ds(monkeypatch, fake) + assert ds.image_col == "image" + assert ds.mask_col == "mask" + assert ds.label_col == ds.mask_col + + +class TestPointPromptMode: + def test_point_inside_foreground(self, monkeypatch): + fake = _build_fake_ds(n_rows=1) + ds = _make_ds(monkeypatch, fake, prompt_mode="point") + sample = ds[0] + assert sample is not None + assert "point" in sample["prompt"] + x, y = sample["prompt"]["point"] + assert sample["gt_mask"][y, x] + assert sample["prompt"]["label"] == 1 + + +class TestTextPromptMode: + def test_text_pulled_from_configured_col(self, monkeypatch): + fake = _build_fake_ds(n_rows=2, with_text=True) + ds = _make_ds(monkeypatch, fake, prompt_mode="text", text_col="text") + sample = ds[1] + assert sample is not None + assert sample["prompt"] == {"text": "object_1"} + + def test_text_mode_requires_text_col(self, monkeypatch): + fake = _build_fake_ds(n_rows=1, with_text=False) + with pytest.raises(ValueError, match="requires text_col"): + _make_ds(monkeypatch, fake, prompt_mode="text") + + def test_text_col_must_exist(self, monkeypatch): + fake = _build_fake_ds(n_rows=1, with_text=False) + with pytest.raises(ValueError, match="not found"): + _make_ds(monkeypatch, fake, prompt_mode="text", text_col="missing") + + +class TestCoverageFilter: + def test_below_min_returns_none(self, monkeypatch): + # Build a dataset with a 1-pixel mask -> coverage ~0.08% + feats = Features({"image": HFImage(), "mask": HFImage()}) + img = np.full((32, 40, 3), 0, dtype=np.uint8) + mask = np.zeros((32, 40), dtype=np.uint8) + mask[0, 0] = 1 + fake = _FakeHFDataset( + [{"image": _make_pil(img), "mask": _make_pil(mask)}], + feats, + ) + ds = _make_ds(monkeypatch, fake, min_mask_coverage=0.01) + assert ds[0] is None # filtered out + + def test_above_max_returns_none(self, monkeypatch): + feats = Features({"image": HFImage(), "mask": HFImage()}) + img = np.full((10, 10, 3), 0, dtype=np.uint8) + mask = np.ones((10, 10), dtype=np.uint8) # 100% coverage + fake = _FakeHFDataset( + [{"image": _make_pil(img), "mask": _make_pil(mask)}], + feats, + ) + ds = _make_ds(monkeypatch, fake, max_mask_coverage=0.95) + assert ds[0] is None + + def test_iter_valid_skips_filtered(self, monkeypatch): + feats = Features({"image": HFImage(), "mask": HFImage()}) + rows = [] + # 3 good, 2 too-small interleaved + for cov in [1.0, 0.0, 1.0, 0.0, 1.0]: + img = np.zeros((32, 40, 3), dtype=np.uint8) + mask = np.zeros((32, 40), dtype=np.uint8) + if cov > 0: + mask[5:15, 5:15] = 1 + rows.append({"image": _make_pil(img), "mask": _make_pil(mask)}) + fake = _FakeHFDataset(rows, feats) + ds = _make_ds(monkeypatch, fake) + valid = list(ds.iter_valid()) + assert len(valid) == 3 + assert [s["sample_id"] for s in valid] == [ + "sample_0000", "sample_0002", "sample_0004", + ] + + def test_iter_valid_respects_cap(self, monkeypatch): + fake = _build_fake_ds(n_rows=10) + ds = _make_ds(monkeypatch, fake) + assert len(list(ds.iter_valid(max_samples=4))) == 4 + + +class TestValidation: + def test_bad_prompt_mode(self, monkeypatch): + fake = _build_fake_ds(n_rows=1) + with pytest.raises(ValueError, match="prompt_mode"): + _make_ds(monkeypatch, fake, prompt_mode="silly") # type: ignore[arg-type] + + def test_bad_coverage_bounds(self, monkeypatch): + fake = _build_fake_ds(n_rows=1) + with pytest.raises(ValueError, match="min_mask_coverage"): + _make_ds( + monkeypatch, fake, + min_mask_coverage=0.9, max_mask_coverage=0.1, + ) + + def test_missing_columns_raises(self, monkeypatch): + """Dataset with no Image features should error out clearly.""" + feats = Features({"label": Value("int32")}) + fake = _FakeHFDataset([{"label": 1}], feats) + with pytest.raises(ValueError, match="auto-detect"): + _make_ds(monkeypatch, fake) + + +class TestRegistry: + def test_mask_generation_in_task_dataset_mapping(self): + """Mask-generation task is wired into the dataset registry.""" + from winml.modelkit.datasets import TASK_DATASET_MAPPING + + assert TASK_DATASET_MAPPING["mask-generation"] is MaskGenerationDataset diff --git a/tests/unit/eval/test_binary_segmentation_metric.py b/tests/unit/eval/test_binary_segmentation_metric.py new file mode 100644 index 000000000..0b53f700f --- /dev/null +++ b/tests/unit/eval/test_binary_segmentation_metric.py @@ -0,0 +1,97 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Tests for BinarySegmentationMetric.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from winml.modelkit.eval.metrics.binary_segmentation import BinarySegmentationMetric + + +class TestPerfectAndDisjoint: + def test_perfect_overlap(self) -> None: + m = BinarySegmentationMetric() + gt = np.zeros((10, 10), dtype=bool) + gt[2:6, 2:6] = True + m.update(gt.copy(), gt) + out = m.compute() + assert out["mIoU"] == 1.0 + assert out["dice"] == 1.0 + assert out["num_samples"] == 1 + + def test_disjoint(self) -> None: + m = BinarySegmentationMetric() + gt = np.zeros((10, 10), dtype=bool) + gt[0:4, 0:4] = True + pred = np.zeros((10, 10), dtype=bool) + pred[5:9, 5:9] = True + m.update(pred, gt) + out = m.compute() + assert out["mIoU"] == 0.0 + assert out["dice"] == 0.0 + + +class TestKnownValue: + def test_half_overlap(self) -> None: + # GT is 4x4 = 16 px, pred covers half of it + 8 px elsewhere -> + # intersection=8, union=24, IoU=1/3; |pred|+|gt|=32, Dice=16/32=0.5. + m = BinarySegmentationMetric() + gt = np.zeros((10, 10), dtype=bool) + gt[0:4, 0:4] = True + pred = np.zeros((10, 10), dtype=bool) + pred[0:4, 0:2] = True # 8 px overlap with gt + pred[6:8, 0:4] = True # 8 px outside gt + m.update(pred, gt) + out = m.compute() + assert out["mIoU"] == pytest.approx(1 / 3, abs=1e-6) + assert out["dice"] == pytest.approx(0.5, abs=1e-6) + + +class TestAggregation: + def test_mean_across_samples(self) -> None: + m = BinarySegmentationMetric() + gt = np.zeros((4, 4), dtype=bool) + gt[0:2, 0:2] = True + m.update(gt.copy(), gt) # IoU=1.0 + m.update(np.zeros_like(gt), gt) # IoU=0.0 + out = m.compute() + assert out["mIoU"] == 0.5 + assert out["num_samples"] == 2 + + +class TestEdgeCases: + def test_empty_gt_skipped(self) -> None: + m = BinarySegmentationMetric() + empty = np.zeros((5, 5), dtype=bool) + pred = np.ones((5, 5), dtype=bool) + m.update(pred, empty) + out = m.compute() + assert out["num_samples"] == 0 + assert out["num_skipped"] == 1 + assert out["mIoU"] == 0.0 # documented fallback when nothing scored + + def test_shape_mismatch_raises(self) -> None: + m = BinarySegmentationMetric() + with pytest.raises(ValueError, match="shape"): + m.update(np.zeros((4, 4), dtype=bool), np.zeros((5, 5), dtype=bool)) + + def test_nonzero_treated_as_foreground(self) -> None: + # Pass raw uint8 masks (255) -> should be treated as fg + m = BinarySegmentationMetric() + gt = np.zeros((4, 4), dtype=np.uint8) + gt[0:2, 0:2] = 255 + m.update(gt.copy(), gt) + assert m.compute()["mIoU"] == 1.0 + + +class TestRegistryLazyAttr: + def test_registered_in_metrics_package(self) -> None: + from winml.modelkit.eval import metrics + + cls = metrics.BinarySegmentationMetric + assert cls is BinarySegmentationMetric diff --git a/tests/unit/eval/test_mask_generation_evaluator.py b/tests/unit/eval/test_mask_generation_evaluator.py new file mode 100644 index 000000000..f0693dcee --- /dev/null +++ b/tests/unit/eval/test_mask_generation_evaluator.py @@ -0,0 +1,369 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Tests for the mask-generation evaluator (pure helpers + validation). + +End-to-end tests that exercise a real ORT session require the SAM 3 ONNX +files cached and are run only in the integration suite -- see +``scripts/sam3_smoke_eval.py`` for the script form. +""" + +from __future__ import annotations + +import numpy as np +import pytest +from PIL import Image + +from winml.modelkit.eval.config import DatasetConfig, WinMLEvaluationConfig +from winml.modelkit.eval.mask_generation_evaluator import ( + _TARGET_SIZE, + WinMLMaskGenerationEvaluator, + _build_decoder_inputs, + _build_providers, + _postprocess_mask, + _preprocess_image, +) + + +# ---------------------------------------------------------------------- +# _preprocess_image +# ---------------------------------------------------------------------- + + +class TestPreprocessImage: + def test_shape_and_dtype(self) -> None: + img = Image.new("RGB", (640, 480), color=(127, 127, 127)) + pv, _scale_x, _scale_y = _preprocess_image(img) + assert pv.shape == (1, 3, _TARGET_SIZE, _TARGET_SIZE) + assert pv.dtype == np.float32 + + def test_scale_for_landscape(self) -> None: + # SAM 3 image processor does a direct resize to 1008x1008; the per- + # axis scale factors are TARGET / orig_dim independently of aspect. + img = Image.new("RGB", (800, 400), color=(128, 128, 128)) + _, scale_x, scale_y = _preprocess_image(img) + assert scale_x == pytest.approx(_TARGET_SIZE / 800) + assert scale_y == pytest.approx(_TARGET_SIZE / 400) + + def test_no_padding_full_target_filled(self) -> None: + # Direct resize means every pixel in the output corresponds to a + # real input pixel -- there is no zero-padded border region. + img = Image.new("RGB", (200, 100), color=(255, 0, 0)) + pv, _, _ = _preprocess_image(img) + # Red channel after rescale + (x - 0.5)/0.5 normalization == 1.0 + # everywhere; no zero border. + assert np.allclose(pv[0, 0], 1.0) + + def test_normalization_applied(self) -> None: + # SAM 3 normalization: (pixel/255 - 0.5) / 0.5 + # Black input -> (0 - 0.5) / 0.5 == -1.0 + img = Image.new("RGB", (100, 100), color=(0, 0, 0)) + pv, _, _ = _preprocess_image(img) + center_r = pv[0, 0, _TARGET_SIZE // 2, _TARGET_SIZE // 2] + assert center_r == pytest.approx(-1.0, abs=1e-4) + + +# ---------------------------------------------------------------------- +# _postprocess_mask +# ---------------------------------------------------------------------- + + +class TestPostprocessMask: + def test_recovers_original_shape(self) -> None: + # Low-res mask (256x256) -> original 480x640. With direct-resize + # preprocessing the low-res mask maps 1:1 to the full original + # image regardless of aspect ratio. + low = np.random.RandomState(0).rand(256, 256).astype(np.float32) - 0.5 + out = _postprocess_mask(low, orig_h=480, orig_w=640) + assert out.shape == (480, 640) + assert out.dtype == bool + + def test_thresholding_at_zero(self) -> None: + # All-positive logits -> all-True mask + low = np.ones((256, 256), dtype=np.float32) * 5.0 + out = _postprocess_mask(low, orig_h=100, orig_w=100) + assert out.all() + + low_neg = -low + out_neg = _postprocess_mask(low_neg, orig_h=100, orig_w=100) + assert not out_neg.any() + + +# ---------------------------------------------------------------------- +# _build_decoder_inputs +# ---------------------------------------------------------------------- + + +def _fake_emb() -> dict[str, np.ndarray]: + return { + "image_embeddings.0": np.zeros((1, 32, 288, 288), dtype=np.float32), + "image_embeddings.1": np.zeros((1, 64, 144, 144), dtype=np.float32), + "image_embeddings.2": np.zeros((1, 256, 72, 72), dtype=np.float32), + } + + +class TestBuildDecoderInputsBbox: + def test_bbox_shape_and_scale(self) -> None: + prompt = {"bbox": [10, 20, 30, 40]} + feed = _build_decoder_inputs( + prompt=prompt, prompt_mode="bbox", scale_x=2.0, scale_y=3.0, + emb=_fake_emb(), + ) + # boxes: (1, 1, 4) -- x scaled by 2, y scaled by 3 + assert feed["input_boxes"].shape == (1, 1, 4) + np.testing.assert_array_almost_equal( + feed["input_boxes"][0, 0], [20.0, 60.0, 60.0, 120.0], + ) + # points / labels are empty for bbox mode + assert feed["input_points"].shape == (1, 1, 0, 2) + assert feed["input_labels"].shape == (1, 1, 0) + assert feed["input_labels"].dtype == np.int64 + + def test_includes_all_three_embeddings(self) -> None: + prompt = {"bbox": [0, 0, 10, 10]} + feed = _build_decoder_inputs( + prompt=prompt, prompt_mode="bbox", scale_x=1.0, scale_y=1.0, + emb=_fake_emb(), + ) + for k in ("image_embeddings.0", "image_embeddings.1", "image_embeddings.2"): + assert k in feed + + +class TestBuildDecoderInputsPoint: + def test_point_shape_and_scale(self) -> None: + prompt = {"point": [15, 25], "label": 1} + feed = _build_decoder_inputs( + prompt=prompt, prompt_mode="point", scale_x=2.0, scale_y=3.0, + emb=_fake_emb(), + ) + assert feed["input_points"].shape == (1, 1, 1, 2) + np.testing.assert_array_almost_equal(feed["input_points"][0, 0, 0], [30.0, 75.0]) + # labels = foreground (1) + assert feed["input_labels"].shape == (1, 1, 1) + assert feed["input_labels"][0, 0, 0] == 1 + # boxes empty + assert feed["input_boxes"].shape == (1, 0, 4) + + +class TestBuildDecoderInputsInvalidMode: + def test_text_mode_rejected_with_helpful_message(self) -> None: + with pytest.raises(ValueError, match="Text-prompt"): + _build_decoder_inputs( + prompt={"text": "person"}, prompt_mode="text", + scale_x=1.0, scale_y=1.0, emb=_fake_emb(), + ) + + +# ---------------------------------------------------------------------- +# _build_providers +# ---------------------------------------------------------------------- + + +class TestBuildProviders: + def test_cpu_always_works(self) -> None: + providers, opts = _build_providers("cpu") + assert providers == ["CPUExecutionProvider"] + assert opts == [{}] + + def test_unknown_ep_falls_back_to_cpu(self) -> None: + providers, _opts = _build_providers("nonexistent-ep") + # Falls back via the providers_map.get() default -> CPU + assert providers == ["CPUExecutionProvider"] + + +# ---------------------------------------------------------------------- +# Evaluator constructor validation +# ---------------------------------------------------------------------- + + +def _make_config(model_path) -> WinMLEvaluationConfig: + ds = DatasetConfig(path="mattmdjaga/human_parsing_dataset", split="train", samples=2) + return WinMLEvaluationConfig( + model_id="onnx-community/sam3-tracker-ONNX", + task="mask-generation", + model_path=model_path, + dataset=ds, + device="cpu", + ep="cpu", + ) + + +class TestEvaluatorValidation: + def test_rejects_single_model_path(self) -> None: + cfg = _make_config("some/path.onnx") + with pytest.raises(TypeError, match="role=path"): + WinMLMaskGenerationEvaluator(cfg, model=None) + + def test_rejects_missing_decoder_role(self) -> None: + cfg = _make_config({"image-encoder": "enc.onnx"}) + with pytest.raises(ValueError, match="prompt-decoder"): + WinMLMaskGenerationEvaluator(cfg, model=None) + + def test_rejects_missing_encoder_role(self) -> None: + cfg = _make_config({"prompt-decoder": "dec.onnx"}) + with pytest.raises(ValueError, match="image-encoder"): + WinMLMaskGenerationEvaluator(cfg, model=None) + + +# ---------------------------------------------------------------------- +# Registry wiring +# ---------------------------------------------------------------------- + + +class TestEvaluatorRegistered: + def test_task_resolves_to_evaluator(self) -> None: + from winml.modelkit.eval import WinMLEvaluationConfig + from winml.modelkit.eval.evaluate import get_evaluator_class + + cls = get_evaluator_class(WinMLEvaluationConfig(task="mask-generation")) + assert cls is WinMLMaskGenerationEvaluator + + def test_default_dataset_registered(self) -> None: + from winml.modelkit.eval.evaluate import _DEFAULT_DATASETS + + default = _DEFAULT_DATASETS["mask-generation"] + assert default["path"] == "mattmdjaga/human_parsing_dataset" + assert default["split"] == "train" + + +# ---------------------------------------------------------------------- +# Profile-aware preprocessing / postprocessing (SAM 2 family) +# ---------------------------------------------------------------------- + + +from winml.modelkit.eval.mask_generation_evaluator import ( # noqa: E402 + SAM2_PROFILE, + SAM3_PROFILE, + _postprocess_for_profile, + _preprocess_for_profile, + _resolve_profile, +) + + +class TestPreprocessForProfileSam2: + def test_shape_and_dtype(self) -> None: + img = Image.new("RGB", (640, 480), color=(128, 128, 128)) + pv, sx, sy, new_h, new_w = _preprocess_for_profile(SAM2_PROFILE, img) + # SAM 2 target is 1024x1024 + assert pv.shape == (1, 3, 1024, 1024) + assert pv.dtype == np.float32 + # uniform longest-side scale (single factor used for both axes) + assert sx == pytest.approx(sy) + # 640 is the longer side -> scale = 1024 / 640 = 1.6 + assert sx == pytest.approx(1024 / 640) + # post-resize content dims, pre-pad: w fills target, h shorter + assert new_w == 1024 + assert new_h == round(480 * 1024 / 640) + + def test_zero_padding_present_in_letterbox_region(self) -> None: + # ImageNet mean ~0.485; with letterbox padding using zero in the + # raw pixel domain, post-normalization the pad region equals + # (0 - mean) / std ~ -2.1 for R channel; the content region is + # bounded. We just check that the bottom rows are equal across + # x (a constant value indicating a uniform pad) and != to the + # content rows. + img = Image.new("RGB", (640, 320), color=(255, 255, 255)) + pv, _sx, _sy, new_h, _new_w = _preprocess_for_profile(SAM2_PROFILE, img) + assert new_h < 1024 # there IS a pad region + pad_row = pv[0, 0, -1, :] + content_row = pv[0, 0, new_h // 2, :] + # pad row is uniform (single value across x); content row is uniform + # too because input is white, so distinguish by VALUE not constancy. + assert not np.allclose(pad_row, content_row) + + def test_imagenet_normalization(self) -> None: + # White input: (1.0 - 0.485) / 0.229 ~ 2.248 for R channel. + img = Image.new("RGB", (1024, 1024), color=(255, 255, 255)) + pv, _, _, _, _ = _preprocess_for_profile(SAM2_PROFILE, img) + center_r = pv[0, 0, 512, 512] + assert center_r == pytest.approx((1.0 - 0.485) / 0.229, abs=1e-3) + + +class TestPostprocessForProfileSam2: + def test_recovers_original_shape_with_crop_then_resize(self) -> None: + # Letterbox-padded encoder input was 1024x1024; content region + # was 768x1024 (i.e. shorter side padded). Postprocess should + # crop low-res back to that content aspect, then resize to the + # original 600x800 image. + low = np.random.RandomState(0).rand(256, 256).astype(np.float32) - 0.5 + out = _postprocess_for_profile( + SAM2_PROFILE, low, orig_h=600, orig_w=800, new_h=768, new_w=1024, + ) + assert out.shape == (600, 800) + assert out.dtype == bool + + def test_thresholding_at_zero(self) -> None: + low = np.ones((256, 256), dtype=np.float32) * 5.0 + out = _postprocess_for_profile( + SAM2_PROFILE, low, orig_h=100, orig_w=100, new_h=1024, new_w=1024, + ) + assert out.all() + out_neg = _postprocess_for_profile( + SAM2_PROFILE, -low, orig_h=100, orig_w=100, new_h=1024, new_w=1024, + ) + assert not out_neg.any() + + +# ---------------------------------------------------------------------- +# _resolve_profile -- dispatch logic +# ---------------------------------------------------------------------- + + +class _StubInput: + def __init__(self, shape): + self.shape = shape + + +class _StubSession: + def __init__(self, shape): + self._shape = shape + + def get_inputs(self): + return [_StubInput(self._shape)] + + +def _cfg(model_id: str) -> WinMLEvaluationConfig: + ds = DatasetConfig(path="mattmdjaga/human_parsing_dataset", split="train", samples=2) + return WinMLEvaluationConfig( + model_id=model_id, + task="mask-generation", + model_path={"image-encoder": "enc.onnx", "prompt-decoder": "dec.onnx"}, + dataset=ds, + device="cpu", + ep="cpu", + ) + + +class TestResolveProfile: + def test_shape_signal_picks_sam2(self) -> None: + sess = _StubSession([1, 3, 1024, 1024]) + prof = _resolve_profile(_cfg("ambiguous/name"), sess) + assert prof is SAM2_PROFILE + + def test_shape_signal_picks_sam3(self) -> None: + sess = _StubSession([1, 3, 1008, 1008]) + prof = _resolve_profile(_cfg("ambiguous/name"), sess) + assert prof is SAM3_PROFILE + + def test_falls_back_to_model_id_sam2(self) -> None: + # Dynamic shape (strings) -> fall through to model_id matching + sess = _StubSession([1, 3, "H", "W"]) + prof = _resolve_profile( + _cfg("onnx-community/sam2.1-hiera-small-ONNX"), sess, + ) + assert prof is SAM2_PROFILE + + def test_falls_back_to_model_id_sam3(self) -> None: + sess = _StubSession([1, 3, "H", "W"]) + prof = _resolve_profile( + _cfg("onnx-community/sam3-tracker-ONNX"), sess, + ) + assert prof is SAM3_PROFILE + + def test_default_is_sam3(self) -> None: + sess = _StubSession([1, 3, "H", "W"]) + prof = _resolve_profile(_cfg("unknown/family"), sess) + assert prof is SAM3_PROFILE + diff --git a/tests/unit/inference/test_engine.py b/tests/unit/inference/test_engine.py index 12590a309..68172e5d2 100644 --- a/tests/unit/inference/test_engine.py +++ b/tests/unit/inference/test_engine.py @@ -382,9 +382,14 @@ def test_hub_onnx_ref_is_resolved_before_routing(self, tmp_path: Any) -> None: hub_ref = "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" engine = InferenceEngine() + # ``WinMLSession.load_schema_only`` now routes Hub-ONNX resolution + # through the unified ``resolve_model_input`` (in + # ``winml.modelkit.utils.model_input``). Patch the underlying + # downloader so the lazy ``from ..loader.onnx_hub import + # resolve_hf_onnx_path`` picks up the mock at call time. with patch( - "winml.modelkit.loader.maybe_resolve_hf_onnx_path", - return_value=str(local), + "winml.modelkit.loader.onnx_hub.resolve_hf_onnx_path", + return_value=local, ) as mock_resolve: engine.load_schema_only(hub_ref, task="mask-generation") mock_resolve.assert_called_once() diff --git a/tests/unit/loader/test_onnx_hub.py b/tests/unit/loader/test_onnx_hub.py index 9771b4fdb..aae4c8bb5 100644 --- a/tests/unit/loader/test_onnx_hub.py +++ b/tests/unit/loader/test_onnx_hub.py @@ -4,8 +4,12 @@ # -------------------------------------------------------------------------- """Tests for winml.modelkit.loader.onnx_hub. -Covers Hub-style ONNX reference detection and download. Uses mock -``hf_hub_download`` callables so no network access is required. +Covers the Hub-style ONNX **download** path. Classification and +the combined classify+download wrapper live in +``winml.modelkit.utils.model_input`` and are covered by +``tests/unit/utils/test_model_input.py``. + +Uses mock ``hf_hub_download`` callables so no network access is required. """ from __future__ import annotations @@ -17,8 +21,6 @@ from winml.modelkit.loader.onnx_hub import ( _split_hf_onnx_path, - is_hf_onnx_path, - maybe_resolve_hf_onnx_path, resolve_hf_onnx_path, ) @@ -27,41 +29,6 @@ from pathlib import Path -class TestIsHfOnnxPath: - """Hub ONNX reference detection.""" - - def test_three_segment_onnx_recognized(self) -> None: - """Repo-id + nested file path is a valid Hub ONNX reference.""" - assert is_hf_onnx_path("onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx") - - def test_three_segments_minimum(self) -> None: - """Two segments are treated as a plain HF model ID, not a file ref.""" - assert is_hf_onnx_path("org/repo/file.onnx") - assert not is_hf_onnx_path("org/file.onnx") - - def test_plain_hf_model_id_rejected(self) -> None: - """org/name HF IDs are not Hub ONNX references.""" - assert not is_hf_onnx_path("microsoft/resnet-50") - assert not is_hf_onnx_path("facebook/sam2.1-hiera-small") - - def test_non_onnx_extension_rejected(self) -> None: - """Only .onnx file references match.""" - assert not is_hf_onnx_path("org/repo/path/file.bin") - assert not is_hf_onnx_path("org/repo/path/file") - - def test_existing_local_path_takes_precedence(self, tmp_path: Path) -> None: - """A real on-disk path that looks like a Hub ref is left alone.""" - local = tmp_path / "org" / "repo" / "file.onnx" - local.parent.mkdir(parents=True) - local.write_bytes(b"") - assert not is_hf_onnx_path(str(local)) - - def test_none_and_empty_inputs(self) -> None: - """None and empty string are not Hub references.""" - assert not is_hf_onnx_path(None) - assert not is_hf_onnx_path("") - - class TestSplitHfOnnxPath: """Internal _split_hf_onnx_path helper.""" @@ -181,44 +148,109 @@ def _fake_download(*, repo_id, filename, revision, cache_dir, token): ) -class TestMaybeResolveHfOnnxPath: - """Convenience wrapper that combines is_hf_onnx_path + resolve_hf_onnx_path.""" +class TestResolveHfOnnxPathDiscovery: + """``EntryNotFoundError`` on the main file is enriched with a file listing. - def test_none_passes_through(self) -> None: - """``None`` returns ``None`` without touching the network.""" - with patch("huggingface_hub.hf_hub_download") as mock: - assert maybe_resolve_hf_onnx_path(None) is None - mock.assert_not_called() + The user typically gets here by guessing the wrong path inside a valid + Hub repo (e.g. ``onnx/vision_encoder.onnx`` when only ``int8`` and + ``fp16`` variants exist). The error message must list the ``.onnx`` + files that *are* available so the user can correct the path without + having to open the Hub web UI. + """ - def test_plain_hf_model_id_passes_through(self) -> None: - """An HF model id (e.g. ``microsoft/resnet-50``) is returned unchanged.""" - with patch("huggingface_hub.hf_hub_download") as mock: - assert maybe_resolve_hf_onnx_path("microsoft/resnet-50") == "microsoft/resnet-50" - mock.assert_not_called() + def test_missing_file_lists_available_onnx(self) -> None: + """Wrong filename: error names available .onnx files in the repo.""" + from huggingface_hub.utils import EntryNotFoundError - def test_local_path_passes_through(self, tmp_path: Path) -> None: - """Existing local ``.onnx`` paths take precedence over Hub interpretation.""" - local = tmp_path / "model.onnx" - local.write_bytes(b"") - with patch("huggingface_hub.hf_hub_download") as mock: - assert maybe_resolve_hf_onnx_path(str(local)) == str(local) - mock.assert_not_called() + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + # Main file is missing; sidecar should never be reached + # because the main download raises first. + raise EntryNotFoundError(filename) + + repo_files = [ + "README.md", + "config.json", + "onnx/vision_encoder_int8.onnx", + "onnx/vision_encoder_fp16.onnx", + "onnx/prompt_encoder_mask_decoder_int8.onnx", + ] + + with ( + patch("huggingface_hub.hf_hub_download", side_effect=_fake_download), + patch( + "huggingface_hub.list_repo_files", + return_value=repo_files, + ) as mock_list, + pytest.raises(FileNotFoundError) as exc_info, + ): + resolve_hf_onnx_path( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder.onnx" + ) + + msg = str(exc_info.value) + # Names the bad path and the repo + assert "onnx/vision_encoder.onnx" in msg + assert "onnx-community/sam3-tracker-ONNX" in msg + # Lists every .onnx file that *is* present + assert "onnx/vision_encoder_int8.onnx" in msg + assert "onnx/vision_encoder_fp16.onnx" in msg + assert "onnx/prompt_encoder_mask_decoder_int8.onnx" in msg + # Does not include non-ONNX files + assert "README.md" not in msg + assert "config.json" not in msg + # list_repo_files was called with the repo derived from the bad path + mock_list.assert_called_once() + assert ( + mock_list.call_args.args[0] == "onnx-community/sam3-tracker-ONNX" + or mock_list.call_args.kwargs.get("repo_id") + == "onnx-community/sam3-tracker-ONNX" + ) - def test_hub_ref_is_resolved(self, tmp_path: Path) -> None: - """A Hub-style ONNX ref triggers ``resolve_hf_onnx_path``.""" + def test_missing_file_listing_failure_falls_back_gracefully(self) -> None: + """If list_repo_files itself fails, the error still surfaces cleanly. + + The hint is best-effort -- we must not mask the original + ``EntryNotFoundError`` because the listing step also failed + (gated repo, network blip, auth issue). + """ from huggingface_hub.utils import EntryNotFoundError - downloaded = tmp_path / "vision_encoder_int8.onnx" - downloaded.write_bytes(b"") + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + raise EntryNotFoundError(filename) + + with ( + patch("huggingface_hub.hf_hub_download", side_effect=_fake_download), + patch( + "huggingface_hub.list_repo_files", + side_effect=ConnectionError("network down"), + ), + pytest.raises(FileNotFoundError) as exc_info, + ): + resolve_hf_onnx_path("org/repo/onnx/missing.onnx") + + msg = str(exc_info.value) + assert "onnx/missing.onnx" in msg + assert "org/repo" in msg + # Generic fallback hint is included. + assert "Could not list available .onnx files" in msg + + def test_missing_file_no_onnx_in_repo(self) -> None: + """Repo exists but has no .onnx files at all -- hint says so.""" + from huggingface_hub.utils import EntryNotFoundError def _fake_download(*, repo_id, filename, revision, cache_dir, token): - if filename.endswith(".onnx_data"): - raise EntryNotFoundError(filename) - return str(downloaded) + raise EntryNotFoundError(filename) - with patch("huggingface_hub.hf_hub_download", side_effect=_fake_download): - result = maybe_resolve_hf_onnx_path( - "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" - ) + with ( + patch("huggingface_hub.hf_hub_download", side_effect=_fake_download), + patch( + "huggingface_hub.list_repo_files", + return_value=["README.md", "config.json", "pytorch_model.bin"], + ), + pytest.raises(FileNotFoundError) as exc_info, + ): + resolve_hf_onnx_path("org/pytorch-only/onnx/model.onnx") - assert result == str(downloaded) + msg = str(exc_info.value) + assert "No .onnx files were found" in msg + assert "org/pytorch-only" in msg diff --git a/tests/unit/onnx/test_detection.py b/tests/unit/onnx/test_detection.py index ba54b8338..cd4f3e1a4 100644 --- a/tests/unit/onnx/test_detection.py +++ b/tests/unit/onnx/test_detection.py @@ -15,8 +15,7 @@ from typing import TYPE_CHECKING import numpy as np -import onnx -from onnx import TensorProto, helper +from onnx import GraphProto, TensorProto, helper, save from winml.modelkit.onnx.detection import is_compiled_onnx, is_quantized_onnx @@ -25,11 +24,11 @@ from pathlib import Path -def _save(graph: onnx.GraphProto, path: Path, *, opset: int = 17) -> Path: +def _save(graph: GraphProto, path: Path, *, opset: int = 17) -> Path: """Save a graph as a minimal ONNX model.""" model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) model.ir_version = 8 - onnx.save(model, str(path)) + save(model, str(path)) return path diff --git a/tests/unit/utils/test_model_input.py b/tests/unit/utils/test_model_input.py new file mode 100644 index 000000000..8f259df38 --- /dev/null +++ b/tests/unit/utils/test_model_input.py @@ -0,0 +1,215 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for winml.modelkit.utils.model_input. + +Covers the single ``-m/--model`` value classifier (:func:`classify_model_input`) +and the classify+download resolver (:func:`resolve_model_input`) that +together replace the previous trio of detectors (``is_hub_model``, +``is_hf_onnx_path``, ``is_onnx_file_path``). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import patch + +from winml.modelkit.utils.model_input import ( + classify_model_input, + resolve_model_input, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# classify_model_input: hub_onnx +# --------------------------------------------------------------------------- + + +class TestClassifyHubOnnx: + """``org/repo/.../file.onnx`` -> ``hub_onnx``.""" + + def test_three_segment_onnx_recognized(self) -> None: + """Repo-id + nested file path is a valid Hub ONNX reference.""" + mi = classify_model_input( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + assert mi.kind == "hub_onnx" + assert mi.hf_id == "onnx-community/sam3-tracker-ONNX" + assert mi.local_path is None # not downloaded by classify + + def test_three_segments_minimum(self) -> None: + """Two segments are too few for a Hub ONNX reference.""" + assert classify_model_input("org/repo/file.onnx").kind == "hub_onnx" + # Two segments ending in .onnx is invalid (not a Hub ref, not a + # plausible HF id either). + assert classify_model_input("org/file.onnx").kind == "invalid" + + def test_uppercase_onnx_extension_accepted(self) -> None: + """Case-insensitive ``.onnx`` matches the rest of the CLI.""" + assert classify_model_input("org/repo/path/file.ONNX").kind == "hub_onnx" + assert classify_model_input("org/repo/path/file.OnNx").kind == "hub_onnx" + + +# --------------------------------------------------------------------------- +# classify_model_input: hf_id +# --------------------------------------------------------------------------- + + +class TestClassifyHfId: + """``org/name`` (no .onnx suffix) -> ``hf_id``.""" + + def test_plain_hf_model_id(self) -> None: + mi = classify_model_input("microsoft/resnet-50") + assert mi.kind == "hf_id" + assert mi.hf_id == "microsoft/resnet-50" + assert mi.local_path is None + + def test_single_segment_hf_id(self) -> None: + """Single-segment IDs (e.g. ``bert-base-uncased``) are still hf_id.""" + mi = classify_model_input("bert-base-uncased") + assert mi.kind == "hf_id" + assert mi.hf_id == "bert-base-uncased" + + def test_three_segment_non_onnx_is_hf_id(self) -> None: + """A non-.onnx three-segment string is still treated as a HF id.""" + # Three segments with a non-onnx extension or no extension shouldn't + # claim to be a Hub ONNX ref. + mi = classify_model_input("org/repo/path/file.bin") + assert mi.kind == "hf_id" + mi2 = classify_model_input("org/repo/path/file") + assert mi2.kind == "hf_id" + + +# --------------------------------------------------------------------------- +# classify_model_input: local_onnx / build_dir +# --------------------------------------------------------------------------- + + +class TestClassifyLocal: + """Local-path branch: existing files/dirs + ``./``/``../``/``~/``/abs prefixes.""" + + def test_existing_local_onnx(self, tmp_path: Path) -> None: + """An existing on-disk .onnx file is classified as local_onnx.""" + local = tmp_path / "org" / "repo" / "file.onnx" + local.parent.mkdir(parents=True) + local.write_bytes(b"") + mi = classify_model_input(str(local)) + assert mi.kind == "local_onnx" + assert mi.local_path == str(local) + assert mi.hf_id is None + + def test_existing_build_dir(self, tmp_path: Path) -> None: + """An existing directory is classified as build_dir.""" + d = tmp_path / "build_out" + d.mkdir() + mi = classify_model_input(str(d)) + assert mi.kind == "build_dir" + assert mi.local_path == str(d) + + def test_relative_path_prefixes_rejected_as_hub(self) -> None: + """``./``, ``../``, ``~/`` prefixes block Hub interpretation.""" + # These strings all have three slash-separated segments and end in + # .onnx, so without local-path rejection they would be misclassified + # as Hub references. + for value in ( + "./org/repo/file.onnx", + "../org/repo/file.onnx", + "~/org/repo/file.onnx", + ): + mi = classify_model_input(value) + # .onnx suffix + local-path prefix => local_onnx (download + # attempt would fail later, but classification is correct) + assert mi.kind == "local_onnx", f"expected local_onnx for {value!r}, got {mi}" + + def test_unix_absolute_path_rejected_as_hub(self) -> None: + """Unix-style absolute paths are treated as local even without an existing file.""" + mi = classify_model_input("/tmp/org/repo/file.onnx") # noqa: S108 - fake path is not a real tempfile + assert mi.kind == "local_onnx" + + def test_windows_absolute_path_rejected_as_hub(self) -> None: + """Windows drive-letter absolute paths are treated as local.""" + # Both backslash and forward-slash separators after the drive + # letter are common on Windows; both must be rejected. + for value in ( + r"C:\models\org\repo\file.onnx", + "C:/models/org/repo/file.onnx", + r"D:\org\repo\file.onnx", + ): + mi = classify_model_input(value) + assert mi.kind == "local_onnx", f"expected local_onnx for {value!r}, got {mi}" + + +# --------------------------------------------------------------------------- +# classify_model_input: invalid / edge +# --------------------------------------------------------------------------- + + +class TestClassifyEdge: + """Empty / unparsable inputs.""" + + def test_empty_string(self) -> None: + mi = classify_model_input("") + assert mi.kind == "invalid" + assert mi.raw == "" + + def test_raw_preserved(self) -> None: + """`raw` always echoes the original input regardless of kind.""" + for value in ( + "microsoft/resnet-50", + "org/repo/path/file.onnx", + "./model.onnx", + ): + assert classify_model_input(value).raw == value + + +# --------------------------------------------------------------------------- +# resolve_model_input: pass-through + download +# --------------------------------------------------------------------------- + + +class TestResolveModelInput: + """``resolve_model_input`` == classify + download for hub_onnx only.""" + + def test_hf_id_pass_through_no_network(self) -> None: + """``microsoft/resnet-50`` returns unchanged; no Hub download attempted.""" + with patch("huggingface_hub.hf_hub_download") as mock: + mi = resolve_model_input("microsoft/resnet-50") + assert mi.kind == "hf_id" + assert mi.local_path is None + mock.assert_not_called() + + def test_local_onnx_pass_through_no_network(self, tmp_path: Path) -> None: + """Existing local ``.onnx`` paths take precedence over Hub interpretation.""" + local = tmp_path / "model.onnx" + local.write_bytes(b"") + with patch("huggingface_hub.hf_hub_download") as mock: + mi = resolve_model_input(str(local)) + assert mi.kind == "local_onnx" + assert mi.local_path == str(local) + mock.assert_not_called() + + def test_hub_ref_is_downloaded(self, tmp_path: Path) -> None: + """A Hub-style ONNX ref triggers ``resolve_hf_onnx_path`` and populates local_path.""" + from huggingface_hub.utils import EntryNotFoundError + + downloaded = tmp_path / "vision_encoder_int8.onnx" + downloaded.write_bytes(b"") + + def _fake_download(*, repo_id, filename, revision, cache_dir, token): + if filename.endswith(".onnx_data"): + raise EntryNotFoundError(filename) + return str(downloaded) + + with patch("huggingface_hub.hf_hub_download", side_effect=_fake_download): + mi = resolve_model_input( + "onnx-community/sam3-tracker-ONNX/onnx/vision_encoder_int8.onnx" + ) + + assert mi.kind == "hub_onnx" + assert mi.local_path == str(downloaded) + assert mi.hf_id == "onnx-community/sam3-tracker-ONNX" From 4c8428f0eaf8c816a42d97c0175ddc8655d61151 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 16 Jun 2026 16:28:27 +0000 Subject: [PATCH 4/4] fix: export resolve_task and related symbols from loader __init__.py --- src/winml/modelkit/loader/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/winml/modelkit/loader/__init__.py b/src/winml/modelkit/loader/__init__.py index e7114300a..6e112e08a 100644 --- a/src/winml/modelkit/loader/__init__.py +++ b/src/winml/modelkit/loader/__init__.py @@ -27,6 +27,7 @@ from .config import WinMLLoaderConfig, resolve_loader_config from .onnx_hub import resolve_hf_onnx_path +from .resolution import TaskResolution, TaskSource, resolve_composite, resolve_task from .task import ( HF_TASK_DEFAULTS, KNOWN_TASKS,