Skip to content

Commit 0940c6e

Browse files
committed
fixes inference model signature
1 parent 985d44f commit 0940c6e

3 files changed

Lines changed: 24 additions & 7 deletions

File tree

datamint/mlflow/flavors/datamint_flavor.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import datamint.mlflow.flavors
55
from mlflow import pyfunc
66
from .model import DatamintModel
7-
import logging
87
from collections.abc import Sequence
98
from dataclasses import asdict
109
from packaging.requirements import Requirement
@@ -13,8 +12,10 @@
1312
FLAVOR_NAME = 'datamint'
1413

1514

16-
def _process_signature(signature: ModelSignature) -> ModelSignature:
15+
def _process_signature(signature: ModelSignature | None,
16+
python_model: DatamintModel) -> ModelSignature:
1717
from mlflow.types import ParamSchema, ParamSpec
18+
from mlflow.models.signature import _infer_signature_from_type_hints
1819

1920
# Define inference parameters
2021
params_schema = ParamSchema(
@@ -23,7 +24,19 @@ def _process_signature(signature: ModelSignature) -> ModelSignature:
2324
]
2425
)
2526

26-
current_params_sig = signature.params
27+
if signature is not None:
28+
current_params_sig = signature.params
29+
else:
30+
type_hints = python_model.predict_type_hints
31+
# context is only loaded when input_example exists
32+
signature = _infer_signature_from_type_hints(
33+
python_model=python_model,
34+
context=None,
35+
type_hints=type_hints,
36+
input_example=None,
37+
)
38+
current_params_sig = signature.params
39+
2740
# append our params to the existing signature
2841
if current_params_sig is None:
2942
signature.params = params_schema
@@ -40,7 +53,11 @@ def _process_input_example(input_example: ModelInputExample | None) -> tuple[Mod
4053
datamint_params = {
4154
"mode": "default",
4255
}
43-
if input_example is None or not isinstance(input_example, tuple):
56+
if input_example is None:
57+
from datamint.entities.resource import LocalResource
58+
input_resource = LocalResource(raw_data=bytes()).model_dump(mode='json')
59+
return [input_resource], datamint_params
60+
if not isinstance(input_example, tuple):
4461
return (input_example, datamint_params)
4562
data_example, params_example = input_example
4663
# merge params_example with datamint_params, giving precedence to datamint_params in case of conflicts
@@ -114,7 +131,7 @@ def _get_req_name(req):
114131
datamint_model._clear_linked_models_cache()
115132

116133
if signature is not None:
117-
signature = _process_signature(signature)
134+
signature = _process_signature(signature, datamint_model)
118135
input_example = _process_input_example(input_example)
119136

120137
return mlflow.pyfunc.save_model(

datamint/mlflow/flavors/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def predict(self,
379379
Args:
380380
model_input: List of Resource objects to process
381381
params: Optional configuration dictionary with keys:
382-
- mode (str): Prediction mode (default: 'standard')
382+
- mode (str): Prediction mode (default: 'default')
383383
- confidence_threshold (float): Filter by confidence
384384
- batch_size (int): Batch size for processing
385385
- render_annotation (bool): Return rendered images

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "datamint"
33
description = "A library for interacting with the Datamint API, designed for efficient data management, processing and Deep Learning workflows."
4-
version = "2.10.4"
4+
version = "2.10.5"
55
dynamic = ["dependencies"]
66
requires-python = ">=3.10"
77
readme = "README.md"

0 commit comments

Comments
 (0)