@@ -41,6 +41,7 @@ def __init__(self, *args,
4141 log_model_at_end_only : bool = True ,
4242 additional_metadata : dict [str , Any ] | None = None ,
4343 extra_pip_requirements : list [str ] | None = None ,
44+ log_model_metrics : bool = True ,
4445 ** kwargs ):
4546 """
4647 MLFlowModelCheckpoint is a custom callback for PyTorch Lightning that integrates with MLFlow to log and register models.
@@ -52,6 +53,8 @@ def __init__(self, *args,
5253 log_model_at_end_only (bool): If True, only log the model to MLFlow at the end of the training instead of after every checkpoint save.
5354 additional_metadata (dict[str, Any] | None): Additional metadata to log with the model as a JSON file.
5455 extra_pip_requirements (list[str] | None): Additional pip requirements to include with the MLFlow model.
56+ log_model_metrics (bool): If True, automatically log test metrics to the MLflow LoggedModel entity
57+ after testing. Requires MLflow 3.x with LoggedModel support. Defaults to True.
5558 **kwargs: Keyword arguments for ModelCheckpoint.
5659 """
5760 # Ensure MLflow is configured when callback is initialized
@@ -75,7 +78,9 @@ def __init__(self, *args,
7578 self .register_model_on = register_model_on
7679 self .registered_model_info = None
7780 self .log_model_at_end_only = log_model_at_end_only
81+ self .log_model_metrics = log_model_metrics
7882 self ._last_model_uri = None
83+ self ._last_model_id : str | None = None
7984 self .last_saved_model_info = None
8085 self ._inferred_signature = None
8186 self ._input_example = None
@@ -237,6 +242,7 @@ def log_model_to_mlflow(self,
237242
238243 model .to (device = orig_device ) # Move the model back to its original device
239244 self ._last_model_uri = modelinfo .model_uri
245+ self ._last_model_id = getattr (modelinfo , 'model_id' , None )
240246 self .last_saved_model_info = modelinfo
241247
242248 # Log additional metadata after the model is saved
@@ -349,6 +355,7 @@ def _restore_model_uri(self, trainer: L.Trainer) -> None:
349355 """
350356 logger = _get_MLFlowLogger (trainer )
351357 self ._last_model_uri = None
358+ self ._last_model_id = None
352359 self .last_saved_model_info = None
353360 if logger is None :
354361 _LOGGER .warning ("No MLFlowLogger found. Cannot restore model URI." )
@@ -372,6 +379,7 @@ def _restore_model_uri(self, trainer: L.Trainer) -> None:
372379 return
373380 # get the most recent one
374381 self ._last_model_uri = retrieved_logged_models [0 ].model_uri
382+ self ._last_model_id = getattr (retrieved_logged_models [0 ], 'model_id' , None )
375383 try :
376384 self .last_saved_model_info = mlflow .models .get_model_info (self ._last_model_uri )
377385 except mlflow .exceptions .MlflowException as e :
@@ -388,9 +396,42 @@ def on_predict_start(self, trainer, pl_module):
388396 self ._restore_model_uri (trainer )
389397 return super ().on_predict_start (trainer , pl_module )
390398
399+ def _log_test_metrics_to_model (self , trainer : L .Trainer ) -> None :
400+ """Log test metrics from trainer.callback_metrics to the MLflow LoggedModel.
401+
402+ Filters metrics to only include those prefixed with 'test/' or 'test_',
403+ converts tensor values to floats, and logs them to the LoggedModel
404+ identified by ``self._last_model_id``.
405+ """
406+ if self ._last_model_id is None :
407+ _LOGGER .debug ("No model_id available. Skipping model metrics logging." )
408+ return
409+
410+ metrics : dict [str , float ] = {}
411+ for key , value in trainer .callback_metrics .items ():
412+ if not key .startswith (("test/" , "test_" )):
413+ continue
414+ try :
415+ metrics [key ] = float (value )
416+ except (TypeError , ValueError ):
417+ _LOGGER .debug (f"Skipping non-numeric metric '{ key } ': { value } " )
418+
419+ if not metrics :
420+ _LOGGER .info ("No test metrics found in callback_metrics to log." )
421+ return
422+
423+ try :
424+ mlflow .log_metrics (metrics , model_id = self ._last_model_id )
425+ _LOGGER .info (f"Logged { len (metrics )} test metrics to model { self ._last_model_id } ." )
426+ except Exception as e :
427+ _LOGGER .warning (f"Failed to log test metrics to model: { e } " )
428+
391429 def on_test_end (self , trainer : L .Trainer , pl_module : L .LightningModule ) -> None :
392430 super ().on_test_end (trainer , pl_module )
393431
432+ if self .log_model_metrics :
433+ self ._log_test_metrics_to_model (trainer )
434+
394435 if self .register_model_on == 'test' and self .register_model_name :
395436 self ._update_signature (trainer )
396437 self .register_model (trainer )
0 commit comments