diff --git a/httpx_retries/retry.py b/httpx_retries/retry.py index 5415cae..1e3dcaf 100644 --- a/httpx_retries/retry.py +++ b/httpx_retries/retry.py @@ -4,11 +4,11 @@ import random import sys import time -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from email.utils import parsedate_to_datetime from enum import Enum from http import HTTPStatus -from typing import Final +from typing import Final, TypeAlias import httpx @@ -30,6 +30,9 @@ class HTTPMethod(str, Enum): CONNECT = "CONNECT" +StatusForcelist: TypeAlias = frozenset[HTTPStatus | int] | Callable[[int], bool] + + class Retry: """ A class to encapsulate retry logic and configuration. @@ -49,7 +52,8 @@ class Retry: when deciding how long to wait before retrying. allowed_methods (Iterable[http.HTTPMethod, str], optional): The HTTP methods that can be retried. Defaults to ["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]. - status_forcelist (Iterable[http.HTTPStatus, int], optional): The HTTP status codes that can be retried. + status_forcelist (Iterable[http.HTTPStatus, int] | Callable[[int], bool], optional): The HTTP status codes that + can be retried, or a predicate that accepts a status code and returns whether it can be retried. Defaults to [429, 502, 503, 504]. retry_on_exceptions (Iterable[type[httpx.HTTPError]], optional): The HTTP exceptions that can be retried. Defaults to [httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError]. @@ -92,7 +96,7 @@ def __init__( self, total: int = 10, allowed_methods: Iterable[HTTPMethod | str] | None = None, - status_forcelist: Iterable[HTTPStatus | int] | None = None, + status_forcelist: Iterable[HTTPStatus | int] | Callable[[int], bool] | None = None, retry_on_exceptions: Iterable[type[Exception]] | None = None, backoff_factor: float = 0.0, respect_retry_after_header: bool = True, @@ -130,17 +134,29 @@ def __init__( self.allowed_methods: frozenset[str] = frozenset( method.upper() for method in (allowed_methods or self.RETRYABLE_METHODS) ) - self.status_forcelist = frozenset((status_forcelist or self.RETRYABLE_STATUS_CODES)) + self.status_forcelist: StatusForcelist = self._prepare_status_forcelist(status_forcelist) self.retryable_exceptions = ( self.RETRYABLE_EXCEPTIONS if retry_on_exceptions is None else tuple(retry_on_exceptions) ) + @staticmethod + def _prepare_status_forcelist( + status_forcelist: Iterable[HTTPStatus | int] | Callable[[int], bool] | None, + ) -> StatusForcelist: + if callable(status_forcelist): + return status_forcelist + + return frozenset((status_forcelist or Retry.RETRYABLE_STATUS_CODES)) + def is_retryable_method(self, method: str) -> bool: """Check if a method is retryable.""" return method.upper() in self.allowed_methods def is_retryable_status_code(self, status_code: int) -> bool: """Check if a status code is retryable.""" + if callable(self.status_forcelist): + return self.status_forcelist(status_code) + return status_code in self.status_forcelist def is_retryable_exception(self, exception: Exception) -> bool: diff --git a/tests/test_retry.py b/tests/test_retry.py index aa7a55c..c17a060 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -33,8 +33,8 @@ def test_retry_custom_initialization() -> None: assert retry.max_backoff_wait == 30 assert HTTPMethod.GET in retry.allowed_methods assert HTTPMethod.POST in retry.allowed_methods - assert HTTPStatus.INTERNAL_SERVER_ERROR in retry.status_forcelist - assert HTTPStatus.BAD_GATEWAY in retry.status_forcelist + assert retry.is_retryable_status_code(HTTPStatus.INTERNAL_SERVER_ERROR) + assert retry.is_retryable_status_code(HTTPStatus.BAD_GATEWAY) def test_is_retryable_method() -> None: @@ -110,6 +110,15 @@ def test_custom_retry_status_codes_non_standard() -> None: assert retry.is_retryable_status_code(502) is False +def test_custom_retry_status_codes_predicate() -> None: + retry = Retry(status_forcelist=lambda status_code: status_code >= 500) + + assert retry.is_retryable_status_code(500) is True + assert retry.is_retryable_status_code(599) is True + assert retry.is_retryable_status_code(429) is False + assert retry.is_retry("GET", 503, False) is True + + def test_is_exhausted() -> None: retry = Retry(total=3) assert retry.is_exhausted() is False