diff --git a/docs/faq.md b/docs/faq.md index 98c3bc1..2108b02 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -75,6 +75,38 @@ for attempt in range(5): raise ``` +## Retrying on response content + +Sometimes a server returns a valid response but the body or custom headers signals a failure - for example, a block page, a CAPTCHA redirect, or an authorization wall. This commonly occurs if access may be blocked at the content level rather than the HTTP status level. + +Use `validate_response` to inspect the response and raise an exception to trigger a retry: + +```python +import httpx +from httpx_retries import Retry, RetryTransport + +class ContentBlocked(ValueError): + pass + +def validate_response(response: httpx.Response) -> None: + # safely inspect status and headers if needed + response.raise_for_status() + + # NOTE: Do not call `.read()` here with `Client.stream`, + # it will buffer the entire body, which defeats the purpose of streaming. + response.read() + if "content blocked" in response.text: + raise ContentBlocked(response.text) + +retry = Retry(validate_response=validate_response, retry_on_exceptions=[httpx.HTTPStatusError, ContentBlocked]) + +with httpx.Client(transport=RetryTransport(retry=retry)) as client: + response = client.get("https://example.com") +``` + +!!! warning "Do not call `response.read()` inside `validate_response` with `Client.stream`" + `validate_response` is called before the response is returned to the caller. Calling `response.read()` or `await response.aread()` inside it will buffer the entire body, which defeats the purpose of streaming. If you use streaming, validate only the status code and headers. + ## Limits / Cert / SSL / http2 parameters passed to the client are not being applied This is a limitation of the way transports are applied to clients in HTTPX. If you provide a custom transport, several parameters diff --git a/httpx_retries/retry.py b/httpx_retries/retry.py index 5415cae..8d5a75d 100644 --- a/httpx_retries/retry.py +++ b/httpx_retries/retry.py @@ -4,7 +4,7 @@ import random import sys import time -from collections.abc import Iterable, Mapping +from collections.abc import Awaitable, Callable, Iterable, Mapping from email.utils import parsedate_to_datetime from enum import Enum from http import HTTPStatus @@ -62,6 +62,10 @@ class Retry: repeatedly. Defaults to None (no cumulative cap). elapsed_sleep (float, optional): Cumulative sleep time already spent on this request. Preserved across `increment()` calls; users typically do not set this directly. + validate_response (callable, optional): An optional callback called with each response that would + otherwise be returned as a "good" (non-retryable-status) response. If the callback raises, the + request is retried. May be sync or async; an async callback cannot be used with a sync transport. + Signature: ``(response: httpx.Response) -> None``. """ RETRYABLE_METHODS: Final[frozenset[HTTPMethod]] = frozenset( @@ -101,6 +105,7 @@ def __init__( attempts_made: int = 0, total_timeout: float | None = None, elapsed_sleep: float = 0.0, + validate_response: Callable[[httpx.Response], None | Awaitable[None]] | None = None, ) -> None: """Initialize a new Retry instance.""" if total < 0: @@ -126,6 +131,7 @@ def __init__( self.attempts_made = attempts_made self.total_timeout = total_timeout self.elapsed_sleep = elapsed_sleep + self.validate_response = validate_response self.allowed_methods: frozenset[str] = frozenset( method.upper() for method in (allowed_methods or self.RETRYABLE_METHODS) @@ -292,6 +298,7 @@ def increment(self) -> "Retry": attempts_made=self.attempts_made + 1, total_timeout=self.total_timeout, elapsed_sleep=self.elapsed_sleep, + validate_response=self.validate_response, ) def __repr__(self) -> str: diff --git a/httpx_retries/transport.py b/httpx_retries/transport.py index 46e762e..d622750 100644 --- a/httpx_retries/transport.py +++ b/httpx_retries/transport.py @@ -1,3 +1,4 @@ +import inspect import logging from collections.abc import Callable, Coroutine from functools import partial @@ -126,6 +127,9 @@ def _retry_operation( request: httpx.Request, send_method: Callable[..., httpx.Response], ) -> httpx.Response: + if self.retry.validate_response is not None and inspect.iscoroutinefunction(self.retry.validate_response): + raise TypeError("validate_response must be a sync function when using a sync transport") + retry = self.retry response: httpx.Response | Exception | None = None @@ -146,7 +150,19 @@ def _retry_operation( response = e continue - if retry.is_exhausted() or not retry.is_retryable_status_code(response.status_code): + if retry.is_exhausted(): + return response + + if not retry.is_retryable_status_code(response.status_code): + if self.retry.validate_response is not None: + # normally set by httpx _after_ calling this function, but we want the request in the validator + response.request = request + try: + self.retry.validate_response(response) + except Exception as e: + if retry.is_exhausted() or not retry.is_retryable_exception(e): + raise + continue return response async def _retry_operation_async( @@ -176,5 +192,20 @@ async def _retry_operation_async( response = e continue - if retry.is_exhausted() or not retry.is_retryable_status_code(response.status_code): + if retry.is_exhausted(): + return response + + if not retry.is_retryable_status_code(response.status_code): + if self.retry.validate_response is not None: + # normally set by httpx _after_ calling this function, but we want the request in the validator + response.request = request + try: + if inspect.iscoroutinefunction(self.retry.validate_response): + await self.retry.validate_response(response) + else: + self.retry.validate_response(response) + except Exception as e: + if retry.is_exhausted() or not retry.is_retryable_exception(e): + raise + continue return response diff --git a/tests/test_transport.py b/tests/test_transport.py index 5b354f8..a385335 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -706,3 +706,166 @@ async def handle_request(request: Request) -> Response: # must not mutate it, or retry budgets leak across requests. assert transport.retry.attempts_made == 0 assert transport.retry.elapsed_sleep == 0.0 + + +def test_validate_response_retries_on_failure(mock_responses: MockResponse) -> None: + mock_sleep, status_code_sequences = mock_responses + call_count = 0 + + def validate(response: httpx.Response) -> None: + nonlocal call_count + call_count += 1 + if call_count < 3: + raise httpx.TimeoutException("not ready yet") + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + with httpx.Client(transport=transport) as client: + response = client.get("https://example.com") + + assert response.status_code == 200 + assert call_count == 3 + assert mock_sleep.call_count == 2 + + +def test_validate_response_non_retryable_exception_raises(mock_responses: MockResponse) -> None: + mock_sleep, _ = mock_responses + + def validate(response: httpx.Response) -> None: + raise ValueError("bad response") + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + with pytest.raises(ValueError, match="bad response"): + with httpx.Client(transport=transport) as client: + client.get("https://example.com") + + assert mock_sleep.call_count == 0 + + +def test_validate_response_exhausted_returns_response(mock_responses: MockResponse) -> None: + mock_sleep, _ = mock_responses + + def validate(response: httpx.Response) -> None: + raise httpx.TimeoutException("always bad") + + retry = Retry(total=3, validate_response=validate) + transport = RetryTransport(retry=retry) + + with httpx.Client(transport=transport) as client: + response = client.get("https://example.com") + + assert response.status_code == 200 + assert mock_sleep.call_count == 3 + + +def test_validate_response_async_callback_raises_for_sync_transport(mock_responses: MockResponse) -> None: + mock_sleep, _ = mock_responses + + async def validate(response: httpx.Response) -> None: # pragma: no cover + pass + + retry = Retry(total=3, validate_response=validate) + transport = RetryTransport(retry=retry) + + with pytest.raises(TypeError, match="validate_response must be a sync function"): + with httpx.Client(transport=transport) as client: + client.get("https://example.com") + + +@pytest.mark.asyncio +async def test_async_validate_response_retries_on_failure(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + call_count = 0 + + async def validate(response: httpx.Response) -> None: + nonlocal call_count + call_count += 1 + if call_count < 3: + raise httpx.TimeoutException("not ready yet") + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get("https://example.com") + + assert response.status_code == 200 + assert call_count == 3 + assert mock_asleep.call_count == 2 + + +@pytest.mark.asyncio +async def test_async_validate_response_non_retryable_exception_raises(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + + async def validate(response: httpx.Response) -> None: + raise ValueError("bad response") + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + with pytest.raises(ValueError, match="bad response"): + async with httpx.AsyncClient(transport=transport) as client: + await client.get("https://example.com") + + assert mock_asleep.call_count == 0 + + +@pytest.mark.asyncio +async def test_async_validate_response_sync_callback(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + call_count = 0 + + def validate(response: httpx.Response) -> None: + nonlocal call_count + call_count += 1 + if call_count < 2: + raise httpx.TimeoutException("not ready yet") + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get("https://example.com") + + assert response.status_code == 200 + assert call_count == 2 + assert mock_asleep.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_validate_response_exhausted_returns_response(mock_async_responses: AsyncMockResponse) -> None: + mock_asleep, _ = mock_async_responses + + async def validate(response: httpx.Response) -> None: + raise httpx.TimeoutException("always bad") + + retry = Retry(total=3, validate_response=validate) + transport = RetryTransport(retry=retry) + + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get("https://example.com") + + assert response.status_code == 200 + assert mock_asleep.call_count == 3 + + +def test_validate_response_not_called_for_retryable_status(mock_responses: MockResponse) -> None: + mock_sleep, status_code_sequences = mock_responses + status_code_sequences["https://example.com/fail"] = status_codes([(503, None), (200, None)]) + validated = [] + + def validate(response: httpx.Response) -> None: + validated.append(response.status_code) + + retry = Retry(total=5, validate_response=validate) + transport = RetryTransport(retry=retry) + + with httpx.Client(transport=transport) as client: + response = client.get("https://example.com/fail") + + assert response.status_code == 200 + assert validated == [200]