Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def elastic_retry(
timeout: float | None = None,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
should_retry: Callable[[jax.errors.JaxRuntimeError], bool] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with elasticity fault tolerance.

Expand Down Expand Up @@ -233,6 +234,7 @@ def elastic_retry(
pre_callback: A callback to call before the function is attempted.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.
should_retry: Custom callback to determine if an error should be retried.

Returns:
A decorator that retries the wrapped function.
Expand Down Expand Up @@ -299,7 +301,10 @@ def attempt_execution(retry_index: int) -> Any:
if on_elastic_event_callback is not None:
on_elastic_event_callback()
except jax.errors.JaxRuntimeError as error:
if not elastic.is_error_due_to_slice_down(error):
should_retry_error = elastic.is_error_due_to_slice_down(error)
if should_retry is not None and should_retry(error):
should_retry_error = True
if not should_retry_error:
raise

if self.new_slice_event.is_set():
Expand Down
Loading