44import datamint .mlflow .flavors
55from mlflow import pyfunc
66from .model import DatamintModel
7- import logging
87from collections .abc import Sequence
98from dataclasses import asdict
109from packaging .requirements import Requirement
1312FLAVOR_NAME = 'datamint'
1413
1514
16- def _process_signature (signature : ModelSignature ) -> ModelSignature :
15+ def _process_signature (signature : ModelSignature | None ,
16+ python_model : DatamintModel ) -> ModelSignature :
1717 from mlflow .types import ParamSchema , ParamSpec
18+ from mlflow .models .signature import _infer_signature_from_type_hints
1819
1920 # Define inference parameters
2021 params_schema = ParamSchema (
@@ -23,7 +24,19 @@ def _process_signature(signature: ModelSignature) -> ModelSignature:
2324 ]
2425 )
2526
26- current_params_sig = signature .params
27+ if signature is not None :
28+ current_params_sig = signature .params
29+ else :
30+ type_hints = python_model .predict_type_hints
31+ # context is only loaded when input_example exists
32+ signature = _infer_signature_from_type_hints (
33+ python_model = python_model ,
34+ context = None ,
35+ type_hints = type_hints ,
36+ input_example = None ,
37+ )
38+ current_params_sig = signature .params
39+
2740 # append our params to the existing signature
2841 if current_params_sig is None :
2942 signature .params = params_schema
@@ -40,7 +53,11 @@ def _process_input_example(input_example: ModelInputExample | None) -> tuple[Mod
4053 datamint_params = {
4154 "mode" : "default" ,
4255 }
43- if input_example is None or not isinstance (input_example , tuple ):
56+ if input_example is None :
57+ from datamint .entities .resource import LocalResource
58+ input_resource = LocalResource (raw_data = bytes ()).model_dump (mode = 'json' )
59+ return [input_resource ], datamint_params
60+ if not isinstance (input_example , tuple ):
4461 return (input_example , datamint_params )
4562 data_example , params_example = input_example
4663 # merge params_example with datamint_params, giving precedence to datamint_params in case of conflicts
@@ -114,7 +131,7 @@ def _get_req_name(req):
114131 datamint_model ._clear_linked_models_cache ()
115132
116133 if signature is not None :
117- signature = _process_signature (signature )
134+ signature = _process_signature (signature , datamint_model )
118135 input_example = _process_input_example (input_example )
119136
120137 return mlflow .pyfunc .save_model (
0 commit comments