From 5af89435c70548dd96432a28a5fccd3c6fd082ed Mon Sep 17 00:00:00 2001 From: Shutong Li Date: Thu, 18 Dec 2025 13:58:37 -0800 Subject: [PATCH] Improve type annotations in checkpoint_args.py. PiperOrigin-RevId: 846404326 --- ffn/jax/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffn/jax/train.py b/ffn/jax/train.py index d0df882..13f8126 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -340,7 +340,7 @@ def _get_tf_writer(writers) -> metric_writers.SummaryWriter | None: def _get_ocp_args( train_iter: DataIterator, restore: bool = True -) -> DataIterator: +) -> DataIterator | ocp.args.CheckpointArgs: if isinstance(train_iter, tf.data.Iterator): return DatasetArgs(train_iter)