From 18d439f2588c3205c55825ba554c9451337bd8fa Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 18 Jun 2026 10:40:46 +0800 Subject: [PATCH 1/3] add 3 folder --- .github/workflows/lint.yml | 5 ++++- src/winml/modelkit/optracing/qnn/profiler.py | 8 ++++++-- .../modelkit/optracing/qnn/qhas_parser.py | 2 +- src/winml/modelkit/quant/__init__.py | 4 +++- src/winml/modelkit/serve/app.py | 19 ++++++++++++------- src/winml/modelkit/serve/cli_api.py | 7 +++++-- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 19fb415d9..e39675668 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ concurrency: jobs: lint: runs-on: windows-latest - # Bumped from 5: combined mypy on 16 packages cold-starts at ~3-4 min on + # Bumped from 5: combined mypy on 19 packages cold-starts at ~3-4 min on # Windows runners; the original 5-min ceiling cancelled mid-run. timeout-minutes: 10 @@ -64,3 +64,6 @@ jobs: -p winml.modelkit.loader -p winml.modelkit.onnx -p winml.modelkit.optim + -p winml.modelkit.optracing + -p winml.modelkit.quant + -p winml.modelkit.serve diff --git a/src/winml/modelkit/optracing/qnn/profiler.py b/src/winml/modelkit/optracing/qnn/profiler.py index 12bb44403..0df9e22f2 100644 --- a/src/winml/modelkit/optracing/qnn/profiler.py +++ b/src/winml/modelkit/optracing/qnn/profiler.py @@ -19,7 +19,7 @@ import logging import os from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -29,6 +29,10 @@ from .viewer import find_qnn_sdk, run_qhas_viewer +if TYPE_CHECKING: + from collections.abc import Iterator + + logger = logging.getLogger(__name__) @@ -53,7 +57,7 @@ def _resolve_shape(shape: list, default_dim: int = 1) -> list[int]: @contextlib.contextmanager -def _working_directory(path: Path): +def _working_directory(path: Path) -> Iterator[None]: """Temporarily change CWD and restore on exit. QNN EP writes ``*_schematic.bin`` into the process CWD, so we diff --git a/src/winml/modelkit/optracing/qnn/qhas_parser.py b/src/winml/modelkit/optracing/qnn/qhas_parser.py index 9bc267ccc..77a19c0ea 100644 --- a/src/winml/modelkit/optracing/qnn/qhas_parser.py +++ b/src/winml/modelkit/optracing/qnn/qhas_parser.py @@ -110,4 +110,4 @@ def _vtcm_ratio(op: dict) -> float | None: total = vtcm_read + dram_read if total == 0: return None - return vtcm_read / total + return float(vtcm_read / total) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 2e6c2c279..bc8e6ee06 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -16,6 +16,8 @@ result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) """ +from typing import Any + from .config import QuantizeResult, WinMLQuantizationConfig @@ -31,7 +33,7 @@ } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy-load quantizer (imports onnxruntime.quantization).""" if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/src/winml/modelkit/serve/app.py b/src/winml/modelkit/serve/app.py index d19ee96bd..b38f96f6b 100644 --- a/src/winml/modelkit/serve/app.py +++ b/src/winml/modelkit/serve/app.py @@ -23,6 +23,7 @@ import asyncio import base64 +import binascii import importlib.resources import json import logging @@ -30,7 +31,7 @@ from collections import deque from contextlib import asynccontextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware @@ -57,6 +58,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from ..utils.constants import EPNameOrAlias logger = logging.getLogger(__name__) @@ -166,7 +169,7 @@ def create_app( """ @asynccontextmanager - async def lifespan(app: FastAPI): + async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.start_time = time.time() # Raise the modelkit logger to INFO so the ring handler receives # operational records during `winml serve`. Tests that build the app @@ -187,6 +190,8 @@ async def lifespan(app: FastAPI): logger.info("Multi-model server started (empty — load via POST /v1/models)") app.state.manager = mgr else: + if model_path is None: + raise ValueError("single-model mode requires a model_path") engine = InferenceEngine() engine.load(model_path, task=task, device=device, ep=ep) app.state.manager = SingleModelManager(engine, idle_timeout_sec=idle_timeout_sec) @@ -240,7 +245,7 @@ def _get_mgr() -> SingleModelManager | ModelSlotManager: mgr = getattr(app.state, "manager", None) if mgr is None: raise HTTPException(status_code=503, detail="Model not loaded yet") - return mgr + return cast("SingleModelManager | ModelSlotManager", mgr) def _get_start_time() -> float: return getattr(app.state, "start_time", time.time()) @@ -733,13 +738,13 @@ async def cli_command(command: str, request: CliRequest) -> CliResponse: # --------------------------------------------------------------------------- -def _manifest_from_engine(engine: InferenceEngine) -> dict: +def _manifest_from_engine(engine: InferenceEngine) -> dict[str, Any]: """Build manifest dict from engine, trying build_manifest.json first.""" if engine.model_path: manifest_file = Path(engine.model_path) / "build_manifest.json" if manifest_file.exists(): try: - return json.loads(manifest_file.read_text()) + return cast("dict[str, Any]", json.loads(manifest_file.read_text())) except (json.JSONDecodeError, OSError) as e: logger.warning("Failed to load manifest: %s", e) @@ -753,7 +758,7 @@ def _manifest_from_engine(engine: InferenceEngine) -> dict: def _build_model_schema( - manifest: dict, + manifest: dict[str, Any], engine: InferenceEngine | None = None, task_override: str | None = None, ) -> dict[str, Any]: @@ -828,7 +833,7 @@ def _decode_rest_inputs( if field and field.type in BINARY_TYPES and isinstance(value, str): try: result[name] = base64.b64decode(value) - except (ValueError, base64.binascii.Error) as exc: + except (ValueError, binascii.Error) as exc: raise ValueError(f"Invalid base64 for input '{name}': {exc}") from exc return result diff --git a/src/winml/modelkit/serve/cli_api.py b/src/winml/modelkit/serve/cli_api.py index c47a5cf56..2bb6d8053 100644 --- a/src/winml/modelkit/serve/cli_api.py +++ b/src/winml/modelkit/serve/cli_api.py @@ -27,7 +27,7 @@ import tempfile import time from pathlib import Path -from typing import Any +from typing import Any, cast from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -212,7 +212,10 @@ def _extract_json_from_stdout(stdout: str) -> dict[str, Any] | list[Any] | None: if start == -1: break try: - return json.loads(stdout[start : end + 1]) + return cast( + "dict[str, Any] | list[Any] | None", + json.loads(stdout[start : end + 1]), + ) except json.JSONDecodeError: continue pos = end # try next (earlier) end_char From 7e4a08fd3be7425a4408064aaeff9a072b8be2a3 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 18 Jun 2026 13:49:42 +0800 Subject: [PATCH 2/3] continue --- .github/workflows/lint.yml | 5 +- pyproject.toml | 12 + src/winml/modelkit/session/ep_registry.py | 9 +- src/winml/modelkit/session/monitor/_pdh.py | 32 +- .../modelkit/session/monitor/_xrt_smi.py | 4 +- .../session/qairt/compile_qairt_bin.py | 2 +- .../modelkit/session/qairt/qairt_session.py | 6 +- src/winml/modelkit/session/session.py | 20 +- src/winml/modelkit/telemetry/click_group.py | 7 +- src/winml/modelkit/telemetry/consent.py | 10 +- .../modelkit/telemetry/deviceid/_store.py | 4 +- .../modelkit/telemetry/library/exporter.py | 6 +- src/winml/modelkit/telemetry/telemetry.py | 12 +- src/winml/modelkit/telemetry/utils.py | 13 +- src/winml/modelkit/utils/constants.py | 6 +- src/winml/modelkit/utils/data_utils.py | 131 ++-- src/winml/modelkit/utils/hub_utils.py | 692 +++++++++--------- src/winml/modelkit/utils/native_stderr.py | 8 +- src/winml/modelkit/utils/optimum_loader.py | 320 ++++---- 19 files changed, 688 insertions(+), 611 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e39675668..6089eef89 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ concurrency: jobs: lint: runs-on: windows-latest - # Bumped from 5: combined mypy on 19 packages cold-starts at ~3-4 min on + # Bumped from 5: combined mypy on 22 packages cold-starts at ~3-4 min on # Windows runners; the original 5-min ceiling cancelled mid-run. timeout-minutes: 10 @@ -67,3 +67,6 @@ jobs: -p winml.modelkit.optracing -p winml.modelkit.quant -p winml.modelkit.serve + -p winml.modelkit.session + -p winml.modelkit.sysinfo + -p winml.modelkit.telemetry diff --git a/pyproject.toml b/pyproject.toml index 30c73f58a..1412e0929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -497,9 +497,21 @@ module = [ "sklearn.*", # used in eval/metrics; no community stubs "evaluate", # HF evaluate, used in eval/; no community stubs "evaluate.*", + # QAIRT (Qualcomm AI Runtime) SDK — imported only inside compile_qairt_bin.py, + # which runs in a separate venv-winml subprocess where the SDK is installed. + # Not a dependency of the main/CI environment, so it has no stubs here. + "qairt", + "qairt.*", ] ignore_missing_imports = true +# windowsml ships no py.typed marker, but its source is installed and usable — +# analyze it directly (PEP 561 opt-in) so its inline annotations are honored +# instead of collapsing every symbol to Any. +[[tool.mypy.overrides]] +module = [ "windowsml", "windowsml.*" ] +follow_untyped_imports = true + # Relaxed modules: tests and WIP code [[tool.mypy.overrides]] diff --git a/src/winml/modelkit/session/ep_registry.py b/src/winml/modelkit/session/ep_registry.py index 244da2a7a..d78112f28 100644 --- a/src/winml/modelkit/session/ep_registry.py +++ b/src/winml/modelkit/session/ep_registry.py @@ -228,6 +228,10 @@ class WinMLEPRegistry: available = registry.get_available_eps() """ + # Set in __new__ before __init__ runs; declared here so mypy can resolve its + # type at the __init__ read site. + _initialized: bool + def __new__(cls) -> WinMLEPRegistry: """Singleton pattern.""" global _winml_ep_registry @@ -269,7 +273,8 @@ def _load_ep_catalog(self) -> None: """Load EP catalog from WinML.""" from windowsml import EpCatalog - checked_eps: set[EPName] = set() + # Dedup of raw windowsml provider-name strings (pre-validation). + checked_eps: set[str] = set() with EpCatalog() as catalog: for provider in catalog.find_all_providers(): if provider.name in checked_eps: @@ -402,4 +407,4 @@ def get_ort_available_providers(use_winml: bool = True) -> list[str]: except Exception as e: logger.debug("WinML discovery skipped: %s", e) - return ort.get_available_providers() + return cast("list[str]", ort.get_available_providers()) diff --git a/src/winml/modelkit/session/monitor/_pdh.py b/src/winml/modelkit/session/monitor/_pdh.py index 0e0364dc4..60100be71 100644 --- a/src/winml/modelkit/session/monitor/_pdh.py +++ b/src/winml/modelkit/session/monitor/_pdh.py @@ -178,22 +178,24 @@ def _collect_once(self) -> dict[str, float | int | None]: continue if entry.fmt == _PDH_FMT_DOUBLE: - val = _PdhFmtDouble() + dval = _PdhFmtDouble() s = _pdh.PdhGetFormattedCounterValue( entry.handle, _PDH_FMT_DOUBLE | _PDH_FMT_NOCAP100, ctypes.byref(ct), - ctypes.byref(val), + ctypes.byref(dval), ) values[entry.name] = ( - val.doubleValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None + dval.doubleValue if _pdh_ok(s) and _pdh_ok(dval.CStatus) else None ) else: - val = _PdhFmtLarge() + lval = _PdhFmtLarge() s = _pdh.PdhGetFormattedCounterValue( - entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(val) + entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(lval) + ) + values[entry.name] = ( + lval.largeValue if _pdh_ok(s) and _pdh_ok(lval.CStatus) else None ) - values[entry.name] = val.largeValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None return values @@ -403,16 +405,17 @@ def __init__( self._stop_event = threading.Event() self._lock = threading.Lock() self._util_samples: list[float] = [] - self._memory_local_bytes: list[int] = [] - self._memory_shared_bytes: list[int] = [] + # PDH counters arrive as float|int (double vs large format), so store as float. + self._memory_local_bytes: list[float] = [] + self._memory_shared_bytes: list[float] = [] self._cpu_samples: list[float] = [] - self._ram_used_bytes: list[int] = [] + self._ram_used_bytes: list[float] = [] # Per-engtype snapshots of the monotonic Running Time counter. Stored # as dicts because an adapter exposes multiple engines (e.g. several # Compute_* on an NPU; 3D + Compute_* on a GPU) and the total adapter # time is the sum of per-engine deltas. - self._running_time_start_ns: dict[str, int] = {} - self._running_time_end_ns: dict[str, int] = {} + self._running_time_start_ns: dict[str, float] = {} + self._running_time_end_ns: dict[str, float] = {} def start(self) -> None: """Resolve target device, register PDH counters, start background thread. @@ -532,8 +535,11 @@ def _poll_loop(self) -> None: # normalize so cpu_pct stays 0..100 across machines. cpu_divisor = float(os.cpu_count() or 1) while not self._stop_event.is_set(): + query = self._query + if query is None: + break try: - values = self._query._collect_once() + values = query._collect_once() # util_* counters are per-engine ratios over the same sample # window, so max reports the most-loaded engine on the adapter. # Don't sum — that would exceed 100% and duplicate what the @@ -739,7 +745,7 @@ def running_time_delta_ns(self) -> int: start = self._running_time_start_ns.get(key) if start is None: continue - total += max(0, end - start) + total += int(max(0, end - start)) return total @staticmethod diff --git a/src/winml/modelkit/session/monitor/_xrt_smi.py b/src/winml/modelkit/session/monitor/_xrt_smi.py index aa11e15e5..c9fd5306a 100644 --- a/src/winml/modelkit/session/monitor/_xrt_smi.py +++ b/src/winml/modelkit/session/monitor/_xrt_smi.py @@ -21,7 +21,7 @@ import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, cast logger = logging.getLogger(__name__) @@ -109,7 +109,7 @@ def snapshot(self) -> dict[str, Any]: return {} with Path(tmp_path).open(encoding="utf-8") as f: - return json.load(f) + return cast("dict[str, Any]", json.load(f)) except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as exc: logger.debug("xrt-smi snapshot failed: %s", exc) diff --git a/src/winml/modelkit/session/qairt/compile_qairt_bin.py b/src/winml/modelkit/session/qairt/compile_qairt_bin.py index 9399361d2..e9225309c 100644 --- a/src/winml/modelkit/session/qairt/compile_qairt_bin.py +++ b/src/winml/modelkit/session/qairt/compile_qairt_bin.py @@ -52,7 +52,7 @@ def extract_input_specs(model_path: Path) -> list[dict]: # This script runs inside the QAIRT SDK's venv; use vanilla onnx.load. model = onnx.load(str(model_path)) - dtype_map = { + dtype_map: dict[int, type] = { onnx.TensorProto.FLOAT: np.float32, onnx.TensorProto.FLOAT16: np.float16, onnx.TensorProto.INT8: np.int8, diff --git a/src/winml/modelkit/session/qairt/qairt_session.py b/src/winml/modelkit/session/qairt/qairt_session.py index 47dfe2b9f..ce6bf60d9 100644 --- a/src/winml/modelkit/session/qairt/qairt_session.py +++ b/src/winml/modelkit/session/qairt/qairt_session.py @@ -11,7 +11,7 @@ import os import subprocess from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from ...utils.python_env import ensure_venv from ..session import SessionState, WinMLSession @@ -197,8 +197,8 @@ def _wrap_bin_to_onnx(self) -> None: qnn_version = qnn_json_obj["info"]["buildId"] for qnn_graph in qnn_json_obj["info"]["graphs"]: - qnn_input_tensor_dic = {} - qnn_output_tensor_dic = {} + qnn_input_tensor_dic: dict[str, Any] = {} + qnn_output_tensor_dic: dict[str, Any] = {} graph_name = gen_qnn_ctx_onnx_model.parse_qnn_graph( qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic ) diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index 9f689084b..2ec224b91 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -327,6 +327,16 @@ def run( if self._session is None: self.compile() + # compile() populates self._session or raises; bind a non-None local so + # the narrowing survives into the lambda / comprehension below (mypy drops + # self-attribute narrowing inside nested scopes). + session = self._session + if session is None: + raise InferenceError( + message="Session not available after compile", + context={}, + ) + if self._state == SessionState.ERROR: raise InferenceError( message="Session in error state", @@ -340,16 +350,16 @@ def run( self._validate_inputs(inputs) # Prepare inputs (convert to numpy, enforce dtype) - ort_inputs = self._prepare_inputs(inputs, self._session) + ort_inputs = self._prepare_inputs(inputs, session) # Run inference (with optional perf tracking) - output_names = [o.name for o in self._session.get_outputs()] + output_names = [o.name for o in session.get_outputs()] if self._perf_stats: outputs = self._perf_stats.record( - lambda: self._session.run(output_names, ort_inputs) + lambda: session.run(output_names, ort_inputs) ) else: - outputs = self._session.run(output_names, ort_inputs) + outputs = session.run(output_names, ort_inputs) # Build result dict return dict(zip(output_names, outputs, strict=True)) @@ -396,6 +406,8 @@ def _build_session_options(self, device: str) -> tuple[ort.SessionOptions, str, resolved_device, _ = resolve_device(device, ep=self._ep) resolved_ep = normalize_ep_name(self._ep) if self._ep else resolve_eps(resolved_device)[0] + if resolved_ep is None: + raise ValueError(f"Unknown execution provider: {self._ep!r}") device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper()) opts = self._session_options_factory() diff --git a/src/winml/modelkit/telemetry/click_group.py b/src/winml/modelkit/telemetry/click_group.py index cb391d368..20f14363c 100644 --- a/src/winml/modelkit/telemetry/click_group.py +++ b/src/winml/modelkit/telemetry/click_group.py @@ -39,7 +39,9 @@ class ActionGroup(click.Group): """Click group that auto-instruments every registered command.""" - def resolve_command(self, ctx, args): + def resolve_command( + self, ctx: click.Context, args: list[str] + ) -> tuple[str | None, click.Command | None, list[str]]: """Wrap the resolved subcommand with telemetry instrumentation.""" cmd_name, cmd, remaining = super().resolve_command(ctx, args) if cmd is None: @@ -106,7 +108,8 @@ def wrapped_invoke(ctx: click.Context) -> Any: success=success, ) - cmd.invoke = wrapped_invoke + # Intentional per-instance monkeypatch to wrap this command's invoke(). + cmd.invoke = wrapped_invoke # type: ignore[method-assign] setattr(cmd, _INSTRUMENTED_ATTR, True) return cmd diff --git a/src/winml/modelkit/telemetry/consent.py b/src/winml/modelkit/telemetry/consent.py index 61c6a9343..74cc66800 100644 --- a/src/winml/modelkit/telemetry/consent.py +++ b/src/winml/modelkit/telemetry/consent.py @@ -20,7 +20,7 @@ import sys import tempfile from pathlib import Path -from typing import Literal +from typing import Literal, cast from .utils import _resolve_user_home @@ -122,7 +122,9 @@ def _read_stored_consent() -> Consent | None: and stored_version < _CONSENT_VERSION ): return None - return value # type: ignore[return-value] + # value was validated against ("enabled", "disabled") above, but it comes + # from an untyped JSON dict so mypy still sees it as Any. + return cast("Consent", value) def _write_stored_consent(value: Consent) -> None: @@ -134,7 +136,9 @@ def _write_stored_consent(value: Consent) -> None: if _CONFIG_PATH is None: return data = _load_config() - tele = data.get("telemetry") if isinstance(data.get("telemetry"), dict) else {} + tele = data.get("telemetry") + if not isinstance(tele, dict): + tele = {} tele["consent"] = value tele["consent_version"] = _CONSENT_VERSION data["telemetry"] = tele diff --git a/src/winml/modelkit/telemetry/deviceid/_store.py b/src/winml/modelkit/telemetry/deviceid/_store.py index c75e4d2c5..152a5a719 100644 --- a/src/winml/modelkit/telemetry/deviceid/_store.py +++ b/src/winml/modelkit/telemetry/deviceid/_store.py @@ -12,6 +12,8 @@ from __future__ import annotations +from typing import cast + _REGISTRY_KEY = r"SOFTWARE\Microsoft\DeveloperTools\.modelkit" @@ -27,7 +29,7 @@ def read_key(name: str) -> str | None: return None if value_type != winreg.REG_SZ: return None - return value # already str for REG_SZ + return cast("str", value) # already str for REG_SZ def write_key(name: str, value: str) -> None: diff --git a/src/winml/modelkit/telemetry/library/exporter.py b/src/winml/modelkit/telemetry/library/exporter.py index 7a8600931..1e80757e9 100644 --- a/src/winml/modelkit/telemetry/library/exporter.py +++ b/src/winml/modelkit/telemetry/library/exporter.py @@ -23,6 +23,7 @@ from collections.abc import Sequence from opentelemetry.sdk._logs import ReadableLogRecord + from opentelemetry.sdk.resources import Resource _LOGGER = logging.getLogger(__name__) @@ -226,7 +227,8 @@ def shutdown(self) -> None: def _to_envelope(self, ld: ReadableLogRecord) -> dict: record = ld.log_record - timestamp = _ns_to_datetime(record.timestamp) + # OTel types timestamp as Optional; fall back to now if a record omits it. + timestamp = _ns_to_datetime(record.timestamp or time.time_ns()) data = dict(record.attributes or {}) ext = _resource_to_ext(ld.resource) return _build_envelope( @@ -266,7 +268,7 @@ def _ns_to_datetime(ts_ns: int) -> datetime: return datetime.fromtimestamp(ts_ns / 1_000_000_000, tz=timezone.utc) -def _resource_to_ext(resource) -> dict: +def _resource_to_ext(resource: Resource | None) -> dict: """Translate OpenTelemetry Resource attributes to CS 4.0 ext.* slots. Attribute name → CS slot mapping: diff --git a/src/winml/modelkit/telemetry/telemetry.py b/src/winml/modelkit/telemetry/telemetry.py index ddc74eb8b..ae7bd5cbd 100644 --- a/src/winml/modelkit/telemetry/telemetry.py +++ b/src/winml/modelkit/telemetry/telemetry.py @@ -28,6 +28,10 @@ if TYPE_CHECKING: + from opentelemetry._logs import Logger + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk.resources import Resource + from ..utils.constants import EPNameOrAlias @@ -88,8 +92,8 @@ class Telemetry: """ def __init__(self) -> None: - self._logger = None # set when enabled; None when disabled - self._provider = None + self._logger: Logger | None = None # set when enabled; None when disabled + self._provider: LoggerProvider | None = None self._disabled = True # set to False only after successful init self._app_instance_id = str(uuid.uuid4()) # Kept in the event schema for forward-compat: today @@ -147,7 +151,7 @@ def _try_init(self) -> None: self._disabled = True _clear_cache_quietly() - def _build_resource(self): + def _build_resource(self) -> Resource: from opentelemetry.sdk.resources import Resource device_id, id_status = get_or_create_device_id() @@ -219,6 +223,8 @@ def log_error(self, exc: BaseException) -> None: def _emit(self, event_name: str, attrs: dict[str, Any]) -> None: from opentelemetry._logs import LogRecord + if self._logger is None: + return filtered = _filter_allowlist(event_name, attrs) record = LogRecord( timestamp=time.time_ns(), diff --git a/src/winml/modelkit/telemetry/utils.py b/src/winml/modelkit/telemetry/utils.py index 5d79334ca..cbe9412b3 100644 --- a/src/winml/modelkit/telemetry/utils.py +++ b/src/winml/modelkit/telemetry/utils.py @@ -19,7 +19,11 @@ import re import traceback from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from types import TracebackType _PACKAGE_ROOT = "winml/modelkit" @@ -199,7 +203,12 @@ def __enter__(self) -> _ExclusiveFileLock: self._fd = fd return self - def __exit__(self, exc_type, exc, tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: import msvcrt if self._fd is None: diff --git a/src/winml/modelkit/utils/constants.py b/src/winml/modelkit/utils/constants.py index 7b45d153a..4e85aaca6 100644 --- a/src/winml/modelkit/utils/constants.py +++ b/src/winml/modelkit/utils/constants.py @@ -7,7 +7,7 @@ from __future__ import annotations import sys -from typing import Literal, TypeAlias, get_args +from typing import Literal, TypeAlias, cast, get_args if sys.platform == "win32": @@ -137,7 +137,9 @@ def normalize_ep_name(ep: EPNameOrAlias | None) -> EPName | None: # the prior membership check narrowed ``ep_lower`` so the alias mapping is # total in this branch. ep_lower = ep.lower() - canonical = EP_ALIASES.get(ep_lower) # type: ignore[arg-type] + # ep_lower is an arbitrary lowercased string; cast to the key type for the + # lookup (.get tolerates non-alias keys, returning None). + canonical = EP_ALIASES.get(cast("EPAlias", ep_lower)) if canonical is not None: return canonical diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py index bccf7acd1..2903eab8b 100644 --- a/src/winml/modelkit/utils/data_utils.py +++ b/src/winml/modelkit/utils/data_utils.py @@ -1,64 +1,67 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -"""Data utilities for input preparation and padding.""" - -from __future__ import annotations - -from typing import Any, Literal - -import torch - - -def pad_inputs( - source: dict[str, Any], - expected: dict[str, list[int]], - mode: Literal["left", "right"] = "right", -) -> dict[str, Any]: - """Filter *source* to keys in *expected* and pad undersized tensors. - - For each name in *expected*, if *source* has a tensor for it, pad any - dimension smaller than the ONNX expected shape (skips batch dim). - Non-tensor values are passed through. Missing names are skipped. - - Args: - source: Input tensors keyed by name. - expected: ONNX expected shapes keyed by input name. - mode: Padding side — ``"right"`` (default, pad at end) or - ``"left"`` (pad at start). - - Returns: - Filtered and padded tensors matching *expected* keys. - """ - if mode not in ("right", "left"): - raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") - - result: dict[str, Any] = {} - for name, expected_shape in expected.items(): - val = source.get(name) - if val is None: - continue - if isinstance(val, torch.Tensor): - # TODO: support dynamic shape ONNX models (None in expected_shape) - ndim = min(len(val.shape), len(expected_shape)) - # torch.nn.functional.pad takes pairs (low, high) from the LAST - # dim backwards. Skip batch dim (dim 0). - pad: list[int] = [] - for dim in reversed(range(1, ndim)): - exp = expected_shape[dim] - # Dynamic ONNX dims may be None or a string symbol; emit a - # (0, 0) pair so later pairs stay aligned with their dim index. - if not isinstance(exp, int): - pad.extend([0, 0]) - continue - deficit = max(exp - val.shape[dim], 0) - if mode == "right": - pad.extend([0, deficit]) - else: # left - pad.extend([deficit, 0]) - if any(p > 0 for p in pad): - val = torch.nn.functional.pad(val, pad) - result[name] = val - return result +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Data utilities for input preparation and padding.""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + + +def pad_inputs( + source: dict[str, Any], + expected: dict[str, list[int]], + mode: Literal["left", "right"] = "right", +) -> dict[str, Any]: + """Filter *source* to keys in *expected* and pad undersized tensors. + + For each name in *expected*, if *source* has a tensor for it, pad any + dimension smaller than the ONNX expected shape (skips batch dim). + Non-tensor values are passed through. Missing names are skipped. + + Args: + source: Input tensors keyed by name. + expected: ONNX expected shapes keyed by input name. + mode: Padding side — ``"right"`` (default, pad at end) or + ``"left"`` (pad at start). + + Returns: + Filtered and padded tensors matching *expected* keys. + """ + if mode not in ("right", "left"): + raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") + + result: dict[str, Any] = {} + for name, expected_shape in expected.items(): + val = source.get(name) + if val is None: + continue + if isinstance(val, torch.Tensor): + # TODO: support dynamic shape ONNX models (None in expected_shape) + ndim = min(len(val.shape), len(expected_shape)) + # torch.nn.functional.pad takes pairs (low, high) from the LAST + # dim backwards. Skip batch dim (dim 0). + pad: list[int] = [] + for dim in reversed(range(1, ndim)): + exp = expected_shape[dim] + # Dynamic ONNX dims may be None or a string symbol; emit a + # (0, 0) pair so later pairs stay aligned with their dim index. + if not isinstance(exp, int): + # Forward-looking: expected is typed list[int] today, but ONNX + # dynamic dims (None / str symbol) are a planned input (see TODO + # above). Keep the guard until the signature widens for them. + pad.extend([0, 0]) # type: ignore[unreachable] + continue + deficit = max(exp - val.shape[dim], 0) + if mode == "right": + pad.extend([0, deficit]) + else: # left + pad.extend([deficit, 0]) + if any(p > 0 for p in pad): + val = torch.nn.functional.pad(val, pad) + result[name] = val + return result diff --git a/src/winml/modelkit/utils/hub_utils.py b/src/winml/modelkit/utils/hub_utils.py index 1de403c8d..cfb7cc279 100644 --- a/src/winml/modelkit/utils/hub_utils.py +++ b/src/winml/modelkit/utils/hub_utils.py @@ -1,344 +1,348 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""HuggingFace Hub utilities for model detection and configuration loading. - -This module provides intelligent detection of HuggingFace Hub models vs local models, -and handles the appropriate metadata storage and configuration loading strategies. -""" - -import logging -import re -from pathlib import Path -from typing import Any - - -logger = logging.getLogger(__name__) - - -def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: - """Comprehensive Hub model detection with metadata extraction. - - Args: - model_name_or_path: Model identifier or path - - Returns: - Tuple of (is_hub_model, metadata_dict) - """ - # Quick rejection for obvious local paths - if Path(model_name_or_path).exists(): - return False, {"type": "local", "path": model_name_or_path} - - # Check for local path indicators - if any(model_name_or_path.startswith(prefix) for prefix in ["./", "../", "/", "~/"]): - return False, {"type": "local", "path": model_name_or_path} - - # Check for Windows absolute paths - if re.match(r"^[A-Za-z]:[\\/]", model_name_or_path): - return False, {"type": "local", "path": model_name_or_path} - - # Parse potential Hub model format - # Supports: model-name, org/model, org/model@revision - hub_pattern = r"^(?:([^/@]+)/)?([^/@]+)(?:@(.+))?$" - match = re.match(hub_pattern, model_name_or_path) - - if not match: - return False, {"type": "invalid"} - - org, model, revision = match.groups() - full_model_id = f"{org}/{model}" if org else model - - # Try to verify with Hub API - try: - from huggingface_hub import HfApi - - api = HfApi() - model_info = api.model_info(full_model_id, revision=revision) - - # Extract comprehensive metadata - metadata = { - "type": "hub", - "model_id": model_info.modelId, - "sha": model_info.sha, - "revision": revision or "main", - "tags": model_info.tags if hasattr(model_info, "tags") else [], - "pipeline_tag": model_info.pipeline_tag - if hasattr(model_info, "pipeline_tag") - else None, - "library_name": model_info.library_name - if hasattr(model_info, "library_name") - else None, - "author": model_info.author if hasattr(model_info, "author") else None, - "last_modified": str(model_info.lastModified) - if hasattr(model_info, "lastModified") - else None, - "private": model_info.private if hasattr(model_info, "private") else False, - "gated": model_info.gated if hasattr(model_info, "gated") else False, - } - - # Try to get model card info if available - try: - from huggingface_hub import ModelCard - - card = ModelCard.load(full_model_id) - if hasattr(card.data, "base_model"): - metadata["base_model"] = card.data.base_model - if hasattr(card.data, "license"): - metadata["license"] = card.data.license - if hasattr(card.data, "language"): - metadata["language"] = card.data.language - if hasattr(card.data, "task_categories"): - metadata["task_categories"] = card.data.task_categories - except Exception: - pass - - return True, metadata - - except Exception as e: - # Could not verify with Hub - might be private or offline - # Use heuristics to guess - if len(model_name_or_path.split("/")) <= 2 and "\\" not in model_name_or_path: - return True, { - "type": "hub_unverified", - "model_id": full_model_id, - "revision": revision or "main", - "error": str(e), - } - return False, {"type": "local", "path": model_name_or_path} - - -def inject_hub_metadata(onnx_model: Any, model_name_or_path: str, metadata: dict) -> None: - """Inject HuggingFace Hub metadata into ONNX model. - - Args: - onnx_model: ONNX model proto - model_name_or_path: Original model identifier - metadata: Hub metadata dictionary - """ - from datetime import datetime, timezone - - # Clear any existing HF metadata - # We need to remove items by filtering, not reassigning - hf_props_to_remove = [] - for i, prop in enumerate(onnx_model.metadata_props): - if prop.key.startswith("hf_"): - hf_props_to_remove.append(i) - - # Remove in reverse order to maintain indices - for i in reversed(hf_props_to_remove): - del onnx_model.metadata_props[i] - - # Add required metadata - def add_prop(key: str, value: Any) -> None: - if value is not None: - prop = onnx_model.metadata_props.add() - prop.key = key - prop.value = str(value) - - # Required fields - add_prop("hf_hub_id", metadata.get("model_id")) - add_prop("hf_hub_revision", metadata.get("sha", "")[:8]) - add_prop("hf_model_type", "hub") - - # Get ModelExport version - try: - from ..version import __version__ - - export_version = __version__ - except ImportError: - export_version = "unknown" - - add_prop("hf_export_version", export_version) - add_prop("hf_export_timestamp", datetime.now(timezone.utc).isoformat()) - - # Optional fields - for key in ["pipeline_tag", "library_name", "base_model", "private", "gated"]: - if key in metadata: - add_prop(f"hf_{key}", metadata[key]) - - # Producer information - onnx_model.producer_name = "ModelExport-HTP" - onnx_model.producer_version = export_version - onnx_model.domain = "com.modelexport.htp" - - # Add doc string for human readability - onnx_model.doc_string = ( - f"Exported from HuggingFace model: {metadata.get('model_id')}\n" - f"Revision: {metadata.get('sha', 'unknown')[:8]}\n" - f"Export timestamp: {datetime.now(timezone.utc).isoformat()}\n" - f"ModelExport version: {export_version}" - ) - - -def save_local_model_configs(model_name_or_path: str, output_dir: Path, metadata: dict) -> None: - """Save configuration files for local/in-house models. - - Args: - model_name_or_path: Path to local model - output_dir: Directory to save configs - metadata: Local model metadata - """ - # Check if the path exists first - if not Path(model_name_or_path).exists(): - logger.info(f"Local model path {model_name_or_path} does not exist, skipping config copy") - return - - try: - from transformers import AutoConfig - - # Save config - config = AutoConfig.from_pretrained(model_name_or_path) - config.save_pretrained(output_dir) - logger.info(f"Saved config.json to {output_dir}") - - # Track what components were saved - components_saved = [] - - # Try AutoProcessor (for multimodal) - try: - from transformers import AutoProcessor - - processor = AutoProcessor.from_pretrained(model_name_or_path) - processor.save_pretrained(output_dir) - components_saved.append("processor") - except Exception: - pass - - # Try AutoTokenizer (for text models) - only if processor wasn't saved - if "processor" not in components_saved: - try: - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.save_pretrained(output_dir) - components_saved.append("tokenizer") - except Exception: - pass - - # Try AutoImageProcessor (for vision) - try: - from transformers import AutoImageProcessor - - image_processor = AutoImageProcessor.from_pretrained(model_name_or_path) - image_processor.save_pretrained(output_dir) - components_saved.append("image_processor") - except Exception: - pass - - # Try AutoFeatureExtractor (for audio) - try: - from transformers import AutoFeatureExtractor - - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) - feature_extractor.save_pretrained(output_dir) - components_saved.append("feature_extractor") - except Exception: - pass - - if components_saved: - logger.info(f"Saved preprocessing components: {', '.join(components_saved)}") - - except Exception as e: - logger.warning(f"Could not save config for local model: {e}") - logger.warning("User will need to provide config manually for inference") - - -def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: - """Load HuggingFace config and preprocessing components from ONNX. - - Handles both: - 1. Hub models - loads from HF Hub using metadata - 2. Local models - loads from co-located config files - - Args: - onnx_path: Path to ONNX model - - Returns: - Tuple of (config, preprocessor) - """ - from pathlib import Path - - from transformers import ( - AutoConfig, - AutoFeatureExtractor, - AutoImageProcessor, - AutoProcessor, - AutoTokenizer, - ) - - from ..onnx import load_onnx - - # Load ONNX model and extract metadata - onnx_model = load_onnx(onnx_path, validate=False) - onnx_dir = Path(onnx_path).parent - - # Extract metadata - metadata = {} - for prop in onnx_model.metadata_props: - metadata[prop.key] = prop.value - - model_type = metadata.get("hf_model_type", "unknown") - - if model_type == "hub": - # Hub model: Load from HuggingFace Hub - hf_hub_id = metadata.get("hf_hub_id") - hf_revision = metadata.get("hf_hub_revision") - - if not hf_hub_id: - raise ValueError("ONNX model marked as Hub model but missing hf_hub_id metadata") - - # Load config from Hub - config = AutoConfig.from_pretrained(hf_hub_id, revision=hf_revision) - - # Try to load preprocessor from Hub - preprocessor = None - for loader_cls in [ - AutoProcessor, - AutoTokenizer, - AutoImageProcessor, - AutoFeatureExtractor, - ]: - try: - preprocessor = loader_cls.from_pretrained(hf_hub_id, revision=hf_revision) - break - except Exception: - continue - - return config, preprocessor - - if model_type == "local": - # Local model: Load from co-located files - config_path = onnx_dir / "config.json" - - if not config_path.exists(): - raise ValueError( - f"Local model but config.json not found at {config_path}. " - "The model may have been moved without its config files." - ) - - # Load config from local file - config = AutoConfig.from_pretrained(onnx_dir) - - # Try to load preprocessor from local files - preprocessor = None - for loader_cls in [ - AutoProcessor, - AutoTokenizer, - AutoImageProcessor, - AutoFeatureExtractor, - ]: - try: - preprocessor = loader_cls.from_pretrained(onnx_dir) - break - except Exception: - continue - - return config, preprocessor - - # Unknown or legacy model - raise ValueError( - f"ONNX model has unknown type '{model_type}'. " - "Was it exported with an older version of ModelExport? " - "Please re-export the model." - ) +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""HuggingFace Hub utilities for model detection and configuration loading. + +This module provides intelligent detection of HuggingFace Hub models vs local models, +and handles the appropriate metadata storage and configuration loading strategies. +""" + +import logging +import re +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + + +def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: + """Comprehensive Hub model detection with metadata extraction. + + Args: + model_name_or_path: Model identifier or path + + Returns: + Tuple of (is_hub_model, metadata_dict) + """ + # Quick rejection for obvious local paths + if Path(model_name_or_path).exists(): + return False, {"type": "local", "path": model_name_or_path} + + # Check for local path indicators + if any(model_name_or_path.startswith(prefix) for prefix in ["./", "../", "/", "~/"]): + return False, {"type": "local", "path": model_name_or_path} + + # Check for Windows absolute paths + if re.match(r"^[A-Za-z]:[\\/]", model_name_or_path): + return False, {"type": "local", "path": model_name_or_path} + + # Parse potential Hub model format + # Supports: model-name, org/model, org/model@revision + hub_pattern = r"^(?:([^/@]+)/)?([^/@]+)(?:@(.+))?$" + match = re.match(hub_pattern, model_name_or_path) + + if not match: + return False, {"type": "invalid"} + + org, model, revision = match.groups() + full_model_id = f"{org}/{model}" if org else model + + # Try to verify with Hub API + try: + from huggingface_hub import HfApi + + api = HfApi() + model_info = api.model_info(full_model_id, revision=revision) + + # Extract comprehensive metadata + metadata = { + "type": "hub", + "model_id": model_info.id, + "sha": model_info.sha, + "revision": revision or "main", + "tags": model_info.tags if hasattr(model_info, "tags") else [], + "pipeline_tag": model_info.pipeline_tag + if hasattr(model_info, "pipeline_tag") + else None, + "library_name": model_info.library_name + if hasattr(model_info, "library_name") + else None, + "author": model_info.author if hasattr(model_info, "author") else None, + "last_modified": str(model_info.lastModified) + if hasattr(model_info, "lastModified") + else None, + "private": model_info.private if hasattr(model_info, "private") else False, + "gated": model_info.gated if hasattr(model_info, "gated") else False, + } + + # Try to get model card info if available + try: + from huggingface_hub import ModelCard + + card = ModelCard.load(full_model_id) + if hasattr(card.data, "base_model"): + metadata["base_model"] = card.data.base_model + if hasattr(card.data, "license"): + metadata["license"] = card.data.license + if hasattr(card.data, "language"): + metadata["language"] = card.data.language + if hasattr(card.data, "task_categories"): + metadata["task_categories"] = card.data.task_categories + except Exception: + pass + + return True, metadata + + except Exception as e: + # Could not verify with Hub - might be private or offline + # Use heuristics to guess + if len(model_name_or_path.split("/")) <= 2 and "\\" not in model_name_or_path: + return True, { + "type": "hub_unverified", + "model_id": full_model_id, + "revision": revision or "main", + "error": str(e), + } + return False, {"type": "local", "path": model_name_or_path} + + +def inject_hub_metadata(onnx_model: Any, model_name_or_path: str, metadata: dict) -> None: + """Inject HuggingFace Hub metadata into ONNX model. + + Args: + onnx_model: ONNX model proto + model_name_or_path: Original model identifier + metadata: Hub metadata dictionary + """ + from datetime import datetime, timezone + + # Clear any existing HF metadata + # We need to remove items by filtering, not reassigning + hf_props_to_remove = [] + for i, prop in enumerate(onnx_model.metadata_props): + if prop.key.startswith("hf_"): + hf_props_to_remove.append(i) + + # Remove in reverse order to maintain indices + for i in reversed(hf_props_to_remove): + del onnx_model.metadata_props[i] + + # Add required metadata + def add_prop(key: str, value: Any) -> None: + if value is not None: + prop = onnx_model.metadata_props.add() + prop.key = key + prop.value = str(value) + + # Required fields + add_prop("hf_hub_id", metadata.get("model_id")) + add_prop("hf_hub_revision", metadata.get("sha", "")[:8]) + add_prop("hf_model_type", "hub") + + # Get ModelExport version + try: + from .. import __version__ + + export_version = __version__ + except ImportError: + export_version = "unknown" + + add_prop("hf_export_version", export_version) + add_prop("hf_export_timestamp", datetime.now(timezone.utc).isoformat()) + + # Optional fields + for key in ["pipeline_tag", "library_name", "base_model", "private", "gated"]: + if key in metadata: + add_prop(f"hf_{key}", metadata[key]) + + # Producer information + onnx_model.producer_name = "ModelExport-HTP" + onnx_model.producer_version = export_version + onnx_model.domain = "com.modelexport.htp" + + # Add doc string for human readability + onnx_model.doc_string = ( + f"Exported from HuggingFace model: {metadata.get('model_id')}\n" + f"Revision: {metadata.get('sha', 'unknown')[:8]}\n" + f"Export timestamp: {datetime.now(timezone.utc).isoformat()}\n" + f"ModelExport version: {export_version}" + ) + + +def save_local_model_configs(model_name_or_path: str, output_dir: Path, metadata: dict) -> None: + """Save configuration files for local/in-house models. + + Args: + model_name_or_path: Path to local model + output_dir: Directory to save configs + metadata: Local model metadata + """ + # Check if the path exists first + if not Path(model_name_or_path).exists(): + logger.info(f"Local model path {model_name_or_path} does not exist, skipping config copy") + return + + try: + from transformers import AutoConfig + + # Save config + config = AutoConfig.from_pretrained(model_name_or_path) + config.save_pretrained(output_dir) + logger.info(f"Saved config.json to {output_dir}") + + # Track what components were saved + components_saved = [] + + # Try AutoProcessor (for multimodal) + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_name_or_path) + processor.save_pretrained(output_dir) + components_saved.append("processor") + except Exception: + pass + + # Try AutoTokenizer (for text models) - only if processor wasn't saved + if "processor" not in components_saved: + try: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(output_dir) + components_saved.append("tokenizer") + except Exception: + pass + + # Try AutoImageProcessor (for vision) + try: + from transformers import AutoImageProcessor + + image_processor = AutoImageProcessor.from_pretrained(model_name_or_path) + image_processor.save_pretrained(output_dir) + components_saved.append("image_processor") + except Exception: + pass + + # Try AutoFeatureExtractor (for audio) + try: + from transformers import AutoFeatureExtractor + + feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) + feature_extractor.save_pretrained(output_dir) + components_saved.append("feature_extractor") + except Exception: + pass + + if components_saved: + logger.info(f"Saved preprocessing components: {', '.join(components_saved)}") + + except Exception as e: + logger.warning(f"Could not save config for local model: {e}") + logger.warning("User will need to provide config manually for inference") + + +def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: + """Load HuggingFace config and preprocessing components from ONNX. + + Handles both: + 1. Hub models - loads from HF Hub using metadata + 2. Local models - loads from co-located config files + + Args: + onnx_path: Path to ONNX model + + Returns: + Tuple of (config, preprocessor) + """ + from pathlib import Path + + from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoImageProcessor, + AutoProcessor, + AutoTokenizer, + ) + + from ..onnx import load_onnx + + # Load ONNX model and extract metadata + onnx_model = load_onnx(onnx_path, validate=False) + onnx_dir = Path(onnx_path).parent + + # Extract metadata + metadata = {} + for prop in onnx_model.metadata_props: + metadata[prop.key] = prop.value + + model_type = metadata.get("hf_model_type", "unknown") + + if model_type == "hub": + # Hub model: Load from HuggingFace Hub + hf_hub_id = metadata.get("hf_hub_id") + hf_revision = metadata.get("hf_hub_revision") + + if not hf_hub_id: + raise ValueError("ONNX model marked as Hub model but missing hf_hub_id metadata") + + # Load config from Hub + config = AutoConfig.from_pretrained(hf_hub_id, revision=hf_revision) + + # Try to load preprocessor from Hub + preprocessor = None + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + hub_loaders: list[Any] = [ + AutoProcessor, + AutoTokenizer, + AutoImageProcessor, + AutoFeatureExtractor, + ] + for loader_cls in hub_loaders: + try: + preprocessor = loader_cls.from_pretrained(hf_hub_id, revision=hf_revision) + break + except Exception: + continue + + return config, preprocessor + + if model_type == "local": + # Local model: Load from co-located files + config_path = onnx_dir / "config.json" + + if not config_path.exists(): + raise ValueError( + f"Local model but config.json not found at {config_path}. " + "The model may have been moved without its config files." + ) + + # Load config from local file + config = AutoConfig.from_pretrained(onnx_dir) + + # Try to load preprocessor from local files + preprocessor = None + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + local_loaders: list[Any] = [ + AutoProcessor, + AutoTokenizer, + AutoImageProcessor, + AutoFeatureExtractor, + ] + for loader_cls in local_loaders: + try: + preprocessor = loader_cls.from_pretrained(onnx_dir) + break + except Exception: + continue + + return config, preprocessor + + # Unknown or legacy model + raise ValueError( + f"ONNX model has unknown type '{model_type}'. " + "Was it exported with an older version of ModelExport? " + "Please re-export the model." + ) diff --git a/src/winml/modelkit/utils/native_stderr.py b/src/winml/modelkit/utils/native_stderr.py index ed6b8a0bb..87d3aea02 100644 --- a/src/winml/modelkit/utils/native_stderr.py +++ b/src/winml/modelkit/utils/native_stderr.py @@ -21,6 +21,10 @@ import re import sys from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator logger = logging.getLogger(__name__) @@ -47,7 +51,7 @@ @contextmanager -def suppress_native_stderr(): +def suppress_native_stderr() -> Iterator[None]: """Redirect native stderr to devnull. No-op on non-Windows.""" if sys.platform != "win32": yield @@ -68,7 +72,7 @@ def suppress_native_stderr(): @contextmanager -def capture_native_stderr(level: int = logging.INFO): +def capture_native_stderr(level: int = logging.INFO) -> Iterator[None]: """Capture native stderr via pipe and re-emit through Python logging. No-op on non-Windows. diff --git a/src/winml/modelkit/utils/optimum_loader.py b/src/winml/modelkit/utils/optimum_loader.py index c7a3a5b1d..90f74c923 100644 --- a/src/winml/modelkit/utils/optimum_loader.py +++ b/src/winml/modelkit/utils/optimum_loader.py @@ -1,160 +1,160 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Optimum integration utilities for loading ONNX models with HuggingFace configurations. - -This module provides seamless integration with HuggingFace Optimum for inference -using exported ONNX models with preserved hierarchy and Hub metadata. -""" - -import logging -import shutil -import tempfile -from pathlib import Path -from typing import Any - -from .hub_utils import load_hf_components_from_onnx - - -logger = logging.getLogger(__name__) - - -class OptimumONNXModel: - """Wrapper for seamless Optimum integration with Hub metadata.""" - - @classmethod - def from_onnx( - cls, onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any - ) -> tuple[Any, Any]: - """Load Optimum model from ONNX with Hub metadata. - - Args: - onnx_path: Path to ONNX model - task: Task type or "auto" to detect - device: Device to run on - **kwargs: Additional arguments for ORTModel - - Returns: - Tuple of (model, preprocessor) - """ - # Load config and preprocessor - config, preprocessor = load_hf_components_from_onnx(onnx_path) - - # Auto-detect task if needed - if task == "auto": - task = cls._detect_task(config, onnx_path) - - # Get appropriate ORTModel class - ort_model_class = cls._get_ort_model_class(task) - - # Create temporary directory with required files - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Save config - config.save_pretrained(temp_path) - - # Save preprocessor if available - if preprocessor: - preprocessor.save_pretrained(temp_path) - - # Copy ONNX model - shutil.copy(onnx_path, temp_path / "model.onnx") - - # Load with Optimum - model = ort_model_class.from_pretrained( - temp_path, - provider="CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider", - **kwargs, - ) - - return model, preprocessor - - @staticmethod - def _detect_task(config: Any, onnx_path: str) -> str: - """Detect task from config and metadata.""" - from ..onnx import load_onnx - - # Try to get task from metadata - try: - onnx_model = load_onnx(onnx_path, validate=False) - for prop in onnx_model.metadata_props: - if prop.key == "hf_pipeline_tag": - return prop.value - except Exception: - pass - - # Check architectures - if hasattr(config, "architectures"): - arch = config.architectures[0] if config.architectures else "" - - task_mapping = { - "ForSequenceClassification": "text-classification", - "ForTokenClassification": "token-classification", - "ForQuestionAnswering": "question-answering", - "ForCausalLM": "text-generation", - "ForConditionalGeneration": "text2text-generation", - "ForImageClassification": "image-classification", - "ForObjectDetection": "object-detection", - "ForAudioClassification": "audio-classification", - } - - for pattern, task in task_mapping.items(): - if pattern in arch: - return task - - # Default - return "feature-extraction" - - @staticmethod - def _get_ort_model_class(task: str) -> type[Any]: - """Get appropriate ORTModel class for task.""" - from optimum.onnxruntime import ( - ORTModel, - ORTModelForAudioClassification, - ORTModelForCausalLM, - ORTModelForFeatureExtraction, - ORTModelForImageClassification, - ORTModelForQuestionAnswering, - ORTModelForSeq2SeqLM, - ORTModelForSequenceClassification, - ORTModelForTokenClassification, - ) - - task_to_model = { - "text-classification": ORTModelForSequenceClassification, - "token-classification": ORTModelForTokenClassification, - "question-answering": ORTModelForQuestionAnswering, - "text-generation": ORTModelForCausalLM, - "text2text-generation": ORTModelForSeq2SeqLM, - "translation": ORTModelForSeq2SeqLM, - "summarization": ORTModelForSeq2SeqLM, - "image-classification": ORTModelForImageClassification, - "audio-classification": ORTModelForAudioClassification, - "feature-extraction": ORTModelForFeatureExtraction, - } - - return task_to_model.get(task, ORTModel) - - -def load_optimum_model( - onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any -) -> tuple[Any, Any]: - """Convenience function to load an ONNX model for Optimum inference. - - Args: - onnx_path: Path to ONNX model exported with ModelExport - task: Task type (auto-detected if not specified) - device: Device to run on ('cpu' or 'cuda') - **kwargs: Additional arguments for ORTModel - - Returns: - Tuple of (model, preprocessor) - - Example: - >>> model, tokenizer = load_optimum_model("bert.onnx") - >>> inputs = tokenizer("Hello world!", return_tensors="pt") - >>> outputs = model(**inputs) - """ - return OptimumONNXModel.from_onnx(onnx_path, task, device, **kwargs) +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Optimum integration utilities for loading ONNX models with HuggingFace configurations. + +This module provides seamless integration with HuggingFace Optimum for inference +using exported ONNX models with preserved hierarchy and Hub metadata. +""" + +import logging +import shutil +import tempfile +from pathlib import Path +from typing import Any, cast + +from .hub_utils import load_hf_components_from_onnx + + +logger = logging.getLogger(__name__) + + +class OptimumONNXModel: + """Wrapper for seamless Optimum integration with Hub metadata.""" + + @classmethod + def from_onnx( + cls, onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any + ) -> tuple[Any, Any]: + """Load Optimum model from ONNX with Hub metadata. + + Args: + onnx_path: Path to ONNX model + task: Task type or "auto" to detect + device: Device to run on + **kwargs: Additional arguments for ORTModel + + Returns: + Tuple of (model, preprocessor) + """ + # Load config and preprocessor + config, preprocessor = load_hf_components_from_onnx(onnx_path) + + # Auto-detect task if needed + if task == "auto": + task = cls._detect_task(config, onnx_path) + + # Get appropriate ORTModel class + ort_model_class = cls._get_ort_model_class(task) + + # Create temporary directory with required files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Save config + config.save_pretrained(temp_path) + + # Save preprocessor if available + if preprocessor: + preprocessor.save_pretrained(temp_path) + + # Copy ONNX model + shutil.copy(onnx_path, temp_path / "model.onnx") + + # Load with Optimum + model = ort_model_class.from_pretrained( + temp_path, + provider="CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider", + **kwargs, + ) + + return model, preprocessor + + @staticmethod + def _detect_task(config: Any, onnx_path: str) -> str: + """Detect task from config and metadata.""" + from ..onnx import load_onnx + + # Try to get task from metadata + try: + onnx_model = load_onnx(onnx_path, validate=False) + for prop in onnx_model.metadata_props: + if prop.key == "hf_pipeline_tag": + return prop.value + except Exception: + pass + + # Check architectures + if hasattr(config, "architectures"): + arch = config.architectures[0] if config.architectures else "" + + task_mapping = { + "ForSequenceClassification": "text-classification", + "ForTokenClassification": "token-classification", + "ForQuestionAnswering": "question-answering", + "ForCausalLM": "text-generation", + "ForConditionalGeneration": "text2text-generation", + "ForImageClassification": "image-classification", + "ForObjectDetection": "object-detection", + "ForAudioClassification": "audio-classification", + } + + for pattern, task in task_mapping.items(): + if pattern in arch: + return task + + # Default + return "feature-extraction" + + @staticmethod + def _get_ort_model_class(task: str) -> type[Any]: + """Get appropriate ORTModel class for task.""" + from optimum.onnxruntime import ( + ORTModel, + ORTModelForAudioClassification, + ORTModelForCausalLM, + ORTModelForFeatureExtraction, + ORTModelForImageClassification, + ORTModelForQuestionAnswering, + ORTModelForSeq2SeqLM, + ORTModelForSequenceClassification, + ORTModelForTokenClassification, + ) + + task_to_model = { + "text-classification": ORTModelForSequenceClassification, + "token-classification": ORTModelForTokenClassification, + "question-answering": ORTModelForQuestionAnswering, + "text-generation": ORTModelForCausalLM, + "text2text-generation": ORTModelForSeq2SeqLM, + "translation": ORTModelForSeq2SeqLM, + "summarization": ORTModelForSeq2SeqLM, + "image-classification": ORTModelForImageClassification, + "audio-classification": ORTModelForAudioClassification, + "feature-extraction": ORTModelForFeatureExtraction, + } + + return cast("type[Any]", task_to_model.get(task, ORTModel)) + + +def load_optimum_model( + onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any +) -> tuple[Any, Any]: + """Convenience function to load an ONNX model for Optimum inference. + + Args: + onnx_path: Path to ONNX model exported with ModelExport + task: Task type (auto-detected if not specified) + device: Device to run on ('cpu' or 'cuda') + **kwargs: Additional arguments for ORTModel + + Returns: + Tuple of (model, preprocessor) + + Example: + >>> model, tokenizer = load_optimum_model("bert.onnx") + >>> inputs = tokenizer("Hello world!", return_tensors="pt") + >>> outputs = model(**inputs) + """ + return OptimumONNXModel.from_onnx(onnx_path, task, device, **kwargs) From c1cc3502a2c3451b9ed8811d192d4bf85633a3a8 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Thu, 18 Jun 2026 14:16:29 +0800 Subject: [PATCH 3/3] more --- .github/workflows/lint.yml | 3 +- src/winml/modelkit/utils/config_utils.py | 29 +- src/winml/modelkit/utils/console.py | 14 +- src/winml/modelkit/utils/data_utils.py | 134 ++-- src/winml/modelkit/utils/hub_utils.py | 696 ++++++++++----------- src/winml/modelkit/utils/native_stderr.py | 1 + src/winml/modelkit/utils/optimum_loader.py | 320 +++++----- 7 files changed, 603 insertions(+), 594 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6089eef89..dd3444bc9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ concurrency: jobs: lint: runs-on: windows-latest - # Bumped from 5: combined mypy on 22 packages cold-starts at ~3-4 min on + # Bumped from 5: combined mypy on 23 packages cold-starts at ~3-4 min on # Windows runners; the original 5-min ceiling cancelled mid-run. timeout-minutes: 10 @@ -70,3 +70,4 @@ jobs: -p winml.modelkit.session -p winml.modelkit.sysinfo -p winml.modelkit.telemetry + -p winml.modelkit.utils diff --git a/src/winml/modelkit/utils/config_utils.py b/src/winml/modelkit/utils/config_utils.py index 911e20468..26e514fe0 100644 --- a/src/winml/modelkit/utils/config_utils.py +++ b/src/winml/modelkit/utils/config_utils.py @@ -11,7 +11,11 @@ import dataclasses import typing -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast + + +if TYPE_CHECKING: + from _typeshed import DataclassInstance T = TypeVar("T") @@ -45,17 +49,18 @@ def merge_config(base: T, overrides: dict[str, Any] | T | None) -> T: if overrides is None: return base - # Convert overrides to dict if it's a config object + # Convert overrides to a plain dict if it's a config object. + overrides_dict: dict[str, Any] if hasattr(overrides, "to_dict"): - overrides = overrides.to_dict() + overrides_dict = overrides.to_dict() elif dataclasses.is_dataclass(overrides) and not isinstance(overrides, type): - overrides = dataclasses.asdict(overrides) + overrides_dict = dataclasses.asdict(overrides) elif isinstance(overrides, dict): - overrides = dict(overrides) # Copy to avoid mutation + overrides_dict = dict(overrides) # Copy to avoid mutation else: raise TypeError(f"overrides must be dict or config, got {type(overrides)}") - return _merge_into(base, overrides) + return _merge_into(base, overrides_dict) def _merge_into(base: T, overrides: dict[str, Any]) -> T: @@ -66,9 +71,9 @@ def _merge_into(base: T, overrides: dict[str, Any]) -> T: if isinstance(base, dict): # Handle dict-like config (e.g., WinMLOptimizationConfig) - result = type(base)(**base) # type: ignore[call-arg] + result = type(base)(**base) result.update(overrides) - return result # type: ignore[return-value] + return result # Primitive or unknown type - just return override return base @@ -78,7 +83,9 @@ def _merge_dataclass(base: T, overrides: dict[str, Any]) -> T: """Merge overrides into a dataclass, handling nested configs.""" # Get current field values current = {} - for f in dataclasses.fields(base): + # base is a dataclass instance by precondition (caller checks is_dataclass); + # cast for the fields() reflection call without widening base's T elsewhere. + for f in dataclasses.fields(cast("DataclassInstance", base)): current[f.name] = getattr(base, f.name) # Apply overrides recursively @@ -117,7 +124,7 @@ def _merge_dataclass(base: T, overrides: dict[str, Any]) -> T: current[key] = value # Create new instance - return type(base)(**current) # type: ignore[return-value] + return type(base)(**current) def _get_field_type(obj: Any, field_name: str) -> type | None: @@ -151,5 +158,5 @@ def _get_field_type(obj: Any, field_name: str) -> type | None: # It's an Optional/Union with None — extract the non-None type for arg in args: if arg is not type(None): - return arg # type: ignore[return-value] + return cast("type", arg) return resolved if isinstance(resolved, type) else None diff --git a/src/winml/modelkit/utils/console.py b/src/winml/modelkit/utils/console.py index a54194411..e83293d9b 100644 --- a/src/winml/modelkit/utils/console.py +++ b/src/winml/modelkit/utils/console.py @@ -117,12 +117,12 @@ def print_io_specs_detail( for i, t in enumerate(inputs): name = t.name or "(unnamed)" - shape_str = str(list(t.shape)) if getattr(t, "shape", None) else "dynamic" + shape_str = str(list(t.shape)) if t.shape else "dynamic" dtype_str = getattr(t, "dtype", None) or "?" label = "Input: " if i == 0 else " " console.print(f" {label}[cyan]{name:<18}[/cyan] {shape_str:<14} [dim]{dtype_str}[/dim]") - for i, t in enumerate(outputs): - name = t.name or "(unnamed)" + for i, out_t in enumerate(outputs): + name = out_t.name or "(unnamed)" # Fix #3: OutputTensorSpec only has name — show name only label = "Output: " if i == 0 else " " console.print(f" {label}[cyan]{name}[/cyan]") @@ -307,7 +307,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._refresh_disabled: bool = False - def refresh(self) -> None: # type: ignore[override] + def refresh(self) -> None: if self._refresh_disabled: return try: @@ -320,15 +320,15 @@ def refresh(self) -> None: # type: ignore[override] logger.debug("Disabling Live refresh after OSError", exc_info=True) @_swallow_oserror - def start(self, refresh: bool = False) -> None: # type: ignore[override] + def start(self, refresh: bool = False) -> None: super().start(refresh=refresh) @_swallow_oserror - def stop(self) -> None: # type: ignore[override] + def stop(self) -> None: super().stop() @_swallow_oserror - def update(self, renderable: RenderableType, *, refresh: bool = False) -> None: # type: ignore[override] + def update(self, renderable: RenderableType, *, refresh: bool = False) -> None: super().update(renderable, refresh=refresh) diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py index 2903eab8b..364ab1b92 100644 --- a/src/winml/modelkit/utils/data_utils.py +++ b/src/winml/modelkit/utils/data_utils.py @@ -1,67 +1,67 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -"""Data utilities for input preparation and padding.""" - -from __future__ import annotations - -from typing import Any, Literal - -import torch - - -def pad_inputs( - source: dict[str, Any], - expected: dict[str, list[int]], - mode: Literal["left", "right"] = "right", -) -> dict[str, Any]: - """Filter *source* to keys in *expected* and pad undersized tensors. - - For each name in *expected*, if *source* has a tensor for it, pad any - dimension smaller than the ONNX expected shape (skips batch dim). - Non-tensor values are passed through. Missing names are skipped. - - Args: - source: Input tensors keyed by name. - expected: ONNX expected shapes keyed by input name. - mode: Padding side — ``"right"`` (default, pad at end) or - ``"left"`` (pad at start). - - Returns: - Filtered and padded tensors matching *expected* keys. - """ - if mode not in ("right", "left"): - raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") - - result: dict[str, Any] = {} - for name, expected_shape in expected.items(): - val = source.get(name) - if val is None: - continue - if isinstance(val, torch.Tensor): - # TODO: support dynamic shape ONNX models (None in expected_shape) - ndim = min(len(val.shape), len(expected_shape)) - # torch.nn.functional.pad takes pairs (low, high) from the LAST - # dim backwards. Skip batch dim (dim 0). - pad: list[int] = [] - for dim in reversed(range(1, ndim)): - exp = expected_shape[dim] - # Dynamic ONNX dims may be None or a string symbol; emit a - # (0, 0) pair so later pairs stay aligned with their dim index. - if not isinstance(exp, int): - # Forward-looking: expected is typed list[int] today, but ONNX - # dynamic dims (None / str symbol) are a planned input (see TODO - # above). Keep the guard until the signature widens for them. - pad.extend([0, 0]) # type: ignore[unreachable] - continue - deficit = max(exp - val.shape[dim], 0) - if mode == "right": - pad.extend([0, deficit]) - else: # left - pad.extend([deficit, 0]) - if any(p > 0 for p in pad): - val = torch.nn.functional.pad(val, pad) - result[name] = val - return result +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Data utilities for input preparation and padding.""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + + +def pad_inputs( + source: dict[str, Any], + expected: dict[str, list[int]], + mode: Literal["left", "right"] = "right", +) -> dict[str, Any]: + """Filter *source* to keys in *expected* and pad undersized tensors. + + For each name in *expected*, if *source* has a tensor for it, pad any + dimension smaller than the ONNX expected shape (skips batch dim). + Non-tensor values are passed through. Missing names are skipped. + + Args: + source: Input tensors keyed by name. + expected: ONNX expected shapes keyed by input name. + mode: Padding side — ``"right"`` (default, pad at end) or + ``"left"`` (pad at start). + + Returns: + Filtered and padded tensors matching *expected* keys. + """ + if mode not in ("right", "left"): + raise ValueError(f"mode must be 'right' or 'left', got {mode!r}") + + result: dict[str, Any] = {} + for name, expected_shape in expected.items(): + val = source.get(name) + if val is None: + continue + if isinstance(val, torch.Tensor): + # TODO: support dynamic shape ONNX models (None in expected_shape) + ndim = min(len(val.shape), len(expected_shape)) + # torch.nn.functional.pad takes pairs (low, high) from the LAST + # dim backwards. Skip batch dim (dim 0). + pad: list[int] = [] + for dim in reversed(range(1, ndim)): + exp = expected_shape[dim] + # Dynamic ONNX dims may be None or a string symbol; emit a + # (0, 0) pair so later pairs stay aligned with their dim index. + if not isinstance(exp, int): + # Forward-looking: expected is typed list[int] today, but ONNX + # dynamic dims (None / str symbol) are a planned input (see TODO + # above). Keep the guard until the signature widens for them. + pad.extend([0, 0]) # type: ignore[unreachable] + continue + deficit = max(exp - val.shape[dim], 0) + if mode == "right": + pad.extend([0, deficit]) + else: # left + pad.extend([deficit, 0]) + if any(p > 0 for p in pad): + val = torch.nn.functional.pad(val, pad) + result[name] = val + return result diff --git a/src/winml/modelkit/utils/hub_utils.py b/src/winml/modelkit/utils/hub_utils.py index cfb7cc279..3da385a3b 100644 --- a/src/winml/modelkit/utils/hub_utils.py +++ b/src/winml/modelkit/utils/hub_utils.py @@ -1,348 +1,348 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""HuggingFace Hub utilities for model detection and configuration loading. - -This module provides intelligent detection of HuggingFace Hub models vs local models, -and handles the appropriate metadata storage and configuration loading strategies. -""" - -import logging -import re -from pathlib import Path -from typing import Any - - -logger = logging.getLogger(__name__) - - -def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: - """Comprehensive Hub model detection with metadata extraction. - - Args: - model_name_or_path: Model identifier or path - - Returns: - Tuple of (is_hub_model, metadata_dict) - """ - # Quick rejection for obvious local paths - if Path(model_name_or_path).exists(): - return False, {"type": "local", "path": model_name_or_path} - - # Check for local path indicators - if any(model_name_or_path.startswith(prefix) for prefix in ["./", "../", "/", "~/"]): - return False, {"type": "local", "path": model_name_or_path} - - # Check for Windows absolute paths - if re.match(r"^[A-Za-z]:[\\/]", model_name_or_path): - return False, {"type": "local", "path": model_name_or_path} - - # Parse potential Hub model format - # Supports: model-name, org/model, org/model@revision - hub_pattern = r"^(?:([^/@]+)/)?([^/@]+)(?:@(.+))?$" - match = re.match(hub_pattern, model_name_or_path) - - if not match: - return False, {"type": "invalid"} - - org, model, revision = match.groups() - full_model_id = f"{org}/{model}" if org else model - - # Try to verify with Hub API - try: - from huggingface_hub import HfApi - - api = HfApi() - model_info = api.model_info(full_model_id, revision=revision) - - # Extract comprehensive metadata - metadata = { - "type": "hub", - "model_id": model_info.id, - "sha": model_info.sha, - "revision": revision or "main", - "tags": model_info.tags if hasattr(model_info, "tags") else [], - "pipeline_tag": model_info.pipeline_tag - if hasattr(model_info, "pipeline_tag") - else None, - "library_name": model_info.library_name - if hasattr(model_info, "library_name") - else None, - "author": model_info.author if hasattr(model_info, "author") else None, - "last_modified": str(model_info.lastModified) - if hasattr(model_info, "lastModified") - else None, - "private": model_info.private if hasattr(model_info, "private") else False, - "gated": model_info.gated if hasattr(model_info, "gated") else False, - } - - # Try to get model card info if available - try: - from huggingface_hub import ModelCard - - card = ModelCard.load(full_model_id) - if hasattr(card.data, "base_model"): - metadata["base_model"] = card.data.base_model - if hasattr(card.data, "license"): - metadata["license"] = card.data.license - if hasattr(card.data, "language"): - metadata["language"] = card.data.language - if hasattr(card.data, "task_categories"): - metadata["task_categories"] = card.data.task_categories - except Exception: - pass - - return True, metadata - - except Exception as e: - # Could not verify with Hub - might be private or offline - # Use heuristics to guess - if len(model_name_or_path.split("/")) <= 2 and "\\" not in model_name_or_path: - return True, { - "type": "hub_unverified", - "model_id": full_model_id, - "revision": revision or "main", - "error": str(e), - } - return False, {"type": "local", "path": model_name_or_path} - - -def inject_hub_metadata(onnx_model: Any, model_name_or_path: str, metadata: dict) -> None: - """Inject HuggingFace Hub metadata into ONNX model. - - Args: - onnx_model: ONNX model proto - model_name_or_path: Original model identifier - metadata: Hub metadata dictionary - """ - from datetime import datetime, timezone - - # Clear any existing HF metadata - # We need to remove items by filtering, not reassigning - hf_props_to_remove = [] - for i, prop in enumerate(onnx_model.metadata_props): - if prop.key.startswith("hf_"): - hf_props_to_remove.append(i) - - # Remove in reverse order to maintain indices - for i in reversed(hf_props_to_remove): - del onnx_model.metadata_props[i] - - # Add required metadata - def add_prop(key: str, value: Any) -> None: - if value is not None: - prop = onnx_model.metadata_props.add() - prop.key = key - prop.value = str(value) - - # Required fields - add_prop("hf_hub_id", metadata.get("model_id")) - add_prop("hf_hub_revision", metadata.get("sha", "")[:8]) - add_prop("hf_model_type", "hub") - - # Get ModelExport version - try: - from .. import __version__ - - export_version = __version__ - except ImportError: - export_version = "unknown" - - add_prop("hf_export_version", export_version) - add_prop("hf_export_timestamp", datetime.now(timezone.utc).isoformat()) - - # Optional fields - for key in ["pipeline_tag", "library_name", "base_model", "private", "gated"]: - if key in metadata: - add_prop(f"hf_{key}", metadata[key]) - - # Producer information - onnx_model.producer_name = "ModelExport-HTP" - onnx_model.producer_version = export_version - onnx_model.domain = "com.modelexport.htp" - - # Add doc string for human readability - onnx_model.doc_string = ( - f"Exported from HuggingFace model: {metadata.get('model_id')}\n" - f"Revision: {metadata.get('sha', 'unknown')[:8]}\n" - f"Export timestamp: {datetime.now(timezone.utc).isoformat()}\n" - f"ModelExport version: {export_version}" - ) - - -def save_local_model_configs(model_name_or_path: str, output_dir: Path, metadata: dict) -> None: - """Save configuration files for local/in-house models. - - Args: - model_name_or_path: Path to local model - output_dir: Directory to save configs - metadata: Local model metadata - """ - # Check if the path exists first - if not Path(model_name_or_path).exists(): - logger.info(f"Local model path {model_name_or_path} does not exist, skipping config copy") - return - - try: - from transformers import AutoConfig - - # Save config - config = AutoConfig.from_pretrained(model_name_or_path) - config.save_pretrained(output_dir) - logger.info(f"Saved config.json to {output_dir}") - - # Track what components were saved - components_saved = [] - - # Try AutoProcessor (for multimodal) - try: - from transformers import AutoProcessor - - processor = AutoProcessor.from_pretrained(model_name_or_path) - processor.save_pretrained(output_dir) - components_saved.append("processor") - except Exception: - pass - - # Try AutoTokenizer (for text models) - only if processor wasn't saved - if "processor" not in components_saved: - try: - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.save_pretrained(output_dir) - components_saved.append("tokenizer") - except Exception: - pass - - # Try AutoImageProcessor (for vision) - try: - from transformers import AutoImageProcessor - - image_processor = AutoImageProcessor.from_pretrained(model_name_or_path) - image_processor.save_pretrained(output_dir) - components_saved.append("image_processor") - except Exception: - pass - - # Try AutoFeatureExtractor (for audio) - try: - from transformers import AutoFeatureExtractor - - feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) - feature_extractor.save_pretrained(output_dir) - components_saved.append("feature_extractor") - except Exception: - pass - - if components_saved: - logger.info(f"Saved preprocessing components: {', '.join(components_saved)}") - - except Exception as e: - logger.warning(f"Could not save config for local model: {e}") - logger.warning("User will need to provide config manually for inference") - - -def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: - """Load HuggingFace config and preprocessing components from ONNX. - - Handles both: - 1. Hub models - loads from HF Hub using metadata - 2. Local models - loads from co-located config files - - Args: - onnx_path: Path to ONNX model - - Returns: - Tuple of (config, preprocessor) - """ - from pathlib import Path - - from transformers import ( - AutoConfig, - AutoFeatureExtractor, - AutoImageProcessor, - AutoProcessor, - AutoTokenizer, - ) - - from ..onnx import load_onnx - - # Load ONNX model and extract metadata - onnx_model = load_onnx(onnx_path, validate=False) - onnx_dir = Path(onnx_path).parent - - # Extract metadata - metadata = {} - for prop in onnx_model.metadata_props: - metadata[prop.key] = prop.value - - model_type = metadata.get("hf_model_type", "unknown") - - if model_type == "hub": - # Hub model: Load from HuggingFace Hub - hf_hub_id = metadata.get("hf_hub_id") - hf_revision = metadata.get("hf_hub_revision") - - if not hf_hub_id: - raise ValueError("ONNX model marked as Hub model but missing hf_hub_id metadata") - - # Load config from Hub - config = AutoConfig.from_pretrained(hf_hub_id, revision=hf_revision) - - # Try to load preprocessor from Hub - preprocessor = None - # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. - hub_loaders: list[Any] = [ - AutoProcessor, - AutoTokenizer, - AutoImageProcessor, - AutoFeatureExtractor, - ] - for loader_cls in hub_loaders: - try: - preprocessor = loader_cls.from_pretrained(hf_hub_id, revision=hf_revision) - break - except Exception: - continue - - return config, preprocessor - - if model_type == "local": - # Local model: Load from co-located files - config_path = onnx_dir / "config.json" - - if not config_path.exists(): - raise ValueError( - f"Local model but config.json not found at {config_path}. " - "The model may have been moved without its config files." - ) - - # Load config from local file - config = AutoConfig.from_pretrained(onnx_dir) - - # Try to load preprocessor from local files - preprocessor = None - # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. - local_loaders: list[Any] = [ - AutoProcessor, - AutoTokenizer, - AutoImageProcessor, - AutoFeatureExtractor, - ] - for loader_cls in local_loaders: - try: - preprocessor = loader_cls.from_pretrained(onnx_dir) - break - except Exception: - continue - - return config, preprocessor - - # Unknown or legacy model - raise ValueError( - f"ONNX model has unknown type '{model_type}'. " - "Was it exported with an older version of ModelExport? " - "Please re-export the model." - ) +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""HuggingFace Hub utilities for model detection and configuration loading. + +This module provides intelligent detection of HuggingFace Hub models vs local models, +and handles the appropriate metadata storage and configuration loading strategies. +""" + +import logging +import re +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + + +def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: + """Comprehensive Hub model detection with metadata extraction. + + Args: + model_name_or_path: Model identifier or path + + Returns: + Tuple of (is_hub_model, metadata_dict) + """ + # Quick rejection for obvious local paths + if Path(model_name_or_path).exists(): + return False, {"type": "local", "path": model_name_or_path} + + # Check for local path indicators + if any(model_name_or_path.startswith(prefix) for prefix in ["./", "../", "/", "~/"]): + return False, {"type": "local", "path": model_name_or_path} + + # Check for Windows absolute paths + if re.match(r"^[A-Za-z]:[\\/]", model_name_or_path): + return False, {"type": "local", "path": model_name_or_path} + + # Parse potential Hub model format + # Supports: model-name, org/model, org/model@revision + hub_pattern = r"^(?:([^/@]+)/)?([^/@]+)(?:@(.+))?$" + match = re.match(hub_pattern, model_name_or_path) + + if not match: + return False, {"type": "invalid"} + + org, model, revision = match.groups() + full_model_id = f"{org}/{model}" if org else model + + # Try to verify with Hub API + try: + from huggingface_hub import HfApi + + api = HfApi() + model_info = api.model_info(full_model_id, revision=revision) + + # Extract comprehensive metadata + metadata = { + "type": "hub", + "model_id": model_info.id, + "sha": model_info.sha, + "revision": revision or "main", + "tags": model_info.tags if hasattr(model_info, "tags") else [], + "pipeline_tag": model_info.pipeline_tag + if hasattr(model_info, "pipeline_tag") + else None, + "library_name": model_info.library_name + if hasattr(model_info, "library_name") + else None, + "author": model_info.author if hasattr(model_info, "author") else None, + "last_modified": str(model_info.lastModified) + if hasattr(model_info, "lastModified") + else None, + "private": model_info.private if hasattr(model_info, "private") else False, + "gated": model_info.gated if hasattr(model_info, "gated") else False, + } + + # Try to get model card info if available + try: + from huggingface_hub import ModelCard + + card = ModelCard.load(full_model_id) + if hasattr(card.data, "base_model"): + metadata["base_model"] = card.data.base_model + if hasattr(card.data, "license"): + metadata["license"] = card.data.license + if hasattr(card.data, "language"): + metadata["language"] = card.data.language + if hasattr(card.data, "task_categories"): + metadata["task_categories"] = card.data.task_categories + except Exception: + pass + + return True, metadata + + except Exception as e: + # Could not verify with Hub - might be private or offline + # Use heuristics to guess + if len(model_name_or_path.split("/")) <= 2 and "\\" not in model_name_or_path: + return True, { + "type": "hub_unverified", + "model_id": full_model_id, + "revision": revision or "main", + "error": str(e), + } + return False, {"type": "local", "path": model_name_or_path} + + +def inject_hub_metadata(onnx_model: Any, model_name_or_path: str, metadata: dict) -> None: + """Inject HuggingFace Hub metadata into ONNX model. + + Args: + onnx_model: ONNX model proto + model_name_or_path: Original model identifier + metadata: Hub metadata dictionary + """ + from datetime import datetime, timezone + + # Clear any existing HF metadata + # We need to remove items by filtering, not reassigning + hf_props_to_remove = [] + for i, prop in enumerate(onnx_model.metadata_props): + if prop.key.startswith("hf_"): + hf_props_to_remove.append(i) + + # Remove in reverse order to maintain indices + for i in reversed(hf_props_to_remove): + del onnx_model.metadata_props[i] + + # Add required metadata + def add_prop(key: str, value: Any) -> None: + if value is not None: + prop = onnx_model.metadata_props.add() + prop.key = key + prop.value = str(value) + + # Required fields + add_prop("hf_hub_id", metadata.get("model_id")) + add_prop("hf_hub_revision", metadata.get("sha", "")[:8]) + add_prop("hf_model_type", "hub") + + # Get ModelExport version + try: + from .. import __version__ + + export_version = __version__ + except ImportError: + export_version = "unknown" + + add_prop("hf_export_version", export_version) + add_prop("hf_export_timestamp", datetime.now(timezone.utc).isoformat()) + + # Optional fields + for key in ["pipeline_tag", "library_name", "base_model", "private", "gated"]: + if key in metadata: + add_prop(f"hf_{key}", metadata[key]) + + # Producer information + onnx_model.producer_name = "ModelExport-HTP" + onnx_model.producer_version = export_version + onnx_model.domain = "com.modelexport.htp" + + # Add doc string for human readability + onnx_model.doc_string = ( + f"Exported from HuggingFace model: {metadata.get('model_id')}\n" + f"Revision: {metadata.get('sha', 'unknown')[:8]}\n" + f"Export timestamp: {datetime.now(timezone.utc).isoformat()}\n" + f"ModelExport version: {export_version}" + ) + + +def save_local_model_configs(model_name_or_path: str, output_dir: Path, metadata: dict) -> None: + """Save configuration files for local/in-house models. + + Args: + model_name_or_path: Path to local model + output_dir: Directory to save configs + metadata: Local model metadata + """ + # Check if the path exists first + if not Path(model_name_or_path).exists(): + logger.info(f"Local model path {model_name_or_path} does not exist, skipping config copy") + return + + try: + from transformers import AutoConfig + + # Save config + config = AutoConfig.from_pretrained(model_name_or_path) + config.save_pretrained(output_dir) + logger.info(f"Saved config.json to {output_dir}") + + # Track what components were saved + components_saved = [] + + # Try AutoProcessor (for multimodal) + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_name_or_path) + processor.save_pretrained(output_dir) + components_saved.append("processor") + except Exception: + pass + + # Try AutoTokenizer (for text models) - only if processor wasn't saved + if "processor" not in components_saved: + try: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(output_dir) + components_saved.append("tokenizer") + except Exception: + pass + + # Try AutoImageProcessor (for vision) + try: + from transformers import AutoImageProcessor + + image_processor = AutoImageProcessor.from_pretrained(model_name_or_path) + image_processor.save_pretrained(output_dir) + components_saved.append("image_processor") + except Exception: + pass + + # Try AutoFeatureExtractor (for audio) + try: + from transformers import AutoFeatureExtractor + + feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path) + feature_extractor.save_pretrained(output_dir) + components_saved.append("feature_extractor") + except Exception: + pass + + if components_saved: + logger.info(f"Saved preprocessing components: {', '.join(components_saved)}") + + except Exception as e: + logger.warning(f"Could not save config for local model: {e}") + logger.warning("User will need to provide config manually for inference") + + +def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: + """Load HuggingFace config and preprocessing components from ONNX. + + Handles both: + 1. Hub models - loads from HF Hub using metadata + 2. Local models - loads from co-located config files + + Args: + onnx_path: Path to ONNX model + + Returns: + Tuple of (config, preprocessor) + """ + from pathlib import Path + + from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoImageProcessor, + AutoProcessor, + AutoTokenizer, + ) + + from ..onnx import load_onnx + + # Load ONNX model and extract metadata + onnx_model = load_onnx(onnx_path, validate=False) + onnx_dir = Path(onnx_path).parent + + # Extract metadata + metadata = {} + for prop in onnx_model.metadata_props: + metadata[prop.key] = prop.value + + model_type = metadata.get("hf_model_type", "unknown") + + if model_type == "hub": + # Hub model: Load from HuggingFace Hub + hf_hub_id = metadata.get("hf_hub_id") + hf_revision = metadata.get("hf_hub_revision") + + if not hf_hub_id: + raise ValueError("ONNX model marked as Hub model but missing hf_hub_id metadata") + + # Load config from Hub + config = AutoConfig.from_pretrained(hf_hub_id, revision=hf_revision) + + # Try to load preprocessor from Hub + preprocessor = None + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + hub_loaders: list[Any] = [ + AutoProcessor, + AutoTokenizer, + AutoImageProcessor, + AutoFeatureExtractor, + ] + for loader_cls in hub_loaders: + try: + preprocessor = loader_cls.from_pretrained(hf_hub_id, revision=hf_revision) + break + except Exception: + continue + + return config, preprocessor + + if model_type == "local": + # Local model: Load from co-located files + config_path = onnx_dir / "config.json" + + if not config_path.exists(): + raise ValueError( + f"Local model but config.json not found at {config_path}. " + "The model may have been moved without its config files." + ) + + # Load config from local file + config = AutoConfig.from_pretrained(onnx_dir) + + # Try to load preprocessor from local files + preprocessor = None + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + local_loaders: list[Any] = [ + AutoProcessor, + AutoTokenizer, + AutoImageProcessor, + AutoFeatureExtractor, + ] + for loader_cls in local_loaders: + try: + preprocessor = loader_cls.from_pretrained(onnx_dir) + break + except Exception: + continue + + return config, preprocessor + + # Unknown or legacy model + raise ValueError( + f"ONNX model has unknown type '{model_type}'. " + "Was it exported with an older version of ModelExport? " + "Please re-export the model." + ) diff --git a/src/winml/modelkit/utils/native_stderr.py b/src/winml/modelkit/utils/native_stderr.py index 87d3aea02..8802a3c20 100644 --- a/src/winml/modelkit/utils/native_stderr.py +++ b/src/winml/modelkit/utils/native_stderr.py @@ -23,6 +23,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING + if TYPE_CHECKING: from collections.abc import Iterator diff --git a/src/winml/modelkit/utils/optimum_loader.py b/src/winml/modelkit/utils/optimum_loader.py index 90f74c923..cf74dd36c 100644 --- a/src/winml/modelkit/utils/optimum_loader.py +++ b/src/winml/modelkit/utils/optimum_loader.py @@ -1,160 +1,160 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Optimum integration utilities for loading ONNX models with HuggingFace configurations. - -This module provides seamless integration with HuggingFace Optimum for inference -using exported ONNX models with preserved hierarchy and Hub metadata. -""" - -import logging -import shutil -import tempfile -from pathlib import Path -from typing import Any, cast - -from .hub_utils import load_hf_components_from_onnx - - -logger = logging.getLogger(__name__) - - -class OptimumONNXModel: - """Wrapper for seamless Optimum integration with Hub metadata.""" - - @classmethod - def from_onnx( - cls, onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any - ) -> tuple[Any, Any]: - """Load Optimum model from ONNX with Hub metadata. - - Args: - onnx_path: Path to ONNX model - task: Task type or "auto" to detect - device: Device to run on - **kwargs: Additional arguments for ORTModel - - Returns: - Tuple of (model, preprocessor) - """ - # Load config and preprocessor - config, preprocessor = load_hf_components_from_onnx(onnx_path) - - # Auto-detect task if needed - if task == "auto": - task = cls._detect_task(config, onnx_path) - - # Get appropriate ORTModel class - ort_model_class = cls._get_ort_model_class(task) - - # Create temporary directory with required files - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Save config - config.save_pretrained(temp_path) - - # Save preprocessor if available - if preprocessor: - preprocessor.save_pretrained(temp_path) - - # Copy ONNX model - shutil.copy(onnx_path, temp_path / "model.onnx") - - # Load with Optimum - model = ort_model_class.from_pretrained( - temp_path, - provider="CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider", - **kwargs, - ) - - return model, preprocessor - - @staticmethod - def _detect_task(config: Any, onnx_path: str) -> str: - """Detect task from config and metadata.""" - from ..onnx import load_onnx - - # Try to get task from metadata - try: - onnx_model = load_onnx(onnx_path, validate=False) - for prop in onnx_model.metadata_props: - if prop.key == "hf_pipeline_tag": - return prop.value - except Exception: - pass - - # Check architectures - if hasattr(config, "architectures"): - arch = config.architectures[0] if config.architectures else "" - - task_mapping = { - "ForSequenceClassification": "text-classification", - "ForTokenClassification": "token-classification", - "ForQuestionAnswering": "question-answering", - "ForCausalLM": "text-generation", - "ForConditionalGeneration": "text2text-generation", - "ForImageClassification": "image-classification", - "ForObjectDetection": "object-detection", - "ForAudioClassification": "audio-classification", - } - - for pattern, task in task_mapping.items(): - if pattern in arch: - return task - - # Default - return "feature-extraction" - - @staticmethod - def _get_ort_model_class(task: str) -> type[Any]: - """Get appropriate ORTModel class for task.""" - from optimum.onnxruntime import ( - ORTModel, - ORTModelForAudioClassification, - ORTModelForCausalLM, - ORTModelForFeatureExtraction, - ORTModelForImageClassification, - ORTModelForQuestionAnswering, - ORTModelForSeq2SeqLM, - ORTModelForSequenceClassification, - ORTModelForTokenClassification, - ) - - task_to_model = { - "text-classification": ORTModelForSequenceClassification, - "token-classification": ORTModelForTokenClassification, - "question-answering": ORTModelForQuestionAnswering, - "text-generation": ORTModelForCausalLM, - "text2text-generation": ORTModelForSeq2SeqLM, - "translation": ORTModelForSeq2SeqLM, - "summarization": ORTModelForSeq2SeqLM, - "image-classification": ORTModelForImageClassification, - "audio-classification": ORTModelForAudioClassification, - "feature-extraction": ORTModelForFeatureExtraction, - } - - return cast("type[Any]", task_to_model.get(task, ORTModel)) - - -def load_optimum_model( - onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any -) -> tuple[Any, Any]: - """Convenience function to load an ONNX model for Optimum inference. - - Args: - onnx_path: Path to ONNX model exported with ModelExport - task: Task type (auto-detected if not specified) - device: Device to run on ('cpu' or 'cuda') - **kwargs: Additional arguments for ORTModel - - Returns: - Tuple of (model, preprocessor) - - Example: - >>> model, tokenizer = load_optimum_model("bert.onnx") - >>> inputs = tokenizer("Hello world!", return_tensors="pt") - >>> outputs = model(**inputs) - """ - return OptimumONNXModel.from_onnx(onnx_path, task, device, **kwargs) +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Optimum integration utilities for loading ONNX models with HuggingFace configurations. + +This module provides seamless integration with HuggingFace Optimum for inference +using exported ONNX models with preserved hierarchy and Hub metadata. +""" + +import logging +import shutil +import tempfile +from pathlib import Path +from typing import Any, cast + +from .hub_utils import load_hf_components_from_onnx + + +logger = logging.getLogger(__name__) + + +class OptimumONNXModel: + """Wrapper for seamless Optimum integration with Hub metadata.""" + + @classmethod + def from_onnx( + cls, onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any + ) -> tuple[Any, Any]: + """Load Optimum model from ONNX with Hub metadata. + + Args: + onnx_path: Path to ONNX model + task: Task type or "auto" to detect + device: Device to run on + **kwargs: Additional arguments for ORTModel + + Returns: + Tuple of (model, preprocessor) + """ + # Load config and preprocessor + config, preprocessor = load_hf_components_from_onnx(onnx_path) + + # Auto-detect task if needed + if task == "auto": + task = cls._detect_task(config, onnx_path) + + # Get appropriate ORTModel class + ort_model_class = cls._get_ort_model_class(task) + + # Create temporary directory with required files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Save config + config.save_pretrained(temp_path) + + # Save preprocessor if available + if preprocessor: + preprocessor.save_pretrained(temp_path) + + # Copy ONNX model + shutil.copy(onnx_path, temp_path / "model.onnx") + + # Load with Optimum + model = ort_model_class.from_pretrained( + temp_path, + provider="CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider", + **kwargs, + ) + + return model, preprocessor + + @staticmethod + def _detect_task(config: Any, onnx_path: str) -> str: + """Detect task from config and metadata.""" + from ..onnx import load_onnx + + # Try to get task from metadata + try: + onnx_model = load_onnx(onnx_path, validate=False) + for prop in onnx_model.metadata_props: + if prop.key == "hf_pipeline_tag": + return prop.value + except Exception: + pass + + # Check architectures + if hasattr(config, "architectures"): + arch = config.architectures[0] if config.architectures else "" + + task_mapping = { + "ForSequenceClassification": "text-classification", + "ForTokenClassification": "token-classification", + "ForQuestionAnswering": "question-answering", + "ForCausalLM": "text-generation", + "ForConditionalGeneration": "text2text-generation", + "ForImageClassification": "image-classification", + "ForObjectDetection": "object-detection", + "ForAudioClassification": "audio-classification", + } + + for pattern, task in task_mapping.items(): + if pattern in arch: + return task + + # Default + return "feature-extraction" + + @staticmethod + def _get_ort_model_class(task: str) -> type[Any]: + """Get appropriate ORTModel class for task.""" + from optimum.onnxruntime import ( + ORTModel, + ORTModelForAudioClassification, + ORTModelForCausalLM, + ORTModelForFeatureExtraction, + ORTModelForImageClassification, + ORTModelForQuestionAnswering, + ORTModelForSeq2SeqLM, + ORTModelForSequenceClassification, + ORTModelForTokenClassification, + ) + + task_to_model = { + "text-classification": ORTModelForSequenceClassification, + "token-classification": ORTModelForTokenClassification, + "question-answering": ORTModelForQuestionAnswering, + "text-generation": ORTModelForCausalLM, + "text2text-generation": ORTModelForSeq2SeqLM, + "translation": ORTModelForSeq2SeqLM, + "summarization": ORTModelForSeq2SeqLM, + "image-classification": ORTModelForImageClassification, + "audio-classification": ORTModelForAudioClassification, + "feature-extraction": ORTModelForFeatureExtraction, + } + + return cast("type[Any]", task_to_model.get(task, ORTModel)) + + +def load_optimum_model( + onnx_path: str, task: str = "auto", device: str = "cpu", **kwargs: Any +) -> tuple[Any, Any]: + """Convenience function to load an ONNX model for Optimum inference. + + Args: + onnx_path: Path to ONNX model exported with ModelExport + task: Task type (auto-detected if not specified) + device: Device to run on ('cpu' or 'cuda') + **kwargs: Additional arguments for ORTModel + + Returns: + Tuple of (model, preprocessor) + + Example: + >>> model, tokenizer = load_optimum_model("bert.onnx") + >>> inputs = tokenizer("Hello world!", return_tensors="pt") + >>> outputs = model(**inputs) + """ + return OptimumONNXModel.from_onnx(onnx_path, task, device, **kwargs)