Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions checkpoint/orbax/checkpoint/_src/path/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import storage_backend as storage_backend_lib # pylint: disable=unused-import
from orbax.checkpoint._src.path import temporary_paths

# Allowed checkpoint step naming using any non empty `step_prefix`.
Expand Down Expand Up @@ -190,7 +191,10 @@ def build_step_path(
base_path: epath.PathLike, name_format: NameFormat[Metadata], step: int
) -> epath.Path:
"""Returns `step` path under `base_path` for step `name_format`."""
return epath.Path(base_path) / name_format.build_name(step)
prefix = storage_backend_lib.resolve_storage_backend(
str(base_path)
).step_path_prefix
return epath.Path(base_path) / f'{prefix}{name_format.build_name(step)}'


def build_step_metadatas(
Expand Down Expand Up @@ -354,11 +358,15 @@ def _build_metadata(
return None

if step is not None:
# step already known, just check exists.
if step_path.exists():
return Metadata(step=step, path=step_path)

# Regex: [prefix]*(step)
name = step_path.name
path_prefix = storage_backend_lib.resolve_storage_backend(
str(step_path.parent)
).step_path_prefix
if path_prefix and name.startswith(path_prefix):
name = name[len(path_prefix) :]
if self.step_format_fixed_length and self.step_format_fixed_length > 0:
zero_present = rf'0\d{{{self.step_format_fixed_length-1}}}'
zero_not_present = rf'[1-9]\d{{{self.step_format_fixed_length-1}}}\d*'
Expand All @@ -367,7 +375,7 @@ def _build_metadata(
zero_padded_step_group = r'(0|[1-9]\d*)'
name_regex = f'^{step_prefix_with_underscore(self.step_prefix)}{zero_padded_step_group}$'

match = re.search(name_regex, step_path.name)
match = re.search(name_regex, name)
if match is None:
return None
(step_,) = match.groups()
Expand Down Expand Up @@ -403,9 +411,11 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
os.path.join(path_prefix, self.step_prefix or '')
)
]
else:
prefix = step_prefix_with_underscore(self.step_prefix)
return [x for x in base_path.iterdir() if x.name.startswith(prefix)]
backend = storage_backend_lib.resolve_storage_backend(str(base_path))
assets = backend.list_checkpoints(str(base_path))
paths = [storage_backend_lib.checkpoint_metadata_to_path(a) for a in assets]
prefix = f'{backend.step_path_prefix}{step_prefix_with_underscore(self.step_prefix)}'
return [p for p in paths if p.name.startswith(prefix)]

def _get_step_paths_and_total_steps(
self, base_path: epath.PathLike, is_primary_host: bool
Expand Down Expand Up @@ -505,7 +515,7 @@ def _find_all_with_single_host_load_and_broadcast(
)
base_path = epath.Path(base_path)
paths_to_step_dict: dict[epath.Path, int] = {
base_path / self.build_name(step): step
build_step_path(base_path, self, step): step
for step in padded_step_list
if step >= 0
}
Expand Down
40 changes: 32 additions & 8 deletions checkpoint/orbax/checkpoint/_src/path/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class StorageBackend(abc.ABC):
different file systems.
"""

@property
def step_path_prefix(self) -> str:
return ''

@abc.abstractmethod
def list_checkpoints(
self,
Expand Down Expand Up @@ -136,10 +140,7 @@ def delete_checkpoint(


class LocalStorageBackend(StorageBackend):
"""A LocalStorageBackend implementation for local file systems.

# TODO(b/425293362): Implement this class.
"""
"""A LocalStorageBackend implementation for local file systems."""

def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]:
"""Returns the final checkpoint path directly."""
Expand All @@ -152,10 +153,18 @@ def list_checkpoints(
self,
base_path: str | epath.PathLike,
) -> list[CheckpointPathMetadata]:
"""Lists checkpoints for a given base path and version pattern."""
raise NotImplementedError(
'list_checkpoints is not yet implemented for LocalStorageBackend.'
)
"""Lists checkpoints for a given base path."""
base = epath.Path(base_path)
if not base.exists():
return []
return [
CheckpointPathMetadata(
path=str(child),
status=CheckpointPathMetadata.Status.COMMITTED,
version=None,
)
for child in base.iterdir()
]

def delete_checkpoint(
self,
Expand All @@ -167,3 +176,18 @@ def delete_checkpoint(
logging.info('Removed old checkpoint (%s)', checkpoint_path)
except OSError:
logging.exception('Failed to remove checkpoint (%s)', checkpoint_path)


def resolve_storage_backend(
path: str,
) -> StorageBackend:
"""Returns a StorageBackend for the given path."""
del path
return LocalStorageBackend()


def checkpoint_metadata_to_path(
asset: CheckpointPathMetadata,
) -> epath.Path:
"""Converts a CheckpointPathMetadata to its checkpoint epath.Path."""
return epath.Path(asset.path)
Loading