1+ """API handler for model inference endpoints (MLflow DataMint server)."""
2+ from typing import Any , Literal
3+ from collections .abc import Callable , Generator
4+ import json
5+ import logging
6+ import time
7+
8+ import httpx
9+
10+ from ..entity_base_api import EntityBaseApi , ApiConfig
11+ from datamint .entities .inferencejob import InferenceJob
12+
13+ logger = logging .getLogger (__name__ )
14+
15+ _TERMINAL_STATUSES = frozenset ({'completed' , 'failed' , 'cancelled' , 'error' })
16+
17+
18+ class InferenceApi (EntityBaseApi [InferenceJob ]):
19+ """API handler for model inference endpoints.
20+
21+ Provides methods to submit inference jobs, poll their status,
22+ cancel running jobs, and use specialised prediction endpoints
23+ (image, frame, slice, volume).
24+ """
25+
26+ def __init__ (self ,
27+ config : ApiConfig ,
28+ client : httpx .Client | None = None ) -> None :
29+ super ().__init__ (config , InferenceJob , 'datamint/api/v1/model-inference' , client )
30+
31+ # ------------------------------------------------------------------
32+ # Helpers
33+ # ------------------------------------------------------------------
34+
35+ def _parse_job_response (self , data : dict ) -> InferenceJob :
36+ """Normalise a job-status response into an ``InferenceJob`` entity."""
37+ if 'job_id' in data :
38+ data ['id' ] = data .pop ('job_id' )
39+ return self ._init_entity_obj (** data )
40+
41+ @staticmethod
42+ def _build_common_payload (
43+ model_name : str ,
44+ model_version : int | None = None ,
45+ model_alias : str | None = None ,
46+ resource_id : str | None = None ,
47+ file_path : str | None = None ,
48+ save_results : bool = False ,
49+ params : dict [str , Any ] | None = None ,
50+ ) -> dict [str , Any ]:
51+ """Build the payload keys shared by every inference request."""
52+ payload : dict [str , Any ] = {"model_name" : model_name }
53+ if model_version is not None :
54+ payload ["model_version" ] = model_version
55+ if model_alias is not None :
56+ payload ["model_alias" ] = model_alias
57+ if resource_id is not None :
58+ payload ["resource_id" ] = resource_id
59+ if file_path is not None :
60+ payload ["file_path" ] = file_path
61+ if save_results :
62+ payload ["save_results" ] = save_results
63+ if params :
64+ payload ["params" ] = params
65+ return payload
66+
67+ # ------------------------------------------------------------------
68+ # Generic inference
69+ # ------------------------------------------------------------------
70+
71+ def submit (
72+ self ,
73+ model_name : str ,
74+ * ,
75+ model_version : int | None = None ,
76+ model_alias : str | None = None ,
77+ resource_id : str | None = None ,
78+ resource_ids : list [str ] | None = None ,
79+ file_path : str | None = None ,
80+ file_paths : list [str ] | None = None ,
81+ save_results : bool = False ,
82+ params : dict [str , Any ] | None = None ,
83+ ) -> InferenceJob :
84+ """Submit an inference job for background processing.
85+
86+ Args:
87+ model_name: Name of the registered model.
88+ model_version: Specific model version number.
89+ model_alias: Model alias (e.g. ``'champion'``).
90+ resource_id: Single resource ID from DataMint API.
91+ resource_ids: List of resource IDs.
92+ file_path: Local file path.
93+ file_paths: List of local file paths.
94+ save_results: Whether to save results to the API.
95+ params: Additional parameters forwarded to the model.
96+
97+ Returns:
98+ The created ``InferenceJob`` (with initial status).
99+ """
100+ payload = self ._build_common_payload (
101+ model_name ,
102+ model_version = model_version ,
103+ model_alias = model_alias ,
104+ resource_id = resource_id ,
105+ file_path = file_path ,
106+ save_results = save_results ,
107+ params = params ,
108+ )
109+ if resource_ids is not None :
110+ payload ["resource_ids" ] = resource_ids
111+ if file_paths is not None :
112+ payload ["file_paths" ] = file_paths
113+
114+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } ' , json = payload )
115+ data = response .json ()
116+ return self .get_status (data ['job_id' ])
117+
118+ # ------------------------------------------------------------------
119+ # Status / cancel
120+ # ------------------------------------------------------------------
121+
122+ def get_status (self , job_id : str ) -> InferenceJob :
123+ """Get the current status of an inference job.
124+
125+ Args:
126+ job_id: The job identifier.
127+
128+ Returns:
129+ An ``InferenceJob`` populated with the latest status.
130+ """
131+ response = self ._make_request ('GET' , f'/{ self .endpoint_base } /status/{ job_id } ' )
132+ return self ._parse_job_response (response .json ())
133+
134+ def get_by_id (self , entity_id : str ) -> InferenceJob :
135+ """Alias for ``get_status`` to satisfy ``EntityBaseApi`` interface."""
136+ return self .get_status (entity_id )
137+
138+ def stream_status (self , job_id : str ) -> Generator [dict [str , Any ], None , None ]:
139+ """Stream status updates for an inference job via Server-Sent Events.
140+
141+ Yields dictionaries parsed from SSE ``data:`` lines until the
142+ stream is closed by the server.
143+
144+ Args:
145+ job_id: The job identifier.
146+
147+ Yields:
148+ Parsed JSON dictionaries for each SSE event.
149+ """
150+ with self ._stream_request ('GET' , f'/{ self .endpoint_base } /status/{ job_id } /stream' ) as resp :
151+ for line in resp .iter_lines ():
152+ if line .startswith ('data:' ):
153+ payload = line [len ('data:' ):].strip ()
154+ if payload :
155+ yield json .loads (payload )
156+
157+ def wait (
158+ self ,
159+ job : str | InferenceJob ,
160+ * ,
161+ on_status : Callable [[InferenceJob ], None ] | None = None ,
162+ poll_interval : float = 2.0 ,
163+ timeout : float | None = None ,
164+ ) -> InferenceJob :
165+ """Block until an inference job reaches a terminal state.
166+
167+ First attempts to follow the SSE stream. If the stream is
168+ unavailable or drops early the method falls back to polling
169+ ``get_status`` at *poll_interval* seconds.
170+
171+ Args:
172+ job: Job ID string or ``InferenceJob`` entity.
173+ on_status: Optional callback invoked with an updated
174+ ``InferenceJob`` each time a status update is received.
175+ poll_interval: Seconds between polls when falling back to
176+ polling mode. Default ``2.0``.
177+ timeout: Maximum seconds to wait. ``None`` means wait
178+ indefinitely. Raises ``TimeoutError`` on expiry.
179+
180+ Returns:
181+ The ``InferenceJob`` in its terminal state.
182+
183+ Raises:
184+ TimeoutError: If *timeout* is set and the job has not
185+ finished within that duration.
186+ """
187+ job_id = self ._entid (job ) if not isinstance (job , str ) else job
188+ deadline = (time .monotonic () + timeout ) if timeout is not None else None
189+
190+ def _check_timeout () -> None :
191+ if deadline is not None and time .monotonic () >= deadline :
192+ raise TimeoutError (
193+ f"Inference job { job_id } did not finish within { timeout } s"
194+ )
195+
196+ # --- Try SSE stream first ---
197+ try :
198+ for event in self .stream_status (job_id ):
199+ _check_timeout ()
200+ status_str = event .get ('status' , '' )
201+ current_job = self ._parse_job_response (event )
202+ if on_status is not None :
203+ on_status (current_job )
204+ if status_str .lower () in _TERMINAL_STATUSES :
205+ return current_job
206+ except Exception as e :
207+ logger .warning (f"SSE stream ended or failed ({ e } ); falling back to polling" )
208+
209+ # --- Polling fallback ---
210+ while True :
211+ _check_timeout ()
212+ current_job = self .get_status (job_id )
213+ if on_status is not None :
214+ on_status (current_job )
215+ if current_job .status .lower () in _TERMINAL_STATUSES :
216+ return current_job
217+ time .sleep (poll_interval )
218+
219+ def cancel (self , job : str | InferenceJob ) -> bool :
220+ """Cancel a running inference job.
221+
222+ Args:
223+ job: Job ID string or ``InferenceJob`` entity.
224+
225+ Returns:
226+ ``True`` if the cancellation was acknowledged.
227+ """
228+ job_id = self ._entid (job )
229+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } /cancel/{ job_id } ' )
230+ return response .json ().get ('success' , False )
231+
232+ # ------------------------------------------------------------------
233+ # Specialised prediction endpoints
234+ # ------------------------------------------------------------------
235+
236+ def predict_image (
237+ self ,
238+ model_name : str ,
239+ * ,
240+ model_version : int | None = None ,
241+ model_alias : str | None = None ,
242+ resource_id : str | None = None ,
243+ file_path : str | None = None ,
244+ save_results : bool = False ,
245+ params : dict [str , Any ] | None = None ,
246+ ) -> InferenceJob :
247+ """Submit an image prediction job.
248+
249+ Args:
250+ model_name: Name of the registered model.
251+ model_version: Specific model version number.
252+ model_alias: Model alias (e.g. ``'champion'``).
253+ resource_id: Resource ID from DataMint API.
254+ file_path: Local file path.
255+ save_results: Whether to save results.
256+ params: Additional parameters.
257+
258+ Returns:
259+ The created ``InferenceJob``.
260+ """
261+ payload = self ._build_common_payload (
262+ model_name ,
263+ model_version = model_version ,
264+ model_alias = model_alias ,
265+ resource_id = resource_id ,
266+ file_path = file_path ,
267+ save_results = save_results ,
268+ params = params ,
269+ )
270+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } /predict-image' , json = payload )
271+ data = response .json ()
272+ return self .get_status (data ['job_id' ])
273+
274+ def predict_frame (
275+ self ,
276+ model_name : str ,
277+ frame_index : int ,
278+ * ,
279+ model_version : int | None = None ,
280+ model_alias : str | None = None ,
281+ resource_id : str | None = None ,
282+ file_path : str | None = None ,
283+ save_results : bool = False ,
284+ params : dict [str , Any ] | None = None ,
285+ ) -> InferenceJob :
286+ """Submit a frame-specific prediction job (for video resources).
287+
288+ Args:
289+ model_name: Name of the registered model.
290+ frame_index: Frame index to process.
291+ model_version: Specific model version number.
292+ model_alias: Model alias.
293+ resource_id: Resource ID from DataMint API.
294+ file_path: Local file path.
295+ save_results: Whether to save results.
296+ params: Additional parameters.
297+
298+ Returns:
299+ The created ``InferenceJob``.
300+ """
301+ payload = self ._build_common_payload (
302+ model_name ,
303+ model_version = model_version ,
304+ model_alias = model_alias ,
305+ resource_id = resource_id ,
306+ file_path = file_path ,
307+ save_results = save_results ,
308+ params = params ,
309+ )
310+ payload ["frame_index" ] = frame_index
311+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } /predict-frame' , json = payload )
312+ data = response .json ()
313+ return self .get_status (data ['job_id' ])
314+
315+ def predict_slice (
316+ self ,
317+ model_name : str ,
318+ slice_index : int ,
319+ axis : Literal ['axial' , 'sagittal' , 'coronal' ],
320+ * ,
321+ model_version : int | None = None ,
322+ model_alias : str | None = None ,
323+ resource_id : str | None = None ,
324+ file_path : str | None = None ,
325+ save_results : bool = False ,
326+ params : dict [str , Any ] | None = None ,
327+ ) -> InferenceJob :
328+ """Submit a slice-specific prediction job for 3D volumes.
329+
330+ Args:
331+ model_name: Name of the registered model.
332+ slice_index: Slice index to process.
333+ axis: Anatomical axis (``'axial'``, ``'sagittal'``, or ``'coronal'``).
334+ model_version: Specific model version number.
335+ model_alias: Model alias.
336+ resource_id: Resource ID from DataMint API.
337+ file_path: Local file path.
338+ save_results: Whether to save results.
339+ params: Additional parameters.
340+
341+ Returns:
342+ The created ``InferenceJob``.
343+ """
344+ payload = self ._build_common_payload (
345+ model_name ,
346+ model_version = model_version ,
347+ model_alias = model_alias ,
348+ resource_id = resource_id ,
349+ file_path = file_path ,
350+ save_results = save_results ,
351+ params = params ,
352+ )
353+ payload ["slice_index" ] = slice_index
354+ payload ["axis" ] = axis
355+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } /predict-slice' , json = payload )
356+ data = response .json ()
357+ return self .get_status (data ['job_id' ])
358+
359+ def predict_volume (
360+ self ,
361+ model_name : str ,
362+ * ,
363+ model_version : int | None = None ,
364+ model_alias : str | None = None ,
365+ resource_id : str | None = None ,
366+ file_path : str | None = None ,
367+ save_results : bool = False ,
368+ params : dict [str , Any ] | None = None ,
369+ ) -> InferenceJob :
370+ """Submit a volume prediction job.
371+
372+ Args:
373+ model_name: Name of the registered model.
374+ model_version: Specific model version number.
375+ model_alias: Model alias.
376+ resource_id: Resource ID from DataMint API.
377+ file_path: Local file path.
378+ save_results: Whether to save results.
379+ params: Additional parameters.
380+
381+ Returns:
382+ The created ``InferenceJob``.
383+ """
384+ payload = self ._build_common_payload (
385+ model_name ,
386+ model_version = model_version ,
387+ model_alias = model_alias ,
388+ resource_id = resource_id ,
389+ file_path = file_path ,
390+ save_results = save_results ,
391+ params = params ,
392+ )
393+ response = self ._make_request ('POST' , f'/{ self .endpoint_base } /predict-volume' , json = payload )
394+ data = response .json ()
395+ return self .get_status (data ['job_id' ])
0 commit comments