Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions httpx_retries/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ class HTTPMethod(str, Enum):
CONNECT = "CONNECT"


class _UnsetType:
__slots__ = ()


_UNSET: Final[_UnsetType] = _UnsetType()


class Retry:
"""
A class to encapsulate retry logic and configuration.
Expand Down Expand Up @@ -283,23 +290,47 @@ async def asleep(self, response: httpx.Response | Exception) -> None:
await asyncio.sleep(time_to_sleep)
self.elapsed_sleep += time_to_sleep

def copy_with(
self,
total: int | _UnsetType = _UNSET,
allowed_methods: Iterable[HTTPMethod | str] | None | _UnsetType = _UNSET,
status_forcelist: Iterable[HTTPStatus | int] | None | _UnsetType = _UNSET,
retry_on_exceptions: Iterable[type[Exception]] | None | _UnsetType = _UNSET,
backoff_factor: float | _UnsetType = _UNSET,
respect_retry_after_header: bool | _UnsetType = _UNSET,
max_backoff_wait: float | _UnsetType = _UNSET,
backoff_jitter: float | _UnsetType = _UNSET,
attempts_made: int | _UnsetType = _UNSET,
total_timeout: float | None | _UnsetType = _UNSET,
elapsed_sleep: float | _UnsetType = _UNSET,
validate_response: Callable[[httpx.Response], None | Awaitable[None]] | None | _UnsetType = _UNSET,
) -> "Retry":
"""Return a new Retry with selected fields overridden."""
return self.__class__(
total=self.total if isinstance(total, _UnsetType) else total,
allowed_methods=self.allowed_methods if isinstance(allowed_methods, _UnsetType) else allowed_methods,
status_forcelist=self.status_forcelist if isinstance(status_forcelist, _UnsetType) else status_forcelist,
retry_on_exceptions=self.retryable_exceptions
if isinstance(retry_on_exceptions, _UnsetType)
else retry_on_exceptions,
backoff_factor=self.backoff_factor if isinstance(backoff_factor, _UnsetType) else backoff_factor,
respect_retry_after_header=self.respect_retry_after_header
if isinstance(respect_retry_after_header, _UnsetType)
else respect_retry_after_header,
max_backoff_wait=self.max_backoff_wait if isinstance(max_backoff_wait, _UnsetType) else max_backoff_wait,
backoff_jitter=self.backoff_jitter if isinstance(backoff_jitter, _UnsetType) else backoff_jitter,
attempts_made=self.attempts_made if isinstance(attempts_made, _UnsetType) else attempts_made,
total_timeout=self.total_timeout if isinstance(total_timeout, _UnsetType) else total_timeout,
elapsed_sleep=self.elapsed_sleep if isinstance(elapsed_sleep, _UnsetType) else elapsed_sleep,
Comment on lines +310 to +324
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we're using the sentinel pattern here, the more performant check is x is _UNSET instead of isinstance(x, _UnsetType)

Copy link
Copy Markdown
Contributor Author

@vgavro vgavro May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's totally for type-checkers on using Sentinel, it's either - we'll do as typechekers wants us to do with _UnsetType = _UNSET pattern and using isinstance() (because, formally, you can pass OTHER _UnsetType, not only _UNSET sentinel), or adding mypy/pyright/ty ignores there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you prefer typing ignores for performance - i'll fix, let me know.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair argument. I haven't measured the performance impact, I doubt it is worth changing.

validate_response=self.validate_response
if isinstance(validate_response, _UnsetType)
else validate_response,
)

def increment(self) -> "Retry":
"""Return a new Retry instance with the attempt count incremented."""
logger.debug("increment retry=%s new_attempts_made=%s", self, self.attempts_made + 1)
return self.__class__(
total=self.total,
max_backoff_wait=self.max_backoff_wait,
backoff_factor=self.backoff_factor,
respect_retry_after_header=self.respect_retry_after_header,
allowed_methods=self.allowed_methods,
status_forcelist=self.status_forcelist,
retry_on_exceptions=self.retryable_exceptions,
backoff_jitter=self.backoff_jitter,
attempts_made=self.attempts_made + 1,
total_timeout=self.total_timeout,
elapsed_sleep=self.elapsed_sleep,
validate_response=self.validate_response,
)
return self.copy_with(attempts_made=self.attempts_made + 1)

def __repr__(self) -> str:
return f"<Retry(total={self.total}, attempts_made={self.attempts_made})>"
38 changes: 23 additions & 15 deletions httpx_retries/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,14 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:

logger.debug("handle_request started request=%s", request)

if self.retry.is_retryable_method(request.method):
retry: Retry = request.extensions.setdefault("retry", self.retry)

if retry.is_retryable_method(request.method):
if retry.validate_response is not None and inspect.iscoroutinefunction(retry.validate_response):
raise TypeError("validate_response must be a sync function when using a sync transport")

send_method = partial(self._sync_transport.handle_request)
response = self._retry_operation(request, send_method)
response = self._retry_operation(request, send_method, retry)
else:
response = self._sync_transport.handle_request(request)

Expand All @@ -112,9 +117,11 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:

logger.debug("handle_async_request started request=%s", request)

if self.retry.is_retryable_method(request.method):
retry: Retry = request.extensions.setdefault("retry", self.retry)

if retry.is_retryable_method(request.method):
send_method = partial(self._async_transport.handle_async_request)
response = await self._retry_operation_async(request, send_method)
response = await self._retry_operation_async(request, send_method, retry)
else:
response = await self._async_transport.handle_async_request(request)

Expand All @@ -126,11 +133,8 @@ def _retry_operation(
self,
request: httpx.Request,
send_method: Callable[..., httpx.Response],
retry: Retry,
) -> httpx.Response:
if self.retry.validate_response is not None and inspect.iscoroutinefunction(self.retry.validate_response):
raise TypeError("validate_response must be a sync function when using a sync transport")

retry = self.retry
response: httpx.Response | Exception | None = None

while True:
Expand All @@ -151,26 +155,28 @@ def _retry_operation(
continue

if retry.is_exhausted():
response.extensions["retry"] = retry
return response

if not retry.is_retryable_status_code(response.status_code):
if self.retry.validate_response is not None:
if retry.validate_response is not None:
# normally set by httpx _after_ calling this function, but we want the request in the validator
response.request = request
try:
self.retry.validate_response(response)
retry.validate_response(response)
except Exception as e:
if retry.is_exhausted() or not retry.is_retryable_exception(e):
raise
continue
response.extensions["retry"] = retry
return response

async def _retry_operation_async(
self,
request: httpx.Request,
send_method: Callable[..., Coroutine[Any, Any, httpx.Response]],
retry: Retry,
) -> httpx.Response:
retry = self.retry
response: httpx.Response | Exception | None = None

while True:
Expand All @@ -193,19 +199,21 @@ async def _retry_operation_async(
continue

if retry.is_exhausted():
response.extensions["retry"] = retry
return response

if not retry.is_retryable_status_code(response.status_code):
if self.retry.validate_response is not None:
if retry.validate_response is not None:
# normally set by httpx _after_ calling this function, but we want the request in the validator
response.request = request
try:
if inspect.iscoroutinefunction(self.retry.validate_response):
await self.retry.validate_response(response)
if inspect.iscoroutinefunction(retry.validate_response):
await retry.validate_response(response)
else:
self.retry.validate_response(response)
retry.validate_response(response)
except Exception as e:
if retry.is_exhausted() or not retry.is_retryable_exception(e):
raise
continue
response.extensions["retry"] = retry
return response
67 changes: 67 additions & 0 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
import inspect
import logging
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock

import httpx
Expand Down Expand Up @@ -550,3 +552,68 @@ def test_is_exhausted_below_total_timeout() -> None:
retry = Retry(total=10, total_timeout=5)
retry.elapsed_sleep = 4.9
assert retry.is_exhausted() is False


def test_copy_with_overrides_fields() -> None:
retry = Retry(total=10, backoff_factor=0.5, total_timeout=30.0)
copy = retry.copy_with(total=3, backoff_factor=1.0)
assert copy.total == 3
assert copy.backoff_factor == 1.0
assert copy.total_timeout == 30.0 # unchanged


def test_copy_with_no_args_equals_original() -> None:
retry = Retry(total=5, backoff_factor=0.2, max_backoff_wait=60.0)
copy = retry.copy_with()
assert copy.total == retry.total
assert copy.backoff_factor == retry.backoff_factor
assert copy.max_backoff_wait == retry.max_backoff_wait
assert copy.allowed_methods == retry.allowed_methods
assert copy.status_forcelist == retry.status_forcelist
assert copy.retryable_exceptions == retry.retryable_exceptions


def test_copy_with_can_set_total_timeout_to_none() -> None:
retry = Retry(total=5, total_timeout=10.0)
copy = retry.copy_with(total_timeout=None)
assert copy.total_timeout is None


def test_copy_with_preserves_subclass() -> None:
class CustomRetry(Retry):
pass

retry = CustomRetry(total=5)
copy = retry.copy_with(total=3)
assert type(copy) is CustomRetry


def test_copy_with_and_init_have_same_parameters() -> None:
init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"}
copy_with_params = set(inspect.signature(Retry.copy_with).parameters) - {"self"}
assert init_params == copy_with_params


def test_copy_with_roundtrips_all_fields() -> None:
init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"}
kwargs: dict[str, Any] = {
"total": 3,
"allowed_methods": {"GET", "POST"},
"status_forcelist": {500, 503},
"retry_on_exceptions": [httpx.TimeoutException],
"backoff_factor": 0.5,
"respect_retry_after_header": False,
"max_backoff_wait": 60.0,
"backoff_jitter": 0.5,
"attempts_made": 2,
"total_timeout": 30.0,
"elapsed_sleep": 1.5,
"validate_response": lambda _: None,
}
assert set(kwargs) == init_params

original = Retry(**kwargs)
copy = original.copy_with(**kwargs)

for attr, value in vars(original).items():
assert getattr(copy, attr) == value, f"{attr} mismatch"
64 changes: 62 additions & 2 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,15 @@ def send_method(request: httpx.Request) -> httpx.Response:
response = Mock(spec=httpx.Response)
response.status_code = status_code
response.headers = httpx.Headers()
response.extensions = {}
response.close = Mock()

responses.append(response)
return response

transport._retry_operation(request=httpx.Request("GET", "https://example.com"), send_method=send_method)
transport._retry_operation(
request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry
)

assert all(r.close.called for r in responses[:-1])

Expand Down Expand Up @@ -522,6 +525,60 @@ async def test_async_from_base_transport() -> None:
assert response.status_code == 200


def test_retry_extension_overrides_transport(mock_responses: MockResponse) -> None:
mock_sleep, status_code_sequences = mock_responses
status_code_sequences["https://example.com/fail"] = status_codes([(429, None)])
transport = RetryTransport(retry=Retry(total=10))

request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)})
with httpx.Client(transport=transport) as client:
response = client.send(request)

assert response.status_code == 429
assert mock_sleep.call_count == 2
assert response.extensions["retry"].attempts_made == 2


def test_retry_extension_set_from_transport_when_absent(mock_responses: MockResponse) -> None:
mock_sleep, _ = mock_responses
transport = RetryTransport(retry=Retry(total=3))

request = httpx.Request("GET", "https://example.com")
with httpx.Client(transport=transport) as client:
response = client.send(request)

assert request.extensions["retry"] is transport.retry
assert response.extensions["retry"] is transport.retry


@pytest.mark.asyncio
async def test_async_retry_extension_overrides_transport(mock_async_responses: AsyncMockResponse) -> None:
mock_asleep, status_code_sequences = mock_async_responses
status_code_sequences["https://example.com/fail"] = astatus_codes([(429, None)])
transport = RetryTransport(retry=Retry(total=10))

request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)})
async with httpx.AsyncClient(transport=transport) as client:
response = await client.send(request)

assert response.status_code == 429
assert mock_asleep.call_count == 2
assert response.extensions["retry"].attempts_made == 2


@pytest.mark.asyncio
async def test_async_retry_extension_set_from_transport_when_absent(mock_async_responses: AsyncMockResponse) -> None:
mock_asleep, _ = mock_async_responses
transport = RetryTransport(retry=Retry(total=3))

request = httpx.Request("GET", "https://example.com")
async with httpx.AsyncClient(transport=transport) as client:
response = await client.send(request)

assert request.extensions["retry"] is transport.retry
assert response.extensions["retry"] is transport.retry


def test_retry_after_capped_by_total_timeout(mock_responses: MockResponse) -> None:
mock_sleep, status_code_sequences = mock_responses
status_code_sequences["https://example.com/fail"] = status_codes([(429, "120")])
Expand Down Expand Up @@ -566,12 +623,15 @@ async def send_method(request: httpx.Request) -> httpx.Response:
response = AsyncMock(spec=httpx.Response)
response.status_code = status_code
response.headers = httpx.Headers()
response.extensions = {}
response.aclose = AsyncMock()

responses.append(response)
return response

await transport._retry_operation_async(request=httpx.Request("GET", "https://example.com"), send_method=send_method)
await transport._retry_operation_async(
request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry
)

assert all(r.aclose.called for r in responses[:-1])

Expand Down
Loading