Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ def create_with_handlers(
) -> CheckpointablesOptions:
registry = registration.local_registry(include_global_registry=True)
for handler in handlers:
registry.add(handler, None)
registry.add(handler, checkpointable_name=None)
for name, handler in named_handlers.items():
registry.add(handler, name)
registry.add(handler, checkpointable_name=name)
return cls(registry=registry)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
given checkpointable will be used.
"""

from typing import Type
from typing import Sequence, Type

from orbax.checkpoint.experimental.v1._src.handlers import json_handler
from orbax.checkpoint.experimental.v1._src.handlers import proto_handler
Expand All @@ -34,23 +34,45 @@
def _try_register_handler(
handler_type: Type[handler_types.CheckpointableHandler],
name: str | None = None,
secondary_typestrs: Sequence[str] | None = None,
):
"""Tries to register handler globally with name and secondary typestrs."""
try:
registration.global_registry().add(handler_type, name)
registration.global_registry().add(
handler_type,
checkpointable_name=name,
secondary_typestrs=secondary_typestrs,
)
except registration.AlreadyExistsError:
pass


_try_register_handler(proto_handler.ProtoHandler)
_try_register_handler(json_handler.JsonHandler)
_try_register_handler(
proto_handler.ProtoHandler,
secondary_typestrs=[
'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler',
],
)
_try_register_handler(
json_handler.JsonHandler,
secondary_typestrs=[
'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler',
],
)
_try_register_handler(
stateful_checkpointable_handler.StatefulCheckpointableHandler
)
_try_register_handler(
json_handler.MetricsHandler,
checkpoint_layout.METRICS_CHECKPOINTABLE_KEY,
)
_try_register_handler(pytree_handler.PyTreeHandler)
_try_register_handler(
pytree_handler.PyTreeHandler,
secondary_typestrs=[
'orbax.checkpoint._src.handlers.pytree_checkpoint_handler.PyTreeCheckpointHandler',
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler',
],
)
_try_register_handler(
pytree_handler.PyTreeHandler, checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
)
Loading
Loading