diff --git a/checkpoint/orbax/checkpoint/_src/path/step.py b/checkpoint/orbax/checkpoint/_src/path/step.py index 6fdc34eca..3b6ccb9b6 100644 --- a/checkpoint/orbax/checkpoint/_src/path/step.py +++ b/checkpoint/orbax/checkpoint/_src/path/step.py @@ -36,6 +36,7 @@ from orbax.checkpoint._src.metadata import step_metadata_serialization from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import storage_backend as storage_backend_lib # pylint: disable=unused-import from orbax.checkpoint._src.path import temporary_paths # Allowed checkpoint step naming using any non empty `step_prefix`. @@ -190,7 +191,10 @@ def build_step_path( base_path: epath.PathLike, name_format: NameFormat[Metadata], step: int ) -> epath.Path: """Returns `step` path under `base_path` for step `name_format`.""" - return epath.Path(base_path) / name_format.build_name(step) + prefix = storage_backend_lib.resolve_storage_backend( + str(base_path) + ).step_path_prefix + return epath.Path(base_path) / f'{prefix}{name_format.build_name(step)}' def build_step_metadatas( @@ -354,11 +358,15 @@ def _build_metadata( return None if step is not None: - # step already known, just check exists. if step_path.exists(): return Metadata(step=step, path=step_path) - # Regex: [prefix]*(step) + name = step_path.name + path_prefix = storage_backend_lib.resolve_storage_backend( + str(step_path.parent) + ).step_path_prefix + if path_prefix and name.startswith(path_prefix): + name = name[len(path_prefix) :] if self.step_format_fixed_length and self.step_format_fixed_length > 0: zero_present = rf'0\d{{{self.step_format_fixed_length-1}}}' zero_not_present = rf'[1-9]\d{{{self.step_format_fixed_length-1}}}\d*' @@ -367,7 +375,7 @@ def _build_metadata( zero_padded_step_group = r'(0|[1-9]\d*)' name_regex = f'^{step_prefix_with_underscore(self.step_prefix)}{zero_padded_step_group}$' - match = re.search(name_regex, step_path.name) + match = re.search(name_regex, name) if match is None: return None (step_,) = match.groups() @@ -403,9 +411,11 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]: os.path.join(path_prefix, self.step_prefix or '') ) ] - else: - prefix = step_prefix_with_underscore(self.step_prefix) - return [x for x in base_path.iterdir() if x.name.startswith(prefix)] + backend = storage_backend_lib.resolve_storage_backend(str(base_path)) + assets = backend.list_checkpoints(str(base_path)) + paths = [storage_backend_lib.checkpoint_metadata_to_path(a) for a in assets] + prefix = f'{backend.step_path_prefix}{step_prefix_with_underscore(self.step_prefix)}' + return [p for p in paths if p.name.startswith(prefix)] def _get_step_paths_and_total_steps( self, base_path: epath.PathLike, is_primary_host: bool @@ -505,7 +515,7 @@ def _find_all_with_single_host_load_and_broadcast( ) base_path = epath.Path(base_path) paths_to_step_dict: dict[epath.Path, int] = { - base_path / self.build_name(step): step + build_step_path(base_path, self, step): step for step in padded_step_list if step >= 0 } diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py index 961994c82..07bf35dba 100644 --- a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py @@ -83,6 +83,10 @@ class StorageBackend(abc.ABC): different file systems. """ + @property + def step_path_prefix(self) -> str: + return '' + @abc.abstractmethod def list_checkpoints( self, @@ -136,10 +140,7 @@ def delete_checkpoint( class LocalStorageBackend(StorageBackend): - """A LocalStorageBackend implementation for local file systems. - - # TODO(b/425293362): Implement this class. - """ + """A LocalStorageBackend implementation for local file systems.""" def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]: """Returns the final checkpoint path directly.""" @@ -152,10 +153,18 @@ def list_checkpoints( self, base_path: str | epath.PathLike, ) -> list[CheckpointPathMetadata]: - """Lists checkpoints for a given base path and version pattern.""" - raise NotImplementedError( - 'list_checkpoints is not yet implemented for LocalStorageBackend.' - ) + """Lists checkpoints for a given base path.""" + base = epath.Path(base_path) + if not base.exists(): + return [] + return [ + CheckpointPathMetadata( + path=str(child), + status=CheckpointPathMetadata.Status.COMMITTED, + version=None, + ) + for child in base.iterdir() + ] def delete_checkpoint( self, @@ -167,3 +176,18 @@ def delete_checkpoint( logging.info('Removed old checkpoint (%s)', checkpoint_path) except OSError: logging.exception('Failed to remove checkpoint (%s)', checkpoint_path) + + +def resolve_storage_backend( + path: str, +) -> StorageBackend: + """Returns a StorageBackend for the given path.""" + del path + return LocalStorageBackend() + + +def checkpoint_metadata_to_path( + asset: CheckpointPathMetadata, +) -> epath.Path: + """Converts a CheckpointPathMetadata to its checkpoint epath.Path.""" + return epath.Path(asset.path)