diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 19fb415d9..dd3444bc9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,7 +13,7 @@ concurrency: jobs: lint: runs-on: windows-latest - # Bumped from 5: combined mypy on 16 packages cold-starts at ~3-4 min on + # Bumped from 5: combined mypy on 23 packages cold-starts at ~3-4 min on # Windows runners; the original 5-min ceiling cancelled mid-run. timeout-minutes: 10 @@ -64,3 +64,10 @@ jobs: -p winml.modelkit.loader -p winml.modelkit.onnx -p winml.modelkit.optim + -p winml.modelkit.optracing + -p winml.modelkit.quant + -p winml.modelkit.serve + -p winml.modelkit.session + -p winml.modelkit.sysinfo + -p winml.modelkit.telemetry + -p winml.modelkit.utils diff --git a/pyproject.toml b/pyproject.toml index bf6afe894..46fa96dc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -506,9 +506,21 @@ module = [ "sklearn.*", # used in eval/metrics; no community stubs "evaluate", # HF evaluate, used in eval/; no community stubs "evaluate.*", + # QAIRT (Qualcomm AI Runtime) SDK — imported only inside compile_qairt_bin.py, + # which runs in a separate venv-winml subprocess where the SDK is installed. + # Not a dependency of the main/CI environment, so it has no stubs here. + "qairt", + "qairt.*", ] ignore_missing_imports = true +# windowsml ships no py.typed marker, but its source is installed and usable — +# analyze it directly (PEP 561 opt-in) so its inline annotations are honored +# instead of collapsing every symbol to Any. +[[tool.mypy.overrides]] +module = [ "windowsml", "windowsml.*" ] +follow_untyped_imports = true + # Relaxed modules: tests and WIP code [[tool.mypy.overrides]] diff --git a/src/winml/modelkit/optracing/qnn/profiler.py b/src/winml/modelkit/optracing/qnn/profiler.py index 12bb44403..0df9e22f2 100644 --- a/src/winml/modelkit/optracing/qnn/profiler.py +++ b/src/winml/modelkit/optracing/qnn/profiler.py @@ -19,7 +19,7 @@ import logging import os from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np @@ -29,6 +29,10 @@ from .viewer import find_qnn_sdk, run_qhas_viewer +if TYPE_CHECKING: + from collections.abc import Iterator + + logger = logging.getLogger(__name__) @@ -53,7 +57,7 @@ def _resolve_shape(shape: list, default_dim: int = 1) -> list[int]: @contextlib.contextmanager -def _working_directory(path: Path): +def _working_directory(path: Path) -> Iterator[None]: """Temporarily change CWD and restore on exit. QNN EP writes ``*_schematic.bin`` into the process CWD, so we diff --git a/src/winml/modelkit/optracing/qnn/qhas_parser.py b/src/winml/modelkit/optracing/qnn/qhas_parser.py index 9bc267ccc..77a19c0ea 100644 --- a/src/winml/modelkit/optracing/qnn/qhas_parser.py +++ b/src/winml/modelkit/optracing/qnn/qhas_parser.py @@ -110,4 +110,4 @@ def _vtcm_ratio(op: dict) -> float | None: total = vtcm_read + dram_read if total == 0: return None - return vtcm_read / total + return float(vtcm_read / total) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 2e6c2c279..bc8e6ee06 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -16,6 +16,8 @@ result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) """ +from typing import Any + from .config import QuantizeResult, WinMLQuantizationConfig @@ -31,7 +33,7 @@ } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy-load quantizer (imports onnxruntime.quantization).""" if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/src/winml/modelkit/serve/app.py b/src/winml/modelkit/serve/app.py index d19ee96bd..b38f96f6b 100644 --- a/src/winml/modelkit/serve/app.py +++ b/src/winml/modelkit/serve/app.py @@ -23,6 +23,7 @@ import asyncio import base64 +import binascii import importlib.resources import json import logging @@ -30,7 +31,7 @@ from collections import deque from contextlib import asynccontextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware @@ -57,6 +58,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from ..utils.constants import EPNameOrAlias logger = logging.getLogger(__name__) @@ -166,7 +169,7 @@ def create_app( """ @asynccontextmanager - async def lifespan(app: FastAPI): + async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.start_time = time.time() # Raise the modelkit logger to INFO so the ring handler receives # operational records during `winml serve`. Tests that build the app @@ -187,6 +190,8 @@ async def lifespan(app: FastAPI): logger.info("Multi-model server started (empty — load via POST /v1/models)") app.state.manager = mgr else: + if model_path is None: + raise ValueError("single-model mode requires a model_path") engine = InferenceEngine() engine.load(model_path, task=task, device=device, ep=ep) app.state.manager = SingleModelManager(engine, idle_timeout_sec=idle_timeout_sec) @@ -240,7 +245,7 @@ def _get_mgr() -> SingleModelManager | ModelSlotManager: mgr = getattr(app.state, "manager", None) if mgr is None: raise HTTPException(status_code=503, detail="Model not loaded yet") - return mgr + return cast("SingleModelManager | ModelSlotManager", mgr) def _get_start_time() -> float: return getattr(app.state, "start_time", time.time()) @@ -733,13 +738,13 @@ async def cli_command(command: str, request: CliRequest) -> CliResponse: # --------------------------------------------------------------------------- -def _manifest_from_engine(engine: InferenceEngine) -> dict: +def _manifest_from_engine(engine: InferenceEngine) -> dict[str, Any]: """Build manifest dict from engine, trying build_manifest.json first.""" if engine.model_path: manifest_file = Path(engine.model_path) / "build_manifest.json" if manifest_file.exists(): try: - return json.loads(manifest_file.read_text()) + return cast("dict[str, Any]", json.loads(manifest_file.read_text())) except (json.JSONDecodeError, OSError) as e: logger.warning("Failed to load manifest: %s", e) @@ -753,7 +758,7 @@ def _manifest_from_engine(engine: InferenceEngine) -> dict: def _build_model_schema( - manifest: dict, + manifest: dict[str, Any], engine: InferenceEngine | None = None, task_override: str | None = None, ) -> dict[str, Any]: @@ -828,7 +833,7 @@ def _decode_rest_inputs( if field and field.type in BINARY_TYPES and isinstance(value, str): try: result[name] = base64.b64decode(value) - except (ValueError, base64.binascii.Error) as exc: + except (ValueError, binascii.Error) as exc: raise ValueError(f"Invalid base64 for input '{name}': {exc}") from exc return result diff --git a/src/winml/modelkit/serve/cli_api.py b/src/winml/modelkit/serve/cli_api.py index c47a5cf56..2bb6d8053 100644 --- a/src/winml/modelkit/serve/cli_api.py +++ b/src/winml/modelkit/serve/cli_api.py @@ -27,7 +27,7 @@ import tempfile import time from pathlib import Path -from typing import Any +from typing import Any, cast from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -212,7 +212,10 @@ def _extract_json_from_stdout(stdout: str) -> dict[str, Any] | list[Any] | None: if start == -1: break try: - return json.loads(stdout[start : end + 1]) + return cast( + "dict[str, Any] | list[Any] | None", + json.loads(stdout[start : end + 1]), + ) except json.JSONDecodeError: continue pos = end # try next (earlier) end_char diff --git a/src/winml/modelkit/session/ep_registry.py b/src/winml/modelkit/session/ep_registry.py index 244da2a7a..d78112f28 100644 --- a/src/winml/modelkit/session/ep_registry.py +++ b/src/winml/modelkit/session/ep_registry.py @@ -228,6 +228,10 @@ class WinMLEPRegistry: available = registry.get_available_eps() """ + # Set in __new__ before __init__ runs; declared here so mypy can resolve its + # type at the __init__ read site. + _initialized: bool + def __new__(cls) -> WinMLEPRegistry: """Singleton pattern.""" global _winml_ep_registry @@ -269,7 +273,8 @@ def _load_ep_catalog(self) -> None: """Load EP catalog from WinML.""" from windowsml import EpCatalog - checked_eps: set[EPName] = set() + # Dedup of raw windowsml provider-name strings (pre-validation). + checked_eps: set[str] = set() with EpCatalog() as catalog: for provider in catalog.find_all_providers(): if provider.name in checked_eps: @@ -402,4 +407,4 @@ def get_ort_available_providers(use_winml: bool = True) -> list[str]: except Exception as e: logger.debug("WinML discovery skipped: %s", e) - return ort.get_available_providers() + return cast("list[str]", ort.get_available_providers()) diff --git a/src/winml/modelkit/session/monitor/_pdh.py b/src/winml/modelkit/session/monitor/_pdh.py index 0e0364dc4..60100be71 100644 --- a/src/winml/modelkit/session/monitor/_pdh.py +++ b/src/winml/modelkit/session/monitor/_pdh.py @@ -178,22 +178,24 @@ def _collect_once(self) -> dict[str, float | int | None]: continue if entry.fmt == _PDH_FMT_DOUBLE: - val = _PdhFmtDouble() + dval = _PdhFmtDouble() s = _pdh.PdhGetFormattedCounterValue( entry.handle, _PDH_FMT_DOUBLE | _PDH_FMT_NOCAP100, ctypes.byref(ct), - ctypes.byref(val), + ctypes.byref(dval), ) values[entry.name] = ( - val.doubleValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None + dval.doubleValue if _pdh_ok(s) and _pdh_ok(dval.CStatus) else None ) else: - val = _PdhFmtLarge() + lval = _PdhFmtLarge() s = _pdh.PdhGetFormattedCounterValue( - entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(val) + entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(lval) + ) + values[entry.name] = ( + lval.largeValue if _pdh_ok(s) and _pdh_ok(lval.CStatus) else None ) - values[entry.name] = val.largeValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None return values @@ -403,16 +405,17 @@ def __init__( self._stop_event = threading.Event() self._lock = threading.Lock() self._util_samples: list[float] = [] - self._memory_local_bytes: list[int] = [] - self._memory_shared_bytes: list[int] = [] + # PDH counters arrive as float|int (double vs large format), so store as float. + self._memory_local_bytes: list[float] = [] + self._memory_shared_bytes: list[float] = [] self._cpu_samples: list[float] = [] - self._ram_used_bytes: list[int] = [] + self._ram_used_bytes: list[float] = [] # Per-engtype snapshots of the monotonic Running Time counter. Stored # as dicts because an adapter exposes multiple engines (e.g. several # Compute_* on an NPU; 3D + Compute_* on a GPU) and the total adapter # time is the sum of per-engine deltas. - self._running_time_start_ns: dict[str, int] = {} - self._running_time_end_ns: dict[str, int] = {} + self._running_time_start_ns: dict[str, float] = {} + self._running_time_end_ns: dict[str, float] = {} def start(self) -> None: """Resolve target device, register PDH counters, start background thread. @@ -532,8 +535,11 @@ def _poll_loop(self) -> None: # normalize so cpu_pct stays 0..100 across machines. cpu_divisor = float(os.cpu_count() or 1) while not self._stop_event.is_set(): + query = self._query + if query is None: + break try: - values = self._query._collect_once() + values = query._collect_once() # util_* counters are per-engine ratios over the same sample # window, so max reports the most-loaded engine on the adapter. # Don't sum — that would exceed 100% and duplicate what the @@ -739,7 +745,7 @@ def running_time_delta_ns(self) -> int: start = self._running_time_start_ns.get(key) if start is None: continue - total += max(0, end - start) + total += int(max(0, end - start)) return total @staticmethod diff --git a/src/winml/modelkit/session/monitor/_xrt_smi.py b/src/winml/modelkit/session/monitor/_xrt_smi.py index aa11e15e5..c9fd5306a 100644 --- a/src/winml/modelkit/session/monitor/_xrt_smi.py +++ b/src/winml/modelkit/session/monitor/_xrt_smi.py @@ -21,7 +21,7 @@ import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, cast logger = logging.getLogger(__name__) @@ -109,7 +109,7 @@ def snapshot(self) -> dict[str, Any]: return {} with Path(tmp_path).open(encoding="utf-8") as f: - return json.load(f) + return cast("dict[str, Any]", json.load(f)) except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as exc: logger.debug("xrt-smi snapshot failed: %s", exc) diff --git a/src/winml/modelkit/session/qairt/compile_qairt_bin.py b/src/winml/modelkit/session/qairt/compile_qairt_bin.py index 9399361d2..e9225309c 100644 --- a/src/winml/modelkit/session/qairt/compile_qairt_bin.py +++ b/src/winml/modelkit/session/qairt/compile_qairt_bin.py @@ -52,7 +52,7 @@ def extract_input_specs(model_path: Path) -> list[dict]: # This script runs inside the QAIRT SDK's venv; use vanilla onnx.load. model = onnx.load(str(model_path)) - dtype_map = { + dtype_map: dict[int, type] = { onnx.TensorProto.FLOAT: np.float32, onnx.TensorProto.FLOAT16: np.float16, onnx.TensorProto.INT8: np.int8, diff --git a/src/winml/modelkit/session/qairt/qairt_session.py b/src/winml/modelkit/session/qairt/qairt_session.py index 47dfe2b9f..ce6bf60d9 100644 --- a/src/winml/modelkit/session/qairt/qairt_session.py +++ b/src/winml/modelkit/session/qairt/qairt_session.py @@ -11,7 +11,7 @@ import os import subprocess from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from ...utils.python_env import ensure_venv from ..session import SessionState, WinMLSession @@ -197,8 +197,8 @@ def _wrap_bin_to_onnx(self) -> None: qnn_version = qnn_json_obj["info"]["buildId"] for qnn_graph in qnn_json_obj["info"]["graphs"]: - qnn_input_tensor_dic = {} - qnn_output_tensor_dic = {} + qnn_input_tensor_dic: dict[str, Any] = {} + qnn_output_tensor_dic: dict[str, Any] = {} graph_name = gen_qnn_ctx_onnx_model.parse_qnn_graph( qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic ) diff --git a/src/winml/modelkit/session/session.py b/src/winml/modelkit/session/session.py index 9f689084b..2ec224b91 100644 --- a/src/winml/modelkit/session/session.py +++ b/src/winml/modelkit/session/session.py @@ -327,6 +327,16 @@ def run( if self._session is None: self.compile() + # compile() populates self._session or raises; bind a non-None local so + # the narrowing survives into the lambda / comprehension below (mypy drops + # self-attribute narrowing inside nested scopes). + session = self._session + if session is None: + raise InferenceError( + message="Session not available after compile", + context={}, + ) + if self._state == SessionState.ERROR: raise InferenceError( message="Session in error state", @@ -340,16 +350,16 @@ def run( self._validate_inputs(inputs) # Prepare inputs (convert to numpy, enforce dtype) - ort_inputs = self._prepare_inputs(inputs, self._session) + ort_inputs = self._prepare_inputs(inputs, session) # Run inference (with optional perf tracking) - output_names = [o.name for o in self._session.get_outputs()] + output_names = [o.name for o in session.get_outputs()] if self._perf_stats: outputs = self._perf_stats.record( - lambda: self._session.run(output_names, ort_inputs) + lambda: session.run(output_names, ort_inputs) ) else: - outputs = self._session.run(output_names, ort_inputs) + outputs = session.run(output_names, ort_inputs) # Build result dict return dict(zip(output_names, outputs, strict=True)) @@ -396,6 +406,8 @@ def _build_session_options(self, device: str) -> tuple[ort.SessionOptions, str, resolved_device, _ = resolve_device(device, ep=self._ep) resolved_ep = normalize_ep_name(self._ep) if self._ep else resolve_eps(resolved_device)[0] + if resolved_ep is None: + raise ValueError(f"Unknown execution provider: {self._ep!r}") device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper()) opts = self._session_options_factory() diff --git a/src/winml/modelkit/telemetry/click_group.py b/src/winml/modelkit/telemetry/click_group.py index cb391d368..20f14363c 100644 --- a/src/winml/modelkit/telemetry/click_group.py +++ b/src/winml/modelkit/telemetry/click_group.py @@ -39,7 +39,9 @@ class ActionGroup(click.Group): """Click group that auto-instruments every registered command.""" - def resolve_command(self, ctx, args): + def resolve_command( + self, ctx: click.Context, args: list[str] + ) -> tuple[str | None, click.Command | None, list[str]]: """Wrap the resolved subcommand with telemetry instrumentation.""" cmd_name, cmd, remaining = super().resolve_command(ctx, args) if cmd is None: @@ -106,7 +108,8 @@ def wrapped_invoke(ctx: click.Context) -> Any: success=success, ) - cmd.invoke = wrapped_invoke + # Intentional per-instance monkeypatch to wrap this command's invoke(). + cmd.invoke = wrapped_invoke # type: ignore[method-assign] setattr(cmd, _INSTRUMENTED_ATTR, True) return cmd diff --git a/src/winml/modelkit/telemetry/consent.py b/src/winml/modelkit/telemetry/consent.py index 61c6a9343..74cc66800 100644 --- a/src/winml/modelkit/telemetry/consent.py +++ b/src/winml/modelkit/telemetry/consent.py @@ -20,7 +20,7 @@ import sys import tempfile from pathlib import Path -from typing import Literal +from typing import Literal, cast from .utils import _resolve_user_home @@ -122,7 +122,9 @@ def _read_stored_consent() -> Consent | None: and stored_version < _CONSENT_VERSION ): return None - return value # type: ignore[return-value] + # value was validated against ("enabled", "disabled") above, but it comes + # from an untyped JSON dict so mypy still sees it as Any. + return cast("Consent", value) def _write_stored_consent(value: Consent) -> None: @@ -134,7 +136,9 @@ def _write_stored_consent(value: Consent) -> None: if _CONFIG_PATH is None: return data = _load_config() - tele = data.get("telemetry") if isinstance(data.get("telemetry"), dict) else {} + tele = data.get("telemetry") + if not isinstance(tele, dict): + tele = {} tele["consent"] = value tele["consent_version"] = _CONSENT_VERSION data["telemetry"] = tele diff --git a/src/winml/modelkit/telemetry/deviceid/_store.py b/src/winml/modelkit/telemetry/deviceid/_store.py index c75e4d2c5..152a5a719 100644 --- a/src/winml/modelkit/telemetry/deviceid/_store.py +++ b/src/winml/modelkit/telemetry/deviceid/_store.py @@ -12,6 +12,8 @@ from __future__ import annotations +from typing import cast + _REGISTRY_KEY = r"SOFTWARE\Microsoft\DeveloperTools\.modelkit" @@ -27,7 +29,7 @@ def read_key(name: str) -> str | None: return None if value_type != winreg.REG_SZ: return None - return value # already str for REG_SZ + return cast("str", value) # already str for REG_SZ def write_key(name: str, value: str) -> None: diff --git a/src/winml/modelkit/telemetry/library/exporter.py b/src/winml/modelkit/telemetry/library/exporter.py index 7a8600931..1e80757e9 100644 --- a/src/winml/modelkit/telemetry/library/exporter.py +++ b/src/winml/modelkit/telemetry/library/exporter.py @@ -23,6 +23,7 @@ from collections.abc import Sequence from opentelemetry.sdk._logs import ReadableLogRecord + from opentelemetry.sdk.resources import Resource _LOGGER = logging.getLogger(__name__) @@ -226,7 +227,8 @@ def shutdown(self) -> None: def _to_envelope(self, ld: ReadableLogRecord) -> dict: record = ld.log_record - timestamp = _ns_to_datetime(record.timestamp) + # OTel types timestamp as Optional; fall back to now if a record omits it. + timestamp = _ns_to_datetime(record.timestamp or time.time_ns()) data = dict(record.attributes or {}) ext = _resource_to_ext(ld.resource) return _build_envelope( @@ -266,7 +268,7 @@ def _ns_to_datetime(ts_ns: int) -> datetime: return datetime.fromtimestamp(ts_ns / 1_000_000_000, tz=timezone.utc) -def _resource_to_ext(resource) -> dict: +def _resource_to_ext(resource: Resource | None) -> dict: """Translate OpenTelemetry Resource attributes to CS 4.0 ext.* slots. Attribute name → CS slot mapping: diff --git a/src/winml/modelkit/telemetry/telemetry.py b/src/winml/modelkit/telemetry/telemetry.py index ddc74eb8b..ae7bd5cbd 100644 --- a/src/winml/modelkit/telemetry/telemetry.py +++ b/src/winml/modelkit/telemetry/telemetry.py @@ -28,6 +28,10 @@ if TYPE_CHECKING: + from opentelemetry._logs import Logger + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk.resources import Resource + from ..utils.constants import EPNameOrAlias @@ -88,8 +92,8 @@ class Telemetry: """ def __init__(self) -> None: - self._logger = None # set when enabled; None when disabled - self._provider = None + self._logger: Logger | None = None # set when enabled; None when disabled + self._provider: LoggerProvider | None = None self._disabled = True # set to False only after successful init self._app_instance_id = str(uuid.uuid4()) # Kept in the event schema for forward-compat: today @@ -147,7 +151,7 @@ def _try_init(self) -> None: self._disabled = True _clear_cache_quietly() - def _build_resource(self): + def _build_resource(self) -> Resource: from opentelemetry.sdk.resources import Resource device_id, id_status = get_or_create_device_id() @@ -219,6 +223,8 @@ def log_error(self, exc: BaseException) -> None: def _emit(self, event_name: str, attrs: dict[str, Any]) -> None: from opentelemetry._logs import LogRecord + if self._logger is None: + return filtered = _filter_allowlist(event_name, attrs) record = LogRecord( timestamp=time.time_ns(), diff --git a/src/winml/modelkit/telemetry/utils.py b/src/winml/modelkit/telemetry/utils.py index 5d79334ca..cbe9412b3 100644 --- a/src/winml/modelkit/telemetry/utils.py +++ b/src/winml/modelkit/telemetry/utils.py @@ -19,7 +19,11 @@ import re import traceback from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from types import TracebackType _PACKAGE_ROOT = "winml/modelkit" @@ -199,7 +203,12 @@ def __enter__(self) -> _ExclusiveFileLock: self._fd = fd return self - def __exit__(self, exc_type, exc, tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: import msvcrt if self._fd is None: diff --git a/src/winml/modelkit/utils/config_utils.py b/src/winml/modelkit/utils/config_utils.py index 911e20468..26e514fe0 100644 --- a/src/winml/modelkit/utils/config_utils.py +++ b/src/winml/modelkit/utils/config_utils.py @@ -11,7 +11,11 @@ import dataclasses import typing -from typing import Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast + + +if TYPE_CHECKING: + from _typeshed import DataclassInstance T = TypeVar("T") @@ -45,17 +49,18 @@ def merge_config(base: T, overrides: dict[str, Any] | T | None) -> T: if overrides is None: return base - # Convert overrides to dict if it's a config object + # Convert overrides to a plain dict if it's a config object. + overrides_dict: dict[str, Any] if hasattr(overrides, "to_dict"): - overrides = overrides.to_dict() + overrides_dict = overrides.to_dict() elif dataclasses.is_dataclass(overrides) and not isinstance(overrides, type): - overrides = dataclasses.asdict(overrides) + overrides_dict = dataclasses.asdict(overrides) elif isinstance(overrides, dict): - overrides = dict(overrides) # Copy to avoid mutation + overrides_dict = dict(overrides) # Copy to avoid mutation else: raise TypeError(f"overrides must be dict or config, got {type(overrides)}") - return _merge_into(base, overrides) + return _merge_into(base, overrides_dict) def _merge_into(base: T, overrides: dict[str, Any]) -> T: @@ -66,9 +71,9 @@ def _merge_into(base: T, overrides: dict[str, Any]) -> T: if isinstance(base, dict): # Handle dict-like config (e.g., WinMLOptimizationConfig) - result = type(base)(**base) # type: ignore[call-arg] + result = type(base)(**base) result.update(overrides) - return result # type: ignore[return-value] + return result # Primitive or unknown type - just return override return base @@ -78,7 +83,9 @@ def _merge_dataclass(base: T, overrides: dict[str, Any]) -> T: """Merge overrides into a dataclass, handling nested configs.""" # Get current field values current = {} - for f in dataclasses.fields(base): + # base is a dataclass instance by precondition (caller checks is_dataclass); + # cast for the fields() reflection call without widening base's T elsewhere. + for f in dataclasses.fields(cast("DataclassInstance", base)): current[f.name] = getattr(base, f.name) # Apply overrides recursively @@ -117,7 +124,7 @@ def _merge_dataclass(base: T, overrides: dict[str, Any]) -> T: current[key] = value # Create new instance - return type(base)(**current) # type: ignore[return-value] + return type(base)(**current) def _get_field_type(obj: Any, field_name: str) -> type | None: @@ -151,5 +158,5 @@ def _get_field_type(obj: Any, field_name: str) -> type | None: # It's an Optional/Union with None — extract the non-None type for arg in args: if arg is not type(None): - return arg # type: ignore[return-value] + return cast("type", arg) return resolved if isinstance(resolved, type) else None diff --git a/src/winml/modelkit/utils/console.py b/src/winml/modelkit/utils/console.py index a54194411..e83293d9b 100644 --- a/src/winml/modelkit/utils/console.py +++ b/src/winml/modelkit/utils/console.py @@ -117,12 +117,12 @@ def print_io_specs_detail( for i, t in enumerate(inputs): name = t.name or "(unnamed)" - shape_str = str(list(t.shape)) if getattr(t, "shape", None) else "dynamic" + shape_str = str(list(t.shape)) if t.shape else "dynamic" dtype_str = getattr(t, "dtype", None) or "?" label = "Input: " if i == 0 else " " console.print(f" {label}[cyan]{name:<18}[/cyan] {shape_str:<14} [dim]{dtype_str}[/dim]") - for i, t in enumerate(outputs): - name = t.name or "(unnamed)" + for i, out_t in enumerate(outputs): + name = out_t.name or "(unnamed)" # Fix #3: OutputTensorSpec only has name — show name only label = "Output: " if i == 0 else " " console.print(f" {label}[cyan]{name}[/cyan]") @@ -307,7 +307,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._refresh_disabled: bool = False - def refresh(self) -> None: # type: ignore[override] + def refresh(self) -> None: if self._refresh_disabled: return try: @@ -320,15 +320,15 @@ def refresh(self) -> None: # type: ignore[override] logger.debug("Disabling Live refresh after OSError", exc_info=True) @_swallow_oserror - def start(self, refresh: bool = False) -> None: # type: ignore[override] + def start(self, refresh: bool = False) -> None: super().start(refresh=refresh) @_swallow_oserror - def stop(self) -> None: # type: ignore[override] + def stop(self) -> None: super().stop() @_swallow_oserror - def update(self, renderable: RenderableType, *, refresh: bool = False) -> None: # type: ignore[override] + def update(self, renderable: RenderableType, *, refresh: bool = False) -> None: super().update(renderable, refresh=refresh) diff --git a/src/winml/modelkit/utils/constants.py b/src/winml/modelkit/utils/constants.py index 7b45d153a..4e85aaca6 100644 --- a/src/winml/modelkit/utils/constants.py +++ b/src/winml/modelkit/utils/constants.py @@ -7,7 +7,7 @@ from __future__ import annotations import sys -from typing import Literal, TypeAlias, get_args +from typing import Literal, TypeAlias, cast, get_args if sys.platform == "win32": @@ -137,7 +137,9 @@ def normalize_ep_name(ep: EPNameOrAlias | None) -> EPName | None: # the prior membership check narrowed ``ep_lower`` so the alias mapping is # total in this branch. ep_lower = ep.lower() - canonical = EP_ALIASES.get(ep_lower) # type: ignore[arg-type] + # ep_lower is an arbitrary lowercased string; cast to the key type for the + # lookup (.get tolerates non-alias keys, returning None). + canonical = EP_ALIASES.get(cast("EPAlias", ep_lower)) if canonical is not None: return canonical diff --git a/src/winml/modelkit/utils/data_utils.py b/src/winml/modelkit/utils/data_utils.py index bccf7acd1..364ab1b92 100644 --- a/src/winml/modelkit/utils/data_utils.py +++ b/src/winml/modelkit/utils/data_utils.py @@ -51,7 +51,10 @@ def pad_inputs( # Dynamic ONNX dims may be None or a string symbol; emit a # (0, 0) pair so later pairs stay aligned with their dim index. if not isinstance(exp, int): - pad.extend([0, 0]) + # Forward-looking: expected is typed list[int] today, but ONNX + # dynamic dims (None / str symbol) are a planned input (see TODO + # above). Keep the guard until the signature widens for them. + pad.extend([0, 0]) # type: ignore[unreachable] continue deficit = max(exp - val.shape[dim], 0) if mode == "right": diff --git a/src/winml/modelkit/utils/hub_utils.py b/src/winml/modelkit/utils/hub_utils.py index 1de403c8d..3da385a3b 100644 --- a/src/winml/modelkit/utils/hub_utils.py +++ b/src/winml/modelkit/utils/hub_utils.py @@ -59,7 +59,7 @@ def is_hub_model(model_name_or_path: str) -> tuple[bool, dict]: # Extract comprehensive metadata metadata = { "type": "hub", - "model_id": model_info.modelId, + "model_id": model_info.id, "sha": model_info.sha, "revision": revision or "main", "tags": model_info.tags if hasattr(model_info, "tags") else [], @@ -143,7 +143,7 @@ def add_prop(key: str, value: Any) -> None: # Get ModelExport version try: - from ..version import __version__ + from .. import __version__ export_version = __version__ except ImportError: @@ -293,12 +293,14 @@ def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: # Try to load preprocessor from Hub preprocessor = None - for loader_cls in [ + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + hub_loaders: list[Any] = [ AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoFeatureExtractor, - ]: + ] + for loader_cls in hub_loaders: try: preprocessor = loader_cls.from_pretrained(hf_hub_id, revision=hf_revision) break @@ -322,12 +324,14 @@ def load_hf_components_from_onnx(onnx_path: str) -> tuple[Any, Any]: # Try to load preprocessor from local files preprocessor = None - for loader_cls in [ + # Heterogeneous transformers Auto-loaders sharing a from_pretrained classmethod. + local_loaders: list[Any] = [ AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoFeatureExtractor, - ]: + ] + for loader_cls in local_loaders: try: preprocessor = loader_cls.from_pretrained(onnx_dir) break diff --git a/src/winml/modelkit/utils/native_stderr.py b/src/winml/modelkit/utils/native_stderr.py index ed6b8a0bb..8802a3c20 100644 --- a/src/winml/modelkit/utils/native_stderr.py +++ b/src/winml/modelkit/utils/native_stderr.py @@ -21,6 +21,11 @@ import re import sys from contextlib import contextmanager +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Iterator logger = logging.getLogger(__name__) @@ -47,7 +52,7 @@ @contextmanager -def suppress_native_stderr(): +def suppress_native_stderr() -> Iterator[None]: """Redirect native stderr to devnull. No-op on non-Windows.""" if sys.platform != "win32": yield @@ -68,7 +73,7 @@ def suppress_native_stderr(): @contextmanager -def capture_native_stderr(level: int = logging.INFO): +def capture_native_stderr(level: int = logging.INFO) -> Iterator[None]: """Capture native stderr via pipe and re-emit through Python logging. No-op on non-Windows. diff --git a/src/winml/modelkit/utils/optimum_loader.py b/src/winml/modelkit/utils/optimum_loader.py index c7a3a5b1d..cf74dd36c 100644 --- a/src/winml/modelkit/utils/optimum_loader.py +++ b/src/winml/modelkit/utils/optimum_loader.py @@ -12,7 +12,7 @@ import shutil import tempfile from pathlib import Path -from typing import Any +from typing import Any, cast from .hub_utils import load_hf_components_from_onnx @@ -135,7 +135,7 @@ def _get_ort_model_class(task: str) -> type[Any]: "feature-extraction": ORTModelForFeatureExtraction, } - return task_to_model.get(task, ORTModel) + return cast("type[Any]", task_to_model.get(task, ORTModel)) def load_optimum_model(