Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ concurrency:
jobs:
lint:
runs-on: windows-latest
# Bumped from 5: combined mypy on 16 packages cold-starts at ~3-4 min on
# Bumped from 5: combined mypy on 23 packages cold-starts at ~3-4 min on
# Windows runners; the original 5-min ceiling cancelled mid-run.
timeout-minutes: 10

Expand Down Expand Up @@ -64,3 +64,10 @@ jobs:
-p winml.modelkit.loader
-p winml.modelkit.onnx
-p winml.modelkit.optim
-p winml.modelkit.optracing
-p winml.modelkit.quant
-p winml.modelkit.serve
-p winml.modelkit.session
-p winml.modelkit.sysinfo
-p winml.modelkit.telemetry
-p winml.modelkit.utils
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,21 @@ module = [
"sklearn.*", # used in eval/metrics; no community stubs
"evaluate", # HF evaluate, used in eval/; no community stubs
"evaluate.*",
# QAIRT (Qualcomm AI Runtime) SDK — imported only inside compile_qairt_bin.py,
# which runs in a separate venv-winml subprocess where the SDK is installed.
# Not a dependency of the main/CI environment, so it has no stubs here.
"qairt",
"qairt.*",
]
ignore_missing_imports = true

# windowsml ships no py.typed marker, but its source is installed and usable —
# analyze it directly (PEP 561 opt-in) so its inline annotations are honored
# instead of collapsing every symbol to Any.
[[tool.mypy.overrides]]
module = [ "windowsml", "windowsml.*" ]
follow_untyped_imports = true

# Relaxed modules: tests and WIP code

[[tool.mypy.overrides]]
Expand Down
8 changes: 6 additions & 2 deletions src/winml/modelkit/optracing/qnn/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -29,6 +29,10 @@
from .viewer import find_qnn_sdk, run_qhas_viewer


if TYPE_CHECKING:
from collections.abc import Iterator


logger = logging.getLogger(__name__)


Expand All @@ -53,7 +57,7 @@ def _resolve_shape(shape: list, default_dim: int = 1) -> list[int]:


@contextlib.contextmanager
def _working_directory(path: Path):
def _working_directory(path: Path) -> Iterator[None]:
"""Temporarily change CWD and restore on exit.

QNN EP writes ``*_schematic.bin`` into the process CWD, so we
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/optracing/qnn/qhas_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def _vtcm_ratio(op: dict) -> float | None:
total = vtcm_read + dram_read
if total == 0:
return None
return vtcm_read / total
return float(vtcm_read / total)
4 changes: 3 additions & 1 deletion src/winml/modelkit/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100))
"""

from typing import Any

from .config import QuantizeResult, WinMLQuantizationConfig


Expand All @@ -31,7 +33,7 @@
}


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy-load quantizer (imports onnxruntime.quantization)."""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
Expand Down
19 changes: 12 additions & 7 deletions src/winml/modelkit/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@

import asyncio
import base64
import binascii
import importlib.resources
import json
import logging
import time
from collections import deque
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -57,6 +58,8 @@


if TYPE_CHECKING:
from collections.abc import AsyncIterator

from ..utils.constants import EPNameOrAlias

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,7 +169,7 @@ def create_app(
"""

@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
app.state.start_time = time.time()
# Raise the modelkit logger to INFO so the ring handler receives
# operational records during `winml serve`. Tests that build the app
Expand All @@ -187,6 +190,8 @@ async def lifespan(app: FastAPI):
logger.info("Multi-model server started (empty — load via POST /v1/models)")
app.state.manager = mgr
else:
if model_path is None:
raise ValueError("single-model mode requires a model_path")
engine = InferenceEngine()
engine.load(model_path, task=task, device=device, ep=ep)
app.state.manager = SingleModelManager(engine, idle_timeout_sec=idle_timeout_sec)
Expand Down Expand Up @@ -240,7 +245,7 @@ def _get_mgr() -> SingleModelManager | ModelSlotManager:
mgr = getattr(app.state, "manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
return mgr
return cast("SingleModelManager | ModelSlotManager", mgr)

def _get_start_time() -> float:
return getattr(app.state, "start_time", time.time())
Expand Down Expand Up @@ -733,13 +738,13 @@ async def cli_command(command: str, request: CliRequest) -> CliResponse:
# ---------------------------------------------------------------------------


def _manifest_from_engine(engine: InferenceEngine) -> dict:
def _manifest_from_engine(engine: InferenceEngine) -> dict[str, Any]:
"""Build manifest dict from engine, trying build_manifest.json first."""
if engine.model_path:
manifest_file = Path(engine.model_path) / "build_manifest.json"
if manifest_file.exists():
try:
return json.loads(manifest_file.read_text())
return cast("dict[str, Any]", json.loads(manifest_file.read_text()))
except (json.JSONDecodeError, OSError) as e:
logger.warning("Failed to load manifest: %s", e)

Expand All @@ -753,7 +758,7 @@ def _manifest_from_engine(engine: InferenceEngine) -> dict:


def _build_model_schema(
manifest: dict,
manifest: dict[str, Any],
engine: InferenceEngine | None = None,
task_override: str | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -828,7 +833,7 @@ def _decode_rest_inputs(
if field and field.type in BINARY_TYPES and isinstance(value, str):
try:
result[name] = base64.b64decode(value)
except (ValueError, base64.binascii.Error) as exc:
except (ValueError, binascii.Error) as exc:
raise ValueError(f"Invalid base64 for input '{name}': {exc}") from exc
return result

Expand Down
7 changes: 5 additions & 2 deletions src/winml/modelkit/serve/cli_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tempfile
import time
from pathlib import Path
from typing import Any
from typing import Any, cast

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -212,7 +212,10 @@ def _extract_json_from_stdout(stdout: str) -> dict[str, Any] | list[Any] | None:
if start == -1:
break
try:
return json.loads(stdout[start : end + 1])
return cast(
"dict[str, Any] | list[Any] | None",
json.loads(stdout[start : end + 1]),
)
except json.JSONDecodeError:
continue
pos = end # try next (earlier) end_char
Expand Down
9 changes: 7 additions & 2 deletions src/winml/modelkit/session/ep_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ class WinMLEPRegistry:
available = registry.get_available_eps()
"""

# Set in __new__ before __init__ runs; declared here so mypy can resolve its
# type at the __init__ read site.
_initialized: bool

def __new__(cls) -> WinMLEPRegistry:
"""Singleton pattern."""
global _winml_ep_registry
Expand Down Expand Up @@ -269,7 +273,8 @@ def _load_ep_catalog(self) -> None:
"""Load EP catalog from WinML."""
from windowsml import EpCatalog

checked_eps: set[EPName] = set()
# Dedup of raw windowsml provider-name strings (pre-validation).
checked_eps: set[str] = set()
with EpCatalog() as catalog:
for provider in catalog.find_all_providers():
if provider.name in checked_eps:
Expand Down Expand Up @@ -402,4 +407,4 @@ def get_ort_available_providers(use_winml: bool = True) -> list[str]:
except Exception as e:
logger.debug("WinML discovery skipped: %s", e)

return ort.get_available_providers()
return cast("list[str]", ort.get_available_providers())
32 changes: 19 additions & 13 deletions src/winml/modelkit/session/monitor/_pdh.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,24 @@ def _collect_once(self) -> dict[str, float | int | None]:
continue

if entry.fmt == _PDH_FMT_DOUBLE:
val = _PdhFmtDouble()
dval = _PdhFmtDouble()
s = _pdh.PdhGetFormattedCounterValue(
entry.handle,
_PDH_FMT_DOUBLE | _PDH_FMT_NOCAP100,
ctypes.byref(ct),
ctypes.byref(val),
ctypes.byref(dval),
)
values[entry.name] = (
val.doubleValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None
dval.doubleValue if _pdh_ok(s) and _pdh_ok(dval.CStatus) else None
)
else:
val = _PdhFmtLarge()
lval = _PdhFmtLarge()
s = _pdh.PdhGetFormattedCounterValue(
entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(val)
entry.handle, _PDH_FMT_LARGE, ctypes.byref(ct), ctypes.byref(lval)
)
values[entry.name] = (
lval.largeValue if _pdh_ok(s) and _pdh_ok(lval.CStatus) else None
)
values[entry.name] = val.largeValue if _pdh_ok(s) and _pdh_ok(val.CStatus) else None

return values

Expand Down Expand Up @@ -403,16 +405,17 @@ def __init__(
self._stop_event = threading.Event()
self._lock = threading.Lock()
self._util_samples: list[float] = []
self._memory_local_bytes: list[int] = []
self._memory_shared_bytes: list[int] = []
# PDH counters arrive as float|int (double vs large format), so store as float.
self._memory_local_bytes: list[float] = []
self._memory_shared_bytes: list[float] = []
self._cpu_samples: list[float] = []
self._ram_used_bytes: list[int] = []
self._ram_used_bytes: list[float] = []
# Per-engtype snapshots of the monotonic Running Time counter. Stored
# as dicts because an adapter exposes multiple engines (e.g. several
# Compute_* on an NPU; 3D + Compute_* on a GPU) and the total adapter
# time is the sum of per-engine deltas.
self._running_time_start_ns: dict[str, int] = {}
self._running_time_end_ns: dict[str, int] = {}
self._running_time_start_ns: dict[str, float] = {}
self._running_time_end_ns: dict[str, float] = {}

def start(self) -> None:
"""Resolve target device, register PDH counters, start background thread.
Expand Down Expand Up @@ -532,8 +535,11 @@ def _poll_loop(self) -> None:
# normalize so cpu_pct stays 0..100 across machines.
cpu_divisor = float(os.cpu_count() or 1)
while not self._stop_event.is_set():
query = self._query
if query is None:
break
try:
values = self._query._collect_once()
values = query._collect_once()
# util_* counters are per-engine ratios over the same sample
# window, so max reports the most-loaded engine on the adapter.
# Don't sum — that would exceed 100% and duplicate what the
Expand Down Expand Up @@ -739,7 +745,7 @@ def running_time_delta_ns(self) -> int:
start = self._running_time_start_ns.get(key)
if start is None:
continue
total += max(0, end - start)
total += int(max(0, end - start))
return total

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/session/monitor/_xrt_smi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, cast


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -109,7 +109,7 @@ def snapshot(self) -> dict[str, Any]:
return {}

with Path(tmp_path).open(encoding="utf-8") as f:
return json.load(f)
return cast("dict[str, Any]", json.load(f))

except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as exc:
logger.debug("xrt-smi snapshot failed: %s", exc)
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/session/qairt/compile_qairt_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def extract_input_specs(model_path: Path) -> list[dict]:
# This script runs inside the QAIRT SDK's venv; use vanilla onnx.load.
model = onnx.load(str(model_path))

dtype_map = {
dtype_map: dict[int, type] = {
onnx.TensorProto.FLOAT: np.float32,
onnx.TensorProto.FLOAT16: np.float16,
onnx.TensorProto.INT8: np.int8,
Expand Down
6 changes: 3 additions & 3 deletions src/winml/modelkit/session/qairt/qairt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from ...utils.python_env import ensure_venv
from ..session import SessionState, WinMLSession
Expand Down Expand Up @@ -197,8 +197,8 @@ def _wrap_bin_to_onnx(self) -> None:

qnn_version = qnn_json_obj["info"]["buildId"]
for qnn_graph in qnn_json_obj["info"]["graphs"]:
qnn_input_tensor_dic = {}
qnn_output_tensor_dic = {}
qnn_input_tensor_dic: dict[str, Any] = {}
qnn_output_tensor_dic: dict[str, Any] = {}
graph_name = gen_qnn_ctx_onnx_model.parse_qnn_graph(
qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic
)
Expand Down
20 changes: 16 additions & 4 deletions src/winml/modelkit/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def run(
if self._session is None:
self.compile()

# compile() populates self._session or raises; bind a non-None local so
# the narrowing survives into the lambda / comprehension below (mypy drops
# self-attribute narrowing inside nested scopes).
session = self._session
if session is None:
raise InferenceError(
message="Session not available after compile",
context={},
)

if self._state == SessionState.ERROR:
raise InferenceError(
message="Session in error state",
Expand All @@ -340,16 +350,16 @@ def run(
self._validate_inputs(inputs)

# Prepare inputs (convert to numpy, enforce dtype)
ort_inputs = self._prepare_inputs(inputs, self._session)
ort_inputs = self._prepare_inputs(inputs, session)

# Run inference (with optional perf tracking)
output_names = [o.name for o in self._session.get_outputs()]
output_names = [o.name for o in session.get_outputs()]
if self._perf_stats:
outputs = self._perf_stats.record(
lambda: self._session.run(output_names, ort_inputs)
lambda: session.run(output_names, ort_inputs)
)
else:
outputs = self._session.run(output_names, ort_inputs)
outputs = session.run(output_names, ort_inputs)

# Build result dict
return dict(zip(output_names, outputs, strict=True))
Expand Down Expand Up @@ -396,6 +406,8 @@ def _build_session_options(self, device: str) -> tuple[ort.SessionOptions, str,

resolved_device, _ = resolve_device(device, ep=self._ep)
resolved_ep = normalize_ep_name(self._ep) if self._ep else resolve_eps(resolved_device)[0]
if resolved_ep is None:
raise ValueError(f"Unknown execution provider: {self._ep!r}")
device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper())

opts = self._session_options_factory()
Expand Down
Loading
Loading