From cb64db46260a3bf6c1e03f0153dacdd4da9a8ebf Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Fri, 29 May 2026 14:59:33 -0700 Subject: [PATCH] Add custom retry callback to the Pathways elastic retry manager. PiperOrigin-RevId: 923614318 --- pathwaysutils/elastic/manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ae3dc81..451bd32 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -197,6 +197,7 @@ def elastic_retry( timeout: float | None = None, pre_callback: Callable[..., Any] | None = None, on_elastic_event_callback: Callable[..., Any] | None = None, + should_retry: Callable[[jax.errors.JaxRuntimeError], bool] | None = None, ) -> Callable[[_F], _F]: """Retries a function with elasticity fault tolerance. @@ -233,6 +234,7 @@ def elastic_retry( pre_callback: A callback to call before the function is attempted. on_elastic_event_callback: A callback to call after an elastic failure occurs. + should_retry: Custom callback to determine if an error should be retried. Returns: A decorator that retries the wrapped function. @@ -299,7 +301,10 @@ def attempt_execution(retry_index: int) -> Any: if on_elastic_event_callback is not None: on_elastic_event_callback() except jax.errors.JaxRuntimeError as error: - if not elastic.is_error_due_to_slice_down(error): + should_retry_error = elastic.is_error_due_to_slice_down(error) + if should_retry is not None and should_retry(error): + should_retry_error = True + if not should_retry_error: raise if self.new_slice_event.is_set():