From fd0232587126bc279c039e01cc55e96d0fd26dae Mon Sep 17 00:00:00 2001 From: Victor Gavro Date: Tue, 12 May 2026 16:34:47 +0300 Subject: [PATCH 1/5] feat: add request.extensions["retry"] for retry configuration per request, also set this object on request for introspection purposes --- httpx_retries/transport.py | 16 ++++++----- tests/test_transport.py | 54 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/httpx_retries/transport.py b/httpx_retries/transport.py index 46e762e..ea9b831 100644 --- a/httpx_retries/transport.py +++ b/httpx_retries/transport.py @@ -87,9 +87,11 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: logger.debug("handle_request started request=%s", request) - if self.retry.is_retryable_method(request.method): + retry: Retry = request.extensions.setdefault("retry", self.retry) + + if retry.is_retryable_method(request.method): send_method = partial(self._sync_transport.handle_request) - response = self._retry_operation(request, send_method) + response = self._retry_operation(request, send_method, retry) else: response = self._sync_transport.handle_request(request) @@ -111,9 +113,11 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: logger.debug("handle_async_request started request=%s", request) - if self.retry.is_retryable_method(request.method): + retry: Retry = request.extensions.setdefault("retry", self.retry) + + if retry.is_retryable_method(request.method): send_method = partial(self._async_transport.handle_async_request) - response = await self._retry_operation_async(request, send_method) + response = await self._retry_operation_async(request, send_method, retry) else: response = await self._async_transport.handle_async_request(request) @@ -125,8 +129,8 @@ def _retry_operation( self, request: httpx.Request, send_method: Callable[..., httpx.Response], + retry: Retry, ) -> httpx.Response: - retry = self.retry response: httpx.Response | Exception | None = None while True: @@ -153,8 +157,8 @@ async def _retry_operation_async( self, request: httpx.Request, send_method: Callable[..., Coroutine[Any, Any, httpx.Response]], + retry: Retry, ) -> httpx.Response: - retry = self.retry response: httpx.Response | Exception | None = None while True: diff --git a/tests/test_transport.py b/tests/test_transport.py index 5b354f8..2c7bc81 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -265,7 +265,7 @@ def send_method(request: httpx.Request) -> httpx.Response: responses.append(response) return response - transport._retry_operation(request=httpx.Request("GET", "https://example.com"), send_method=send_method) + transport._retry_operation(request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry) assert all(r.close.called for r in responses[:-1]) @@ -522,6 +522,56 @@ async def test_async_from_base_transport() -> None: assert response.status_code == 200 +def test_retry_extension_overrides_transport(mock_responses: MockResponse) -> None: + mock_sleep, status_code_sequences = mock_responses + status_code_sequences["https://example.com/fail"] = status_codes([(429, None)]) + transport = RetryTransport(retry=Retry(total=10)) + + request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)}) + with httpx.Client(transport=transport) as client: + response = client.send(request) + + assert response.status_code == 429 + assert mock_sleep.call_count == 2 + + +def test_retry_extension_set_from_transport_when_absent(mock_responses: MockResponse) -> None: + mock_sleep, _ = mock_responses + transport = RetryTransport(retry=Retry(total=3)) + + request = httpx.Request("GET", "https://example.com") + with httpx.Client(transport=transport) as client: + client.send(request) + + assert request.extensions["retry"] is transport.retry + + +@pytest.mark.asyncio +async def test_async_retry_extension_overrides_transport(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, status_code_sequences = mock_async_responses + status_code_sequences["https://example.com/fail"] = astatus_codes([(429, None)]) + transport = RetryTransport(retry=Retry(total=10)) + + request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)}) + async with httpx.AsyncClient(transport=transport) as client: + response = await client.send(request) + + assert response.status_code == 429 + assert mock_asleep.call_count == 2 + + +@pytest.mark.asyncio +async def test_async_retry_extension_set_from_transport_when_absent(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + transport = RetryTransport(retry=Retry(total=3)) + + request = httpx.Request("GET", "https://example.com") + async with httpx.AsyncClient(transport=transport) as client: + await client.send(request) + + assert request.extensions["retry"] is transport.retry + + def test_retry_after_capped_by_total_timeout(mock_responses: MockResponse) -> None: mock_sleep, status_code_sequences = mock_responses status_code_sequences["https://example.com/fail"] = status_codes([(429, "120")]) @@ -571,7 +621,7 @@ async def send_method(request: httpx.Request) -> httpx.Response: responses.append(response) return response - await transport._retry_operation_async(request=httpx.Request("GET", "https://example.com"), send_method=send_method) + await transport._retry_operation_async(request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry) assert all(r.aclose.called for r in responses[:-1]) From f5896cb6e94ae29a62702805a0fda78faafe0e7f Mon Sep 17 00:00:00 2001 From: Victor Gavro Date: Tue, 12 May 2026 17:33:07 +0300 Subject: [PATCH 2/5] feat: add Retry.copy_with method for more easier usage with request.extensions["retry"] - allows more easily override Retry configuration parameters per request to pass it like Client.get(extensions={"retry": transport.retry.copy_with(...)}) (signature like httpx.URL.copy_with, typing approach with _UNSET from httpx._config) --- httpx_retries/retry.py | 54 +++++++++++++++++++++++++++++++---------- tests/test_retry.py | 34 ++++++++++++++++++++++++++ tests/test_transport.py | 8 ++++-- 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/httpx_retries/retry.py b/httpx_retries/retry.py index 5415cae..ab03fab 100644 --- a/httpx_retries/retry.py +++ b/httpx_retries/retry.py @@ -30,6 +30,13 @@ class HTTPMethod(str, Enum): CONNECT = "CONNECT" +class _UnsetType: + __slots__ = () + + +_UNSET: Final[_UnsetType] = _UnsetType() + + class Retry: """ A class to encapsulate retry logic and configuration. @@ -277,22 +284,43 @@ async def asleep(self, response: httpx.Response | Exception) -> None: await asyncio.sleep(time_to_sleep) self.elapsed_sleep += time_to_sleep + def copy_with( + self, + total: int | _UnsetType = _UNSET, + allowed_methods: Iterable[HTTPMethod | str] | None | _UnsetType = _UNSET, + status_forcelist: Iterable[HTTPStatus | int] | None | _UnsetType = _UNSET, + retry_on_exceptions: Iterable[type[Exception]] | None | _UnsetType = _UNSET, + backoff_factor: float | _UnsetType = _UNSET, + respect_retry_after_header: bool | _UnsetType = _UNSET, + max_backoff_wait: float | _UnsetType = _UNSET, + backoff_jitter: float | _UnsetType = _UNSET, + attempts_made: int | _UnsetType = _UNSET, + total_timeout: float | None | _UnsetType = _UNSET, + elapsed_sleep: float | _UnsetType = _UNSET, + ) -> "Retry": + """Return a new Retry with selected fields overridden.""" + return self.__class__( + total=self.total if isinstance(total, _UnsetType) else total, + allowed_methods=self.allowed_methods if isinstance(allowed_methods, _UnsetType) else allowed_methods, + status_forcelist=self.status_forcelist if isinstance(status_forcelist, _UnsetType) else status_forcelist, + retry_on_exceptions=self.retryable_exceptions + if isinstance(retry_on_exceptions, _UnsetType) + else retry_on_exceptions, + backoff_factor=self.backoff_factor if isinstance(backoff_factor, _UnsetType) else backoff_factor, + respect_retry_after_header=self.respect_retry_after_header + if isinstance(respect_retry_after_header, _UnsetType) + else respect_retry_after_header, + max_backoff_wait=self.max_backoff_wait if isinstance(max_backoff_wait, _UnsetType) else max_backoff_wait, + backoff_jitter=self.backoff_jitter if isinstance(backoff_jitter, _UnsetType) else backoff_jitter, + attempts_made=self.attempts_made if isinstance(attempts_made, _UnsetType) else attempts_made, + total_timeout=self.total_timeout if isinstance(total_timeout, _UnsetType) else total_timeout, + elapsed_sleep=self.elapsed_sleep if isinstance(elapsed_sleep, _UnsetType) else elapsed_sleep, + ) + def increment(self) -> "Retry": """Return a new Retry instance with the attempt count incremented.""" logger.debug("increment retry=%s new_attempts_made=%s", self, self.attempts_made + 1) - return self.__class__( - total=self.total, - max_backoff_wait=self.max_backoff_wait, - backoff_factor=self.backoff_factor, - respect_retry_after_header=self.respect_retry_after_header, - allowed_methods=self.allowed_methods, - status_forcelist=self.status_forcelist, - retry_on_exceptions=self.retryable_exceptions, - backoff_jitter=self.backoff_jitter, - attempts_made=self.attempts_made + 1, - total_timeout=self.total_timeout, - elapsed_sleep=self.elapsed_sleep, - ) + return self.copy_with(attempts_made=self.attempts_made + 1) def __repr__(self) -> str: return f"" diff --git a/tests/test_retry.py b/tests/test_retry.py index aa7a55c..9677741 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -550,3 +550,37 @@ def test_is_exhausted_below_total_timeout() -> None: retry = Retry(total=10, total_timeout=5) retry.elapsed_sleep = 4.9 assert retry.is_exhausted() is False + + +def test_copy_with_overrides_fields() -> None: + retry = Retry(total=10, backoff_factor=0.5, total_timeout=30.0) + copy = retry.copy_with(total=3, backoff_factor=1.0) + assert copy.total == 3 + assert copy.backoff_factor == 1.0 + assert copy.total_timeout == 30.0 # unchanged + + +def test_copy_with_no_args_equals_original() -> None: + retry = Retry(total=5, backoff_factor=0.2, max_backoff_wait=60.0) + copy = retry.copy_with() + assert copy.total == retry.total + assert copy.backoff_factor == retry.backoff_factor + assert copy.max_backoff_wait == retry.max_backoff_wait + assert copy.allowed_methods == retry.allowed_methods + assert copy.status_forcelist == retry.status_forcelist + assert copy.retryable_exceptions == retry.retryable_exceptions + + +def test_copy_with_can_set_total_timeout_to_none() -> None: + retry = Retry(total=5, total_timeout=10.0) + copy = retry.copy_with(total_timeout=None) + assert copy.total_timeout is None + + +def test_copy_with_preserves_subclass() -> None: + class CustomRetry(Retry): + pass + + retry = CustomRetry(total=5) + copy = retry.copy_with(total=3) + assert type(copy) is CustomRetry diff --git a/tests/test_transport.py b/tests/test_transport.py index 2c7bc81..1519162 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -265,7 +265,9 @@ def send_method(request: httpx.Request) -> httpx.Response: responses.append(response) return response - transport._retry_operation(request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry) + transport._retry_operation( + request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry + ) assert all(r.close.called for r in responses[:-1]) @@ -621,7 +623,9 @@ async def send_method(request: httpx.Request) -> httpx.Response: responses.append(response) return response - await transport._retry_operation_async(request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry) + await transport._retry_operation_async( + request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry + ) assert all(r.aclose.called for r in responses[:-1]) From c4433b26469a3b30f3d9b02ed0e78d8b343d11f8 Mon Sep 17 00:00:00 2001 From: Victor Gavro Date: Wed, 13 May 2026 02:29:01 +0300 Subject: [PATCH 3/5] fix: update request.extensions["retry"] on increment --- httpx_retries/transport.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpx_retries/transport.py b/httpx_retries/transport.py index ea9b831..d9992ee 100644 --- a/httpx_retries/transport.py +++ b/httpx_retries/transport.py @@ -139,7 +139,7 @@ def _retry_operation( response.close() logger.debug("_retry_operation retrying request=%s response=%s retry=%s", request, response, retry) - retry = retry.increment() + retry = request.extensions["retry"] = retry.increment() retry.sleep(response) try: response = send_method(request) @@ -169,7 +169,7 @@ async def _retry_operation_async( logger.debug( "_retry_operation_async retrying request=%s response=%s retry=%s", request, response, retry ) - retry = retry.increment() + retry = request.extensions["retry"] = retry.increment() await retry.asleep(response) try: response = await send_method(request) From 29e6860f349afe756d8b39db6a9a730b91203897 Mon Sep 17 00:00:00 2001 From: Victor Gavro Date: Tue, 19 May 2026 18:47:17 +0300 Subject: [PATCH 4/5] feat: do not mutate request.extensions["retry"] on retry, set response.extensions["retry"] for depleted retry instead --- httpx_retries/transport.py | 6 ++++-- tests/test_transport.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/httpx_retries/transport.py b/httpx_retries/transport.py index d9992ee..47bb151 100644 --- a/httpx_retries/transport.py +++ b/httpx_retries/transport.py @@ -139,7 +139,7 @@ def _retry_operation( response.close() logger.debug("_retry_operation retrying request=%s response=%s retry=%s", request, response, retry) - retry = request.extensions["retry"] = retry.increment() + retry = retry.increment() retry.sleep(response) try: response = send_method(request) @@ -151,6 +151,7 @@ def _retry_operation( continue if retry.is_exhausted() or not retry.is_retryable_status_code(response.status_code): + response.extensions["retry"] = retry return response async def _retry_operation_async( @@ -169,7 +170,7 @@ async def _retry_operation_async( logger.debug( "_retry_operation_async retrying request=%s response=%s retry=%s", request, response, retry ) - retry = request.extensions["retry"] = retry.increment() + retry = retry.increment() await retry.asleep(response) try: response = await send_method(request) @@ -181,4 +182,5 @@ async def _retry_operation_async( continue if retry.is_exhausted() or not retry.is_retryable_status_code(response.status_code): + response.extensions["retry"] = retry return response diff --git a/tests/test_transport.py b/tests/test_transport.py index 1519162..5a11004 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -260,6 +260,7 @@ def send_method(request: httpx.Request) -> httpx.Response: response = Mock(spec=httpx.Response) response.status_code = status_code response.headers = httpx.Headers() + response.extensions = {} response.close = Mock() responses.append(response) @@ -535,6 +536,7 @@ def test_retry_extension_overrides_transport(mock_responses: MockResponse) -> No assert response.status_code == 429 assert mock_sleep.call_count == 2 + assert response.extensions["retry"].attempts_made == 2 def test_retry_extension_set_from_transport_when_absent(mock_responses: MockResponse) -> None: @@ -543,9 +545,10 @@ def test_retry_extension_set_from_transport_when_absent(mock_responses: MockResp request = httpx.Request("GET", "https://example.com") with httpx.Client(transport=transport) as client: - client.send(request) + response = client.send(request) assert request.extensions["retry"] is transport.retry + assert response.extensions["retry"] is transport.retry @pytest.mark.asyncio @@ -560,6 +563,7 @@ async def test_async_retry_extension_overrides_transport(mock_async_responses: A assert response.status_code == 429 assert mock_asleep.call_count == 2 + assert response.extensions["retry"].attempts_made == 2 @pytest.mark.asyncio @@ -569,9 +573,10 @@ async def test_async_retry_extension_set_from_transport_when_absent(mock_async_r request = httpx.Request("GET", "https://example.com") async with httpx.AsyncClient(transport=transport) as client: - await client.send(request) + response = await client.send(request) assert request.extensions["retry"] is transport.retry + assert response.extensions["retry"] is transport.retry def test_retry_after_capped_by_total_timeout(mock_responses: MockResponse) -> None: @@ -618,6 +623,7 @@ async def send_method(request: httpx.Request) -> httpx.Response: response = AsyncMock(spec=httpx.Response) response.status_code = status_code response.headers = httpx.Headers() + response.extensions = {} response.aclose = AsyncMock() responses.append(response) From d58a8b359bebf4b1fc1e9a3cddc8568d4971a945 Mon Sep 17 00:00:00 2001 From: Victor Gavro Date: Tue, 19 May 2026 21:08:13 +0300 Subject: [PATCH 5/5] added copy_with tests to match __init__ signature --- tests/test_retry.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_retry.py b/tests/test_retry.py index 9677741..0cd3cf5 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,6 +1,8 @@ import datetime +import inspect import logging from http import HTTPStatus +from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx @@ -584,3 +586,33 @@ class CustomRetry(Retry): retry = CustomRetry(total=5) copy = retry.copy_with(total=3) assert type(copy) is CustomRetry + + +def test_copy_with_and_init_have_same_parameters() -> None: + init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"} + copy_with_params = set(inspect.signature(Retry.copy_with).parameters) - {"self"} + assert init_params == copy_with_params + + +def test_copy_with_roundtrips_all_fields() -> None: + init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"} + kwargs: dict[str, Any] = { + "total": 3, + "allowed_methods": {"GET", "POST"}, + "status_forcelist": {500, 503}, + "retry_on_exceptions": [httpx.TimeoutException], + "backoff_factor": 0.5, + "respect_retry_after_header": False, + "max_backoff_wait": 60.0, + "backoff_jitter": 0.5, + "attempts_made": 2, + "total_timeout": 30.0, + "elapsed_sleep": 1.5, + } + assert set(kwargs) == init_params + + original = Retry(**kwargs) + copy = original.copy_with(**kwargs) + + for attr, value in vars(original).items(): + assert getattr(copy, attr) == value, f"{attr} mismatch"