Skip to content

Commit 9186819

Browse files
committed
Fixed logging when running the deployed model; fixed model params signature not present
1 parent 4d89d2f commit 9186819

4 files changed

Lines changed: 74 additions & 4 deletions

File tree

datamint/mlflow/__init__.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
if mlflow_utils.is_tracking_uri_set():
1616
_LOGGER.warning("MLflow tracking URI is already set before patching get_tracking_uri.")
1717

18+
1819
@wraps(_original_get_tracking_uri)
1920
def _patched_get_tracking_uri(*args, **kwargs):
2021
"""Patched version of get_tracking_uri that ensures MLflow environment is set up first.
@@ -47,6 +48,56 @@ def _patched_get_tracking_uri(*args, **kwargs):
4748
# Replace the original function with our patched version
4849
mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
4950

51+
_ALREADY_CONFIGURED_LOGGING = False
52+
53+
54+
def _configure_mlflow_loggers():
55+
global _ALREADY_CONFIGURED_LOGGING
56+
if _ALREADY_CONFIGURED_LOGGING:
57+
return
58+
59+
from mlflow.environment_variables import MLFLOW_LOGGING_LEVEL
60+
from mlflow.utils.logging_utils import SuppressLogFilter, get_mlflow_log_level
61+
import logging.config
62+
import rich.logging
63+
import os
64+
65+
if 'MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT' not in os.environ:
66+
# probably not running in mlflow server, so no need to configure our mlflow loggers
67+
return
68+
69+
_ALREADY_CONFIGURED_LOGGING = True
70+
71+
logging.config.dictConfig(
72+
{
73+
"version": 1,
74+
"disable_existing_loggers": False,
75+
"handlers": {
76+
"datamint_mlflow_handler": {
77+
"class": "rich.logging.RichHandler",
78+
"filters": ["suppress_in_thread"],
79+
},
80+
},
81+
"loggers": {
82+
'datamint.mlflow': {
83+
"handlers": ["datamint_mlflow_handler"],
84+
"level": get_mlflow_log_level(),
85+
"propagate": False,
86+
},
87+
},
88+
"filters": {
89+
"suppress_in_thread": {
90+
"()": SuppressLogFilter,
91+
}
92+
},
93+
}
94+
)
95+
_LOGGER.info("Configured MLflow loggers with RichHandler and level %s", get_mlflow_log_level())
96+
97+
try:
98+
_configure_mlflow_loggers()
99+
except Exception as e:
100+
_LOGGER.error("Failed to configure MLflow loggers: %s", e)
50101

51102
if TYPE_CHECKING:
52103
from .flavors.model import DatamintModel
@@ -62,4 +113,5 @@ def _patched_get_tracking_uri(*args, **kwargs):
62113
},
63114
)
64115

65-
__all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured', 'DatamintModel']
116+
117+
__all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured', 'DatamintModel']

datamint/mlflow/flavors/datamint_flavor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,23 @@
1111

1212
FLAVOR_NAME = 'datamint'
1313

14-
_LOGGER = logging.getLogger(__name__)
1514

15+
def _process_signature(signature: ModelSignature | None) -> ModelSignature:
16+
from mlflow.types import ParamSchema, ParamSpec
17+
18+
# Define inference parameters
19+
params_schema = ParamSchema(
20+
[
21+
ParamSpec("mode", "string", "default"), # Default mode
22+
]
23+
)
24+
25+
if signature is None:
26+
signature = ModelSignature(params=params_schema)
27+
else:
28+
signature.params = params_schema
29+
return signature
30+
1631

1732
def save_model(datamint_model: DatamintModel,
1833
path,
@@ -77,6 +92,8 @@ def _get_req_name(req):
7792

7893
datamint_model._clear_linked_models_cache()
7994

95+
signature = _process_signature(signature)
96+
8097
return mlflow.pyfunc.save_model(
8198
path=path,
8299
python_model=datamint_model,

datamint/mlflow/flavors/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def predict(self,
399399

400400
# Parse and validate mode
401401
mode = self._parse_mode(model_input=model_input, params=params)
402+
logger.info(f"Received prediction request with {len(model_input)} resources and params {params} with mode '{mode.value}'")
402403

403404
# Route to appropriate prediction method
404405
try:
@@ -408,7 +409,7 @@ def predict(self,
408409
mode = PredictionMode.DEFAULT
409410
else:
410411
raise NotImplementedError
411-
logger.debug(f"Routing to '{mode.value}' mode for {len(model_input)} resources")
412+
logger.info(f"Routing to '{mode.value}' mode for {len(model_input)} resources")
412413
result = self._route_prediction(model_input, mode, params)
413414

414415
# Apply common post-processing

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "datamint"
33
description = "A library for interacting with the Datamint API, designed for efficient data management, processing and Deep Learning workflows."
4-
version = "2.10.1"
4+
version = "2.10.2"
55
dynamic = ["dependencies"]
66
requires-python = ">=3.10"
77
readme = "README.md"

0 commit comments

Comments
 (0)