diff --git a/httpx_retries/retry.py b/httpx_retries/retry.py index 8d5a75d..ed470dc 100644 --- a/httpx_retries/retry.py +++ b/httpx_retries/retry.py @@ -30,6 +30,13 @@ class HTTPMethod(str, Enum): CONNECT = "CONNECT" +class _UnsetType: + __slots__ = () + + +_UNSET: Final[_UnsetType] = _UnsetType() + + class Retry: """ A class to encapsulate retry logic and configuration. @@ -283,23 +290,47 @@ async def asleep(self, response: httpx.Response | Exception) -> None: await asyncio.sleep(time_to_sleep) self.elapsed_sleep += time_to_sleep + def copy_with( + self, + total: int | _UnsetType = _UNSET, + allowed_methods: Iterable[HTTPMethod | str] | None | _UnsetType = _UNSET, + status_forcelist: Iterable[HTTPStatus | int] | None | _UnsetType = _UNSET, + retry_on_exceptions: Iterable[type[Exception]] | None | _UnsetType = _UNSET, + backoff_factor: float | _UnsetType = _UNSET, + respect_retry_after_header: bool | _UnsetType = _UNSET, + max_backoff_wait: float | _UnsetType = _UNSET, + backoff_jitter: float | _UnsetType = _UNSET, + attempts_made: int | _UnsetType = _UNSET, + total_timeout: float | None | _UnsetType = _UNSET, + elapsed_sleep: float | _UnsetType = _UNSET, + validate_response: Callable[[httpx.Response], None | Awaitable[None]] | None | _UnsetType = _UNSET, + ) -> "Retry": + """Return a new Retry with selected fields overridden.""" + return self.__class__( + total=self.total if isinstance(total, _UnsetType) else total, + allowed_methods=self.allowed_methods if isinstance(allowed_methods, _UnsetType) else allowed_methods, + status_forcelist=self.status_forcelist if isinstance(status_forcelist, _UnsetType) else status_forcelist, + retry_on_exceptions=self.retryable_exceptions + if isinstance(retry_on_exceptions, _UnsetType) + else retry_on_exceptions, + backoff_factor=self.backoff_factor if isinstance(backoff_factor, _UnsetType) else backoff_factor, + respect_retry_after_header=self.respect_retry_after_header + if isinstance(respect_retry_after_header, _UnsetType) + else respect_retry_after_header, + max_backoff_wait=self.max_backoff_wait if isinstance(max_backoff_wait, _UnsetType) else max_backoff_wait, + backoff_jitter=self.backoff_jitter if isinstance(backoff_jitter, _UnsetType) else backoff_jitter, + attempts_made=self.attempts_made if isinstance(attempts_made, _UnsetType) else attempts_made, + total_timeout=self.total_timeout if isinstance(total_timeout, _UnsetType) else total_timeout, + elapsed_sleep=self.elapsed_sleep if isinstance(elapsed_sleep, _UnsetType) else elapsed_sleep, + validate_response=self.validate_response + if isinstance(validate_response, _UnsetType) + else validate_response, + ) + def increment(self) -> "Retry": """Return a new Retry instance with the attempt count incremented.""" logger.debug("increment retry=%s new_attempts_made=%s", self, self.attempts_made + 1) - return self.__class__( - total=self.total, - max_backoff_wait=self.max_backoff_wait, - backoff_factor=self.backoff_factor, - respect_retry_after_header=self.respect_retry_after_header, - allowed_methods=self.allowed_methods, - status_forcelist=self.status_forcelist, - retry_on_exceptions=self.retryable_exceptions, - backoff_jitter=self.backoff_jitter, - attempts_made=self.attempts_made + 1, - total_timeout=self.total_timeout, - elapsed_sleep=self.elapsed_sleep, - validate_response=self.validate_response, - ) + return self.copy_with(attempts_made=self.attempts_made + 1) def __repr__(self) -> str: return f"" diff --git a/httpx_retries/transport.py b/httpx_retries/transport.py index d622750..1990d6d 100644 --- a/httpx_retries/transport.py +++ b/httpx_retries/transport.py @@ -88,9 +88,14 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: logger.debug("handle_request started request=%s", request) - if self.retry.is_retryable_method(request.method): + retry: Retry = request.extensions.setdefault("retry", self.retry) + + if retry.is_retryable_method(request.method): + if retry.validate_response is not None and inspect.iscoroutinefunction(retry.validate_response): + raise TypeError("validate_response must be a sync function when using a sync transport") + send_method = partial(self._sync_transport.handle_request) - response = self._retry_operation(request, send_method) + response = self._retry_operation(request, send_method, retry) else: response = self._sync_transport.handle_request(request) @@ -112,9 +117,11 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: logger.debug("handle_async_request started request=%s", request) - if self.retry.is_retryable_method(request.method): + retry: Retry = request.extensions.setdefault("retry", self.retry) + + if retry.is_retryable_method(request.method): send_method = partial(self._async_transport.handle_async_request) - response = await self._retry_operation_async(request, send_method) + response = await self._retry_operation_async(request, send_method, retry) else: response = await self._async_transport.handle_async_request(request) @@ -126,11 +133,8 @@ def _retry_operation( self, request: httpx.Request, send_method: Callable[..., httpx.Response], + retry: Retry, ) -> httpx.Response: - if self.retry.validate_response is not None and inspect.iscoroutinefunction(self.retry.validate_response): - raise TypeError("validate_response must be a sync function when using a sync transport") - - retry = self.retry response: httpx.Response | Exception | None = None while True: @@ -151,26 +155,28 @@ def _retry_operation( continue if retry.is_exhausted(): + response.extensions["retry"] = retry return response if not retry.is_retryable_status_code(response.status_code): - if self.retry.validate_response is not None: + if retry.validate_response is not None: # normally set by httpx _after_ calling this function, but we want the request in the validator response.request = request try: - self.retry.validate_response(response) + retry.validate_response(response) except Exception as e: if retry.is_exhausted() or not retry.is_retryable_exception(e): raise continue + response.extensions["retry"] = retry return response async def _retry_operation_async( self, request: httpx.Request, send_method: Callable[..., Coroutine[Any, Any, httpx.Response]], + retry: Retry, ) -> httpx.Response: - retry = self.retry response: httpx.Response | Exception | None = None while True: @@ -193,19 +199,21 @@ async def _retry_operation_async( continue if retry.is_exhausted(): + response.extensions["retry"] = retry return response if not retry.is_retryable_status_code(response.status_code): - if self.retry.validate_response is not None: + if retry.validate_response is not None: # normally set by httpx _after_ calling this function, but we want the request in the validator response.request = request try: - if inspect.iscoroutinefunction(self.retry.validate_response): - await self.retry.validate_response(response) + if inspect.iscoroutinefunction(retry.validate_response): + await retry.validate_response(response) else: - self.retry.validate_response(response) + retry.validate_response(response) except Exception as e: if retry.is_exhausted() or not retry.is_retryable_exception(e): raise continue + response.extensions["retry"] = retry return response diff --git a/tests/test_retry.py b/tests/test_retry.py index aa7a55c..4fbdd7b 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,6 +1,8 @@ import datetime +import inspect import logging from http import HTTPStatus +from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx @@ -550,3 +552,68 @@ def test_is_exhausted_below_total_timeout() -> None: retry = Retry(total=10, total_timeout=5) retry.elapsed_sleep = 4.9 assert retry.is_exhausted() is False + + +def test_copy_with_overrides_fields() -> None: + retry = Retry(total=10, backoff_factor=0.5, total_timeout=30.0) + copy = retry.copy_with(total=3, backoff_factor=1.0) + assert copy.total == 3 + assert copy.backoff_factor == 1.0 + assert copy.total_timeout == 30.0 # unchanged + + +def test_copy_with_no_args_equals_original() -> None: + retry = Retry(total=5, backoff_factor=0.2, max_backoff_wait=60.0) + copy = retry.copy_with() + assert copy.total == retry.total + assert copy.backoff_factor == retry.backoff_factor + assert copy.max_backoff_wait == retry.max_backoff_wait + assert copy.allowed_methods == retry.allowed_methods + assert copy.status_forcelist == retry.status_forcelist + assert copy.retryable_exceptions == retry.retryable_exceptions + + +def test_copy_with_can_set_total_timeout_to_none() -> None: + retry = Retry(total=5, total_timeout=10.0) + copy = retry.copy_with(total_timeout=None) + assert copy.total_timeout is None + + +def test_copy_with_preserves_subclass() -> None: + class CustomRetry(Retry): + pass + + retry = CustomRetry(total=5) + copy = retry.copy_with(total=3) + assert type(copy) is CustomRetry + + +def test_copy_with_and_init_have_same_parameters() -> None: + init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"} + copy_with_params = set(inspect.signature(Retry.copy_with).parameters) - {"self"} + assert init_params == copy_with_params + + +def test_copy_with_roundtrips_all_fields() -> None: + init_params = set(inspect.signature(Retry.__init__).parameters) - {"self"} + kwargs: dict[str, Any] = { + "total": 3, + "allowed_methods": {"GET", "POST"}, + "status_forcelist": {500, 503}, + "retry_on_exceptions": [httpx.TimeoutException], + "backoff_factor": 0.5, + "respect_retry_after_header": False, + "max_backoff_wait": 60.0, + "backoff_jitter": 0.5, + "attempts_made": 2, + "total_timeout": 30.0, + "elapsed_sleep": 1.5, + "validate_response": lambda _: None, + } + assert set(kwargs) == init_params + + original = Retry(**kwargs) + copy = original.copy_with(**kwargs) + + for attr, value in vars(original).items(): + assert getattr(copy, attr) == value, f"{attr} mismatch" diff --git a/tests/test_transport.py b/tests/test_transport.py index a385335..847d0a4 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -260,12 +260,15 @@ def send_method(request: httpx.Request) -> httpx.Response: response = Mock(spec=httpx.Response) response.status_code = status_code response.headers = httpx.Headers() + response.extensions = {} response.close = Mock() responses.append(response) return response - transport._retry_operation(request=httpx.Request("GET", "https://example.com"), send_method=send_method) + transport._retry_operation( + request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry + ) assert all(r.close.called for r in responses[:-1]) @@ -522,6 +525,60 @@ async def test_async_from_base_transport() -> None: assert response.status_code == 200 +def test_retry_extension_overrides_transport(mock_responses: MockResponse) -> None: + mock_sleep, status_code_sequences = mock_responses + status_code_sequences["https://example.com/fail"] = status_codes([(429, None)]) + transport = RetryTransport(retry=Retry(total=10)) + + request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)}) + with httpx.Client(transport=transport) as client: + response = client.send(request) + + assert response.status_code == 429 + assert mock_sleep.call_count == 2 + assert response.extensions["retry"].attempts_made == 2 + + +def test_retry_extension_set_from_transport_when_absent(mock_responses: MockResponse) -> None: + mock_sleep, _ = mock_responses + transport = RetryTransport(retry=Retry(total=3)) + + request = httpx.Request("GET", "https://example.com") + with httpx.Client(transport=transport) as client: + response = client.send(request) + + assert request.extensions["retry"] is transport.retry + assert response.extensions["retry"] is transport.retry + + +@pytest.mark.asyncio +async def test_async_retry_extension_overrides_transport(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, status_code_sequences = mock_async_responses + status_code_sequences["https://example.com/fail"] = astatus_codes([(429, None)]) + transport = RetryTransport(retry=Retry(total=10)) + + request = httpx.Request("GET", "https://example.com/fail", extensions={"retry": Retry(total=2)}) + async with httpx.AsyncClient(transport=transport) as client: + response = await client.send(request) + + assert response.status_code == 429 + assert mock_asleep.call_count == 2 + assert response.extensions["retry"].attempts_made == 2 + + +@pytest.mark.asyncio +async def test_async_retry_extension_set_from_transport_when_absent(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + transport = RetryTransport(retry=Retry(total=3)) + + request = httpx.Request("GET", "https://example.com") + async with httpx.AsyncClient(transport=transport) as client: + response = await client.send(request) + + assert request.extensions["retry"] is transport.retry + assert response.extensions["retry"] is transport.retry + + def test_retry_after_capped_by_total_timeout(mock_responses: MockResponse) -> None: mock_sleep, status_code_sequences = mock_responses status_code_sequences["https://example.com/fail"] = status_codes([(429, "120")]) @@ -566,12 +623,15 @@ async def send_method(request: httpx.Request) -> httpx.Response: response = AsyncMock(spec=httpx.Response) response.status_code = status_code response.headers = httpx.Headers() + response.extensions = {} response.aclose = AsyncMock() responses.append(response) return response - await transport._retry_operation_async(request=httpx.Request("GET", "https://example.com"), send_method=send_method) + await transport._retry_operation_async( + request=httpx.Request("GET", "https://example.com"), send_method=send_method, retry=transport.retry + ) assert all(r.aclose.called for r in responses[:-1])