Skip to content

Commit 9466c95

Browse files
committed
Add support for logging test metrics to MLflow LoggedModel
1 parent 15a8647 commit 9466c95

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

datamint/mlflow/lightning/callbacks/modelcheckpoint.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, *args,
4141
log_model_at_end_only: bool = True,
4242
additional_metadata: dict[str, Any] | None = None,
4343
extra_pip_requirements: list[str] | None = None,
44+
log_model_metrics: bool = True,
4445
**kwargs):
4546
"""
4647
MLFlowModelCheckpoint is a custom callback for PyTorch Lightning that integrates with MLFlow to log and register models.
@@ -52,6 +53,8 @@ def __init__(self, *args,
5253
log_model_at_end_only (bool): If True, only log the model to MLFlow at the end of the training instead of after every checkpoint save.
5354
additional_metadata (dict[str, Any] | None): Additional metadata to log with the model as a JSON file.
5455
extra_pip_requirements (list[str] | None): Additional pip requirements to include with the MLFlow model.
56+
log_model_metrics (bool): If True, automatically log test metrics to the MLflow LoggedModel entity
57+
after testing. Requires MLflow 3.x with LoggedModel support. Defaults to True.
5558
**kwargs: Keyword arguments for ModelCheckpoint.
5659
"""
5760
# Ensure MLflow is configured when callback is initialized
@@ -75,7 +78,9 @@ def __init__(self, *args,
7578
self.register_model_on = register_model_on
7679
self.registered_model_info = None
7780
self.log_model_at_end_only = log_model_at_end_only
81+
self.log_model_metrics = log_model_metrics
7882
self._last_model_uri = None
83+
self._last_model_id: str | None = None
7984
self.last_saved_model_info = None
8085
self._inferred_signature = None
8186
self._input_example = None
@@ -237,6 +242,7 @@ def log_model_to_mlflow(self,
237242

238243
model.to(device=orig_device) # Move the model back to its original device
239244
self._last_model_uri = modelinfo.model_uri
245+
self._last_model_id = getattr(modelinfo, 'model_id', None)
240246
self.last_saved_model_info = modelinfo
241247

242248
# Log additional metadata after the model is saved
@@ -349,6 +355,7 @@ def _restore_model_uri(self, trainer: L.Trainer) -> None:
349355
"""
350356
logger = _get_MLFlowLogger(trainer)
351357
self._last_model_uri = None
358+
self._last_model_id = None
352359
self.last_saved_model_info = None
353360
if logger is None:
354361
_LOGGER.warning("No MLFlowLogger found. Cannot restore model URI.")
@@ -372,6 +379,7 @@ def _restore_model_uri(self, trainer: L.Trainer) -> None:
372379
return
373380
# get the most recent one
374381
self._last_model_uri = retrieved_logged_models[0].model_uri
382+
self._last_model_id = getattr(retrieved_logged_models[0], 'model_id', None)
375383
try:
376384
self.last_saved_model_info = mlflow.models.get_model_info(self._last_model_uri)
377385
except mlflow.exceptions.MlflowException as e:
@@ -388,9 +396,42 @@ def on_predict_start(self, trainer, pl_module):
388396
self._restore_model_uri(trainer)
389397
return super().on_predict_start(trainer, pl_module)
390398

399+
def _log_test_metrics_to_model(self, trainer: L.Trainer) -> None:
400+
"""Log test metrics from trainer.callback_metrics to the MLflow LoggedModel.
401+
402+
Filters metrics to only include those prefixed with 'test/' or 'test_',
403+
converts tensor values to floats, and logs them to the LoggedModel
404+
identified by ``self._last_model_id``.
405+
"""
406+
if self._last_model_id is None:
407+
_LOGGER.debug("No model_id available. Skipping model metrics logging.")
408+
return
409+
410+
metrics: dict[str, float] = {}
411+
for key, value in trainer.callback_metrics.items():
412+
if not key.startswith(("test/", "test_")):
413+
continue
414+
try:
415+
metrics[key] = float(value)
416+
except (TypeError, ValueError):
417+
_LOGGER.debug(f"Skipping non-numeric metric '{key}': {value}")
418+
419+
if not metrics:
420+
_LOGGER.info("No test metrics found in callback_metrics to log.")
421+
return
422+
423+
try:
424+
mlflow.log_metrics(metrics, model_id=self._last_model_id)
425+
_LOGGER.info(f"Logged {len(metrics)} test metrics to model {self._last_model_id}.")
426+
except Exception as e:
427+
_LOGGER.warning(f"Failed to log test metrics to model: {e}")
428+
391429
def on_test_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
392430
super().on_test_end(trainer, pl_module)
393431

432+
if self.log_model_metrics:
433+
self._log_test_metrics_to_model(trainer)
434+
394435
if self.register_model_on == 'test' and self.register_model_name:
395436
self._update_signature(trainer)
396437
self.register_model(trainer)

0 commit comments

Comments
 (0)