Skip to content

Commit 3b5e604

Browse files
committed
Inference api implemented
1 parent a05000f commit 3b5e604

3 files changed

Lines changed: 464 additions & 1 deletion

File tree

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
"""API handler for model inference endpoints (MLflow DataMint server)."""
2+
from typing import Any, Literal
3+
from collections.abc import Callable, Generator
4+
import json
5+
import logging
6+
import time
7+
8+
import httpx
9+
10+
from ..entity_base_api import EntityBaseApi, ApiConfig
11+
from datamint.entities.inferencejob import InferenceJob
12+
13+
logger = logging.getLogger(__name__)
14+
15+
_TERMINAL_STATUSES = frozenset({'completed', 'failed', 'cancelled', 'error'})
16+
17+
18+
class InferenceApi(EntityBaseApi[InferenceJob]):
19+
"""API handler for model inference endpoints.
20+
21+
Provides methods to submit inference jobs, poll their status,
22+
cancel running jobs, and use specialised prediction endpoints
23+
(image, frame, slice, volume).
24+
"""
25+
26+
def __init__(self,
27+
config: ApiConfig,
28+
client: httpx.Client | None = None) -> None:
29+
super().__init__(config, InferenceJob, 'datamint/api/v1/model-inference', client)
30+
31+
# ------------------------------------------------------------------
32+
# Helpers
33+
# ------------------------------------------------------------------
34+
35+
def _parse_job_response(self, data: dict) -> InferenceJob:
36+
"""Normalise a job-status response into an ``InferenceJob`` entity."""
37+
if 'job_id' in data:
38+
data['id'] = data.pop('job_id')
39+
return self._init_entity_obj(**data)
40+
41+
@staticmethod
42+
def _build_common_payload(
43+
model_name: str,
44+
model_version: int | None = None,
45+
model_alias: str | None = None,
46+
resource_id: str | None = None,
47+
file_path: str | None = None,
48+
save_results: bool = False,
49+
params: dict[str, Any] | None = None,
50+
) -> dict[str, Any]:
51+
"""Build the payload keys shared by every inference request."""
52+
payload: dict[str, Any] = {"model_name": model_name}
53+
if model_version is not None:
54+
payload["model_version"] = model_version
55+
if model_alias is not None:
56+
payload["model_alias"] = model_alias
57+
if resource_id is not None:
58+
payload["resource_id"] = resource_id
59+
if file_path is not None:
60+
payload["file_path"] = file_path
61+
if save_results:
62+
payload["save_results"] = save_results
63+
if params:
64+
payload["params"] = params
65+
return payload
66+
67+
# ------------------------------------------------------------------
68+
# Generic inference
69+
# ------------------------------------------------------------------
70+
71+
def submit(
72+
self,
73+
model_name: str,
74+
*,
75+
model_version: int | None = None,
76+
model_alias: str | None = None,
77+
resource_id: str | None = None,
78+
resource_ids: list[str] | None = None,
79+
file_path: str | None = None,
80+
file_paths: list[str] | None = None,
81+
save_results: bool = False,
82+
params: dict[str, Any] | None = None,
83+
) -> InferenceJob:
84+
"""Submit an inference job for background processing.
85+
86+
Args:
87+
model_name: Name of the registered model.
88+
model_version: Specific model version number.
89+
model_alias: Model alias (e.g. ``'champion'``).
90+
resource_id: Single resource ID from DataMint API.
91+
resource_ids: List of resource IDs.
92+
file_path: Local file path.
93+
file_paths: List of local file paths.
94+
save_results: Whether to save results to the API.
95+
params: Additional parameters forwarded to the model.
96+
97+
Returns:
98+
The created ``InferenceJob`` (with initial status).
99+
"""
100+
payload = self._build_common_payload(
101+
model_name,
102+
model_version=model_version,
103+
model_alias=model_alias,
104+
resource_id=resource_id,
105+
file_path=file_path,
106+
save_results=save_results,
107+
params=params,
108+
)
109+
if resource_ids is not None:
110+
payload["resource_ids"] = resource_ids
111+
if file_paths is not None:
112+
payload["file_paths"] = file_paths
113+
114+
response = self._make_request('POST', f'/{self.endpoint_base}', json=payload)
115+
data = response.json()
116+
return self.get_status(data['job_id'])
117+
118+
# ------------------------------------------------------------------
119+
# Status / cancel
120+
# ------------------------------------------------------------------
121+
122+
def get_status(self, job_id: str) -> InferenceJob:
123+
"""Get the current status of an inference job.
124+
125+
Args:
126+
job_id: The job identifier.
127+
128+
Returns:
129+
An ``InferenceJob`` populated with the latest status.
130+
"""
131+
response = self._make_request('GET', f'/{self.endpoint_base}/status/{job_id}')
132+
return self._parse_job_response(response.json())
133+
134+
def get_by_id(self, entity_id: str) -> InferenceJob:
135+
"""Alias for ``get_status`` to satisfy ``EntityBaseApi`` interface."""
136+
return self.get_status(entity_id)
137+
138+
def stream_status(self, job_id: str) -> Generator[dict[str, Any], None, None]:
139+
"""Stream status updates for an inference job via Server-Sent Events.
140+
141+
Yields dictionaries parsed from SSE ``data:`` lines until the
142+
stream is closed by the server.
143+
144+
Args:
145+
job_id: The job identifier.
146+
147+
Yields:
148+
Parsed JSON dictionaries for each SSE event.
149+
"""
150+
with self._stream_request('GET', f'/{self.endpoint_base}/status/{job_id}/stream') as resp:
151+
for line in resp.iter_lines():
152+
if line.startswith('data:'):
153+
payload = line[len('data:'):].strip()
154+
if payload:
155+
yield json.loads(payload)
156+
157+
def wait(
158+
self,
159+
job: str | InferenceJob,
160+
*,
161+
on_status: Callable[[InferenceJob], None] | None = None,
162+
poll_interval: float = 2.0,
163+
timeout: float | None = None,
164+
) -> InferenceJob:
165+
"""Block until an inference job reaches a terminal state.
166+
167+
First attempts to follow the SSE stream. If the stream is
168+
unavailable or drops early the method falls back to polling
169+
``get_status`` at *poll_interval* seconds.
170+
171+
Args:
172+
job: Job ID string or ``InferenceJob`` entity.
173+
on_status: Optional callback invoked with an updated
174+
``InferenceJob`` each time a status update is received.
175+
poll_interval: Seconds between polls when falling back to
176+
polling mode. Default ``2.0``.
177+
timeout: Maximum seconds to wait. ``None`` means wait
178+
indefinitely. Raises ``TimeoutError`` on expiry.
179+
180+
Returns:
181+
The ``InferenceJob`` in its terminal state.
182+
183+
Raises:
184+
TimeoutError: If *timeout* is set and the job has not
185+
finished within that duration.
186+
"""
187+
job_id = self._entid(job) if not isinstance(job, str) else job
188+
deadline = (time.monotonic() + timeout) if timeout is not None else None
189+
190+
def _check_timeout() -> None:
191+
if deadline is not None and time.monotonic() >= deadline:
192+
raise TimeoutError(
193+
f"Inference job {job_id} did not finish within {timeout}s"
194+
)
195+
196+
# --- Try SSE stream first ---
197+
try:
198+
for event in self.stream_status(job_id):
199+
_check_timeout()
200+
status_str = event.get('status', '')
201+
current_job = self._parse_job_response(event)
202+
if on_status is not None:
203+
on_status(current_job)
204+
if status_str.lower() in _TERMINAL_STATUSES:
205+
return current_job
206+
except Exception as e:
207+
logger.warning(f"SSE stream ended or failed ({e}); falling back to polling")
208+
209+
# --- Polling fallback ---
210+
while True:
211+
_check_timeout()
212+
current_job = self.get_status(job_id)
213+
if on_status is not None:
214+
on_status(current_job)
215+
if current_job.status.lower() in _TERMINAL_STATUSES:
216+
return current_job
217+
time.sleep(poll_interval)
218+
219+
def cancel(self, job: str | InferenceJob) -> bool:
220+
"""Cancel a running inference job.
221+
222+
Args:
223+
job: Job ID string or ``InferenceJob`` entity.
224+
225+
Returns:
226+
``True`` if the cancellation was acknowledged.
227+
"""
228+
job_id = self._entid(job)
229+
response = self._make_request('POST', f'/{self.endpoint_base}/cancel/{job_id}')
230+
return response.json().get('success', False)
231+
232+
# ------------------------------------------------------------------
233+
# Specialised prediction endpoints
234+
# ------------------------------------------------------------------
235+
236+
def predict_image(
237+
self,
238+
model_name: str,
239+
*,
240+
model_version: int | None = None,
241+
model_alias: str | None = None,
242+
resource_id: str | None = None,
243+
file_path: str | None = None,
244+
save_results: bool = False,
245+
params: dict[str, Any] | None = None,
246+
) -> InferenceJob:
247+
"""Submit an image prediction job.
248+
249+
Args:
250+
model_name: Name of the registered model.
251+
model_version: Specific model version number.
252+
model_alias: Model alias (e.g. ``'champion'``).
253+
resource_id: Resource ID from DataMint API.
254+
file_path: Local file path.
255+
save_results: Whether to save results.
256+
params: Additional parameters.
257+
258+
Returns:
259+
The created ``InferenceJob``.
260+
"""
261+
payload = self._build_common_payload(
262+
model_name,
263+
model_version=model_version,
264+
model_alias=model_alias,
265+
resource_id=resource_id,
266+
file_path=file_path,
267+
save_results=save_results,
268+
params=params,
269+
)
270+
response = self._make_request('POST', f'/{self.endpoint_base}/predict-image', json=payload)
271+
data = response.json()
272+
return self.get_status(data['job_id'])
273+
274+
def predict_frame(
275+
self,
276+
model_name: str,
277+
frame_index: int,
278+
*,
279+
model_version: int | None = None,
280+
model_alias: str | None = None,
281+
resource_id: str | None = None,
282+
file_path: str | None = None,
283+
save_results: bool = False,
284+
params: dict[str, Any] | None = None,
285+
) -> InferenceJob:
286+
"""Submit a frame-specific prediction job (for video resources).
287+
288+
Args:
289+
model_name: Name of the registered model.
290+
frame_index: Frame index to process.
291+
model_version: Specific model version number.
292+
model_alias: Model alias.
293+
resource_id: Resource ID from DataMint API.
294+
file_path: Local file path.
295+
save_results: Whether to save results.
296+
params: Additional parameters.
297+
298+
Returns:
299+
The created ``InferenceJob``.
300+
"""
301+
payload = self._build_common_payload(
302+
model_name,
303+
model_version=model_version,
304+
model_alias=model_alias,
305+
resource_id=resource_id,
306+
file_path=file_path,
307+
save_results=save_results,
308+
params=params,
309+
)
310+
payload["frame_index"] = frame_index
311+
response = self._make_request('POST', f'/{self.endpoint_base}/predict-frame', json=payload)
312+
data = response.json()
313+
return self.get_status(data['job_id'])
314+
315+
def predict_slice(
316+
self,
317+
model_name: str,
318+
slice_index: int,
319+
axis: Literal['axial', 'sagittal', 'coronal'],
320+
*,
321+
model_version: int | None = None,
322+
model_alias: str | None = None,
323+
resource_id: str | None = None,
324+
file_path: str | None = None,
325+
save_results: bool = False,
326+
params: dict[str, Any] | None = None,
327+
) -> InferenceJob:
328+
"""Submit a slice-specific prediction job for 3D volumes.
329+
330+
Args:
331+
model_name: Name of the registered model.
332+
slice_index: Slice index to process.
333+
axis: Anatomical axis (``'axial'``, ``'sagittal'``, or ``'coronal'``).
334+
model_version: Specific model version number.
335+
model_alias: Model alias.
336+
resource_id: Resource ID from DataMint API.
337+
file_path: Local file path.
338+
save_results: Whether to save results.
339+
params: Additional parameters.
340+
341+
Returns:
342+
The created ``InferenceJob``.
343+
"""
344+
payload = self._build_common_payload(
345+
model_name,
346+
model_version=model_version,
347+
model_alias=model_alias,
348+
resource_id=resource_id,
349+
file_path=file_path,
350+
save_results=save_results,
351+
params=params,
352+
)
353+
payload["slice_index"] = slice_index
354+
payload["axis"] = axis
355+
response = self._make_request('POST', f'/{self.endpoint_base}/predict-slice', json=payload)
356+
data = response.json()
357+
return self.get_status(data['job_id'])
358+
359+
def predict_volume(
360+
self,
361+
model_name: str,
362+
*,
363+
model_version: int | None = None,
364+
model_alias: str | None = None,
365+
resource_id: str | None = None,
366+
file_path: str | None = None,
367+
save_results: bool = False,
368+
params: dict[str, Any] | None = None,
369+
) -> InferenceJob:
370+
"""Submit a volume prediction job.
371+
372+
Args:
373+
model_name: Name of the registered model.
374+
model_version: Specific model version number.
375+
model_alias: Model alias.
376+
resource_id: Resource ID from DataMint API.
377+
file_path: Local file path.
378+
save_results: Whether to save results.
379+
params: Additional parameters.
380+
381+
Returns:
382+
The created ``InferenceJob``.
383+
"""
384+
payload = self._build_common_payload(
385+
model_name,
386+
model_version=model_version,
387+
model_alias=model_alias,
388+
resource_id=resource_id,
389+
file_path=file_path,
390+
save_results=save_results,
391+
params=params,
392+
)
393+
response = self._make_request('POST', f'/{self.endpoint_base}/predict-volume', json=payload)
394+
data = response.json()
395+
return self.get_status(data['job_id'])

0 commit comments

Comments
 (0)