From 38adf0514ae9310a0b9c7efda25a743fafbecddb Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Tue, 10 Feb 2026 12:14:16 -0800 Subject: [PATCH 1/2] fix: read HookResult modify action from tool:post events + fix rate limiting test markers (#8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: read HookResult modify action from tool:post events (2 sites) Both _execute_tool_only and _execute_tool_with_result now detect when a hook modifies tool output via action='modify' on tool:post, and use the modified data instead of the original get_serialized_output(). This enables truncation and transformation hooks to work correctly. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> * fix: add missing @pytest.mark.asyncio to TestRateLimitDelay class The 4 async test methods in TestRateLimitDelay were missing the asyncio marker. With asyncio_mode = "strict" in pyproject.toml, explicit markers are required. Added class-level @pytest.mark.asyncio decorator. --------- Co-authored-by: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- amplifier_module_loop_streaming/__init__.py | 43 ++++- tests/test_hook_modify.py | 200 ++++++++++++++++++++ tests/test_rate_limiting.py | 38 ++-- 3 files changed, 259 insertions(+), 22 deletions(-) create mode 100644 tests/test_hook_modify.py diff --git a/amplifier_module_loop_streaming/__init__.py b/amplifier_module_loop_streaming/__init__.py index b87af2a..6927b68 100644 --- a/amplifier_module_loop_streaming/__init__.py +++ b/amplifier_module_loop_streaming/__init__.py @@ -7,6 +7,7 @@ __amplifier_module_type__ = "orchestrator" import asyncio +import json import logging import re import time @@ -987,8 +988,24 @@ async def _execute_tool_only( f"Stored ephemeral injection from tool:post ({tool_call.name}) for next iteration" ) - # Return result content (JSON-serialized for dict/list outputs) - content = result.get_serialized_output() + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if returned_result is not None and returned_result is not result_data: + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + content = json.dumps(modified_result) + else: + content = str(modified_result) + else: + content = result.get_serialized_output() return (tool_call.id, tool_call.name, content) except Exception as e: @@ -1106,13 +1123,31 @@ async def _execute_tool_with_result( f"Stored ephemeral injection from tool:post ({tool_call.name}) for next iteration" ) - # Add result with tool_call_id (JSON-serialized for dict/list outputs) + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if returned_result is not None and returned_result is not result_data: + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + tool_content = json.dumps(modified_result) + else: + tool_content = str(modified_result) + else: + tool_content = result.get_serialized_output() + await context.add_message( { "role": "tool", "name": tool_call.name, "tool_call_id": tool_call.id, - "content": result.get_serialized_output(), + "content": tool_content, } ) response_added = True diff --git a/tests/test_hook_modify.py b/tests/test_hook_modify.py new file mode 100644 index 0000000..d901ca6 --- /dev/null +++ b/tests/test_hook_modify.py @@ -0,0 +1,200 @@ +"""Tests for hook modify action on tool:post events in streaming orchestrator. + +Verifies that when a hook returns HookResult(action="modify", data={"result": ...}) +on a tool:post event, both _execute_tool_only and _execute_tool_with_result +use the modified data instead of the original result.get_serialized_output(). +""" + +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from amplifier_core.hooks import HookRegistry +from amplifier_core.models import HookResult + + +def _make_tool_result(output, success=True): + """Create a mock tool result with get_serialized_output() and model_dump().""" + result = MagicMock() + result.success = success + result.output = output + result.error = None + + def get_serialized_output(): + if isinstance(output, (dict, list)): + return json.dumps(output) + return str(output) + + result.get_serialized_output = get_serialized_output + + def model_dump(): + return {"success": success, "output": output, "error": None} + + result.model_dump = model_dump + return result + + +@pytest.mark.asyncio +async def test_execute_tool_only_modify_replaces_result(): + """_execute_tool_only should use modified data when hook returns modify.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + # Tool with original output + original_output = {"original": True, "big_data": "x" * 1000} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # Hook that modifies tool output + modified_content = {"modified": True, "truncated": True} + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult(action="modify", data={"result": modified_content}) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + # Mock tool call + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {"key": "value"} + + # Call _execute_tool_only directly + tool_call_id, tool_name, content = await orchestrator._execute_tool_only( + tool_call, tools, hooks, "group_1" + ) + + assert tool_call_id == "tc_1" + assert tool_name == "test_tool" + # Content should be the MODIFIED data + assert content == json.dumps(modified_content), ( + f"Expected modified content {json.dumps(modified_content)}, got {content}" + ) + assert "big_data" not in content + + +@pytest.mark.asyncio +async def test_execute_tool_only_no_modify_uses_original(): + """_execute_tool_only should use original when no modify hook.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + original_output = {"original": True} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + hooks = HookRegistry() + + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {} + + tool_call_id, tool_name, content = await orchestrator._execute_tool_only( + tool_call, tools, hooks, "group_1" + ) + + assert content == json.dumps(original_output) + + +@pytest.mark.asyncio +async def test_execute_tool_with_result_modify_replaces_context(): + """_execute_tool_with_result should use modified data in context message.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + # Tool with original output + original_output = {"original": True, "big_data": "x" * 1000} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # Hook that modifies tool output + modified_content = {"modified": True, "truncated": True} + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult(action="modify", data={"result": modified_content}) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + # Mock tool call + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {"key": "value"} + + # Mock context to capture add_message calls + context = AsyncMock() + messages_added = [] + + async def capture_add_message(msg): + messages_added.append(msg) + + context.add_message = AsyncMock(side_effect=capture_add_message) + + # Call _execute_tool_with_result directly + await orchestrator._execute_tool_with_result(tool_call, tools, context, hooks) + + # Find the tool result message + tool_msgs = [msg for msg in messages_added if msg.get("role") == "tool"] + assert len(tool_msgs) == 1, f"Expected 1 tool message, got {len(tool_msgs)}" + + tool_content = tool_msgs[0]["content"] + assert tool_content == json.dumps(modified_content), ( + f"Expected modified content {json.dumps(modified_content)}, got {tool_content}" + ) + assert "big_data" not in tool_content + + +@pytest.mark.asyncio +async def test_execute_tool_with_result_no_modify_uses_original(): + """_execute_tool_with_result should use original when no modify hook.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + original_output = {"original": True} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + hooks = HookRegistry() + + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {} + + context = AsyncMock() + messages_added = [] + + async def capture_add_message(msg): + messages_added.append(msg) + + context.add_message = AsyncMock(side_effect=capture_add_message) + + await orchestrator._execute_tool_with_result(tool_call, tools, context, hooks) + + tool_msgs = [msg for msg in messages_added if msg.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"] == json.dumps(original_output) diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py index 4446327..6b9b2e5 100644 --- a/tests/test_rate_limiting.py +++ b/tests/test_rate_limiting.py @@ -1,8 +1,9 @@ """Tests for orchestrator rate limiting feature.""" -import asyncio import time -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock + +import pytest class TestRateLimitingConfig: @@ -11,73 +12,74 @@ class TestRateLimitingConfig: def test_default_disabled(self): """Rate limiting should be disabled by default.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({}) assert orchestrator.min_delay_between_calls_ms == 0 def test_config_enabled(self): """Rate limiting can be enabled via config.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 500}) assert orchestrator.min_delay_between_calls_ms == 500 def test_tracking_initialized(self): """Last provider call timestamp should start as None.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({}) assert orchestrator._last_provider_call_end is None +@pytest.mark.asyncio class TestRateLimitDelay: """Test the _apply_rate_limit_delay method.""" async def test_no_delay_when_disabled(self): """No delay should be applied when rate limiting is disabled.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 0}) hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 1) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called() async def test_no_delay_on_first_call(self): """No delay on first call (no previous timestamp).""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 1000}) hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 1) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called() async def test_delay_applied_when_needed(self): """Delay should be applied when elapsed < configured.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 100}) orchestrator._last_provider_call_end = time.monotonic() # Just now hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 2) elapsed = (time.monotonic() - start) * 1000 # Convert to ms - + # Should have delayed close to 100ms assert elapsed >= 90 # Allow some tolerance assert elapsed < 150 hooks.emit.assert_called_once() - + # Verify event payload call_args = hooks.emit.call_args assert call_args[0][0] == "orchestrator:rate_limit_delay" @@ -87,14 +89,14 @@ async def test_delay_applied_when_needed(self): async def test_no_delay_if_enough_time_elapsed(self): """No delay if enough time has already passed.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 50}) orchestrator._last_provider_call_end = time.monotonic() - 0.1 # 100ms ago hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 2) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called() From 5518f648284824ab8875836e43b211fa98fc6b4f Mon Sep 17 00:00:00 2001 From: Sam Schillace Date: Thu, 12 Feb 2026 13:32:55 -0800 Subject: [PATCH 2/2] feat: add exponential backoff retry for transient provider errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wrap all 3 provider call sites in the streaming orchestrator with configurable exponential backoff retry logic. Retries only on LLMError with retryable=True (RateLimitError, ProviderUnavailableError, LLMTimeoutError). Honors retry_after from provider responses and emits provider:retry events for observability. Config: retry_max_attempts (default 3), retry_base_delay_seconds (1.0), retry_max_delay_seconds (30.0). 18 new tests covering all retry behaviors, zero regressions. 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com> --- amplifier_module_loop_streaming/__init__.py | 190 ++++++---- tests/test_retry.py | 385 ++++++++++++++++++++ 2 files changed, 511 insertions(+), 64 deletions(-) create mode 100644 tests/test_retry.py diff --git a/amplifier_module_loop_streaming/__init__.py b/amplifier_module_loop_streaming/__init__.py index 6927b68..1f9029d 100644 --- a/amplifier_module_loop_streaming/__init__.py +++ b/amplifier_module_loop_streaming/__init__.py @@ -75,6 +75,15 @@ def __init__(self, config: dict[str, Any]): # Store ephemeral injections from tool:post hooks for next iteration self._pending_ephemeral_injections: list[dict[str, Any]] = [] + # Retry configuration for transient provider errors + self.retry_max_attempts = int(config.get("retry_max_attempts", 3)) + self.retry_base_delay_seconds = float( + config.get("retry_base_delay_seconds", 1.0) + ) + self.retry_max_delay_seconds = float( + config.get("retry_max_delay_seconds", 30.0) + ) + async def _apply_rate_limit_delay( self, hooks: HookRegistry, iteration: int ) -> None: @@ -106,6 +115,105 @@ async def _apply_rate_limit_delay( ) await asyncio.sleep(remaining_ms / 1000) + async def _call_provider_with_retry( + self, + call_fn, + hooks: HookRegistry, + provider_name: str | None, + ) -> Any: + """Call a provider function with exponential backoff retry on retryable errors. + + Wraps both async (provider.complete()) and sync (provider.stream()) calls. + Retries only when the error has retryable=True (e.g., RateLimitError, + ProviderUnavailableError, LLMTimeoutError). Honors retry_after from the + error when available, otherwise uses exponential backoff. + + Args: + call_fn: Callable that makes the provider call. May return a coroutine + (for complete()) or a value (for stream()). + hooks: Hook registry for event emission. + provider_name: Name of the provider for event data. + + Returns: + The result of call_fn(). + + Raises: + LLMError: If the error is not retryable or all retries are exhausted. + Exception: If a non-LLM error occurs (never retried). + """ + max_retries = self.retry_max_attempts + + for attempt in range(max_retries + 1): + try: + result = call_fn() + if asyncio.iscoroutine(result): + result = await result + return result + except LLMError as e: + is_last_attempt = attempt >= max_retries + if e.retryable and not is_last_attempt: + # Calculate delay: prefer server-provided retry_after + retry_after = getattr(e, "retry_after", None) + if retry_after is not None: + delay = float(retry_after) + else: + delay = min( + self.retry_base_delay_seconds * (2**attempt), + self.retry_max_delay_seconds, + ) + + await hooks.emit( + "provider:retry", + { + "provider": provider_name, + "error": { + "type": type(e).__name__, + "msg": str(e), + }, + "attempt": attempt + 1, + "max_retries": max_retries, + "delay_seconds": delay, + "retryable": e.retryable, + "status_code": e.status_code, + }, + ) + logger.warning( + "Retryable provider error (attempt %d/%d): %s. " + "Retrying in %.1fs...", + attempt + 1, + max_retries, + e, + delay, + ) + await asyncio.sleep(delay) + continue + + # Not retryable or final attempt — emit error and raise + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + "retryable": e.retryable, + "status_code": e.status_code, + }, + ) + raise + except Exception as e: + # Non-LLM errors are never retried + await hooks.emit( + PROVIDER_ERROR, + { + "provider": provider_name, + "error": {"type": type(e).__name__, "msg": str(e)}, + }, + ) + raise + + # Unreachable, but satisfies type checkers + msg = "Retry loop exited unexpectedly" + raise RuntimeError(msg) + async def execute( self, prompt: str, @@ -390,28 +498,11 @@ async def _execute_stream( kwargs = {} if self.extended_thinking: kwargs["extended_thinking"] = True - try: - response = await provider.complete(chat_request, **kwargs) - except LLMError as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - "retryable": e.retryable, - "status_code": e.status_code, - }, - ) - raise - except Exception as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - }, - ) - raise + response = await self._call_provider_with_retry( + lambda: provider.complete(chat_request, **kwargs), + hooks, + provider_name, + ) # Update rate limit timestamp after non-streaming response self._last_provider_call_end = time.monotonic() @@ -706,7 +797,11 @@ async def _execute_stream( if self.extended_thinking: kwargs["extended_thinking"] = True - response = await provider.complete(max_iter_chat_request, **kwargs) + response = await self._call_provider_with_retry( + lambda: provider.complete(max_iter_chat_request, **kwargs), + hooks, + provider_name, + ) content = ( response.content if hasattr(response, "content") else str(response) ) @@ -719,25 +814,9 @@ async def _execute_stream( # Add to context await context.add_message({"role": "assistant", "content": content}) - except LLMError as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - "retryable": e.retryable, - "status_code": e.status_code, - }, - ) - logger.error(f"Error getting final response after max iterations: {e}") - except Exception as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - }, - ) + except (LLMError, Exception) as e: + # _call_provider_with_retry already emitted provider:error and + # exhausted retries — log and continue gracefully. logger.error(f"Error getting final response after max iterations: {e}") # Emit execution end @@ -771,28 +850,11 @@ async def _stream_from_provider( # Convert tools dict to list for provider tools_list = list(tools.values()) if tools else [] - try: - stream_iter = provider.stream(chat_request, tools=tools_list) - except LLMError as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - "retryable": e.retryable, - "status_code": e.status_code, - }, - ) - raise - except Exception as e: - await hooks.emit( - PROVIDER_ERROR, - { - "provider": provider_name, - "error": {"type": type(e).__name__, "msg": str(e)}, - }, - ) - raise + stream_iter = await self._call_provider_with_retry( + lambda: provider.stream(chat_request, tools=tools_list), + hooks, + provider_name, + ) async for chunk in stream_iter: # Check for immediate cancellation between chunks diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 0000000..25da62a --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,385 @@ +"""Tests for exponential backoff retry logic on retryable provider errors.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestRetryConfig: + """Test retry configuration defaults and overrides.""" + + def test_default_config(self): + """Retry should be enabled by default with sensible defaults.""" + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({}) + assert orchestrator.retry_max_attempts == 3 + assert orchestrator.retry_base_delay_seconds == 1.0 + assert orchestrator.retry_max_delay_seconds == 30.0 + + def test_custom_config(self): + """Retry parameters can be overridden via config.""" + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator( + { + "retry_max_attempts": 5, + "retry_base_delay_seconds": 0.5, + "retry_max_delay_seconds": 60.0, + } + ) + assert orchestrator.retry_max_attempts == 5 + assert orchestrator.retry_base_delay_seconds == 0.5 + assert orchestrator.retry_max_delay_seconds == 60.0 + + def test_retry_disabled(self): + """Setting retry_max_attempts to 0 effectively disables retry.""" + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"retry_max_attempts": 0}) + assert orchestrator.retry_max_attempts == 0 + + +@pytest.mark.asyncio +class TestCallProviderWithRetry: + """Test the _call_provider_with_retry method.""" + + def _make_orchestrator(self, **config_overrides): + from amplifier_module_loop_streaming import StreamingOrchestrator + + config = { + "retry_max_attempts": 3, + "retry_base_delay_seconds": 0.01, # Fast for tests + "retry_max_delay_seconds": 0.1, + } + config.update(config_overrides) + return StreamingOrchestrator(config) + + def _make_llm_error(self, retryable=True, retry_after=None, status_code=429): + """Create a mock LLMError with the right attributes.""" + from amplifier_core.llm_errors import LLMError + + error = LLMError("Rate limit exceeded") + error.retryable = retryable + error.status_code = status_code + if retry_after is not None: + error.retry_after = retry_after + return error + + async def test_success_on_first_attempt(self): + """Should return result immediately when call succeeds.""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + async def success_fn(): + return "response" + + result = await orchestrator._call_provider_with_retry( + success_fn, hooks, "test-provider" + ) + + assert result == "response" + hooks.emit.assert_not_called() + + async def test_retry_on_retryable_error_then_succeed(self): + """Should retry on retryable error and return result on success.""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + call_count = 0 + + async def fail_then_succeed(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise error + return "recovered" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await orchestrator._call_provider_with_retry( + fail_then_succeed, hooks, "test-provider" + ) + + assert result == "recovered" + assert call_count == 3 + # Should have slept twice (after attempt 1 and 2) + assert mock_sleep.call_count == 2 + + async def test_raises_after_max_retries_exhausted(self): + """Should raise after all retry attempts are exhausted.""" + orchestrator = self._make_orchestrator(retry_max_attempts=2) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + + async def always_fail(): + raise error + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + always_fail, hooks, "test-provider" + ) + + async def test_no_retry_on_non_retryable_error(self): + """Should not retry when error has retryable=False.""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + error = self._make_llm_error(retryable=False, status_code=400) + call_count = 0 + + async def fail_non_retryable(): + nonlocal call_count + call_count += 1 + raise error + + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + fail_non_retryable, hooks, "test-provider" + ) + + assert call_count == 1 # No retry + + async def test_no_retry_on_non_llm_exception(self): + """Should not retry non-LLM exceptions (e.g., ValueError).""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + call_count = 0 + + async def fail_with_value_error(): + nonlocal call_count + call_count += 1 + raise ValueError("bad input") + + with pytest.raises(ValueError, match="bad input"): + await orchestrator._call_provider_with_retry( + fail_with_value_error, hooks, "test-provider" + ) + + assert call_count == 1 # No retry + + async def test_exponential_backoff_delays(self): + """Should use exponential backoff: base * 2^attempt.""" + orchestrator = self._make_orchestrator( + retry_max_attempts=3, + retry_base_delay_seconds=1.0, + retry_max_delay_seconds=30.0, + ) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + + async def always_fail(): + raise error + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + always_fail, hooks, "test-provider" + ) + + # 4 attempts total (initial + 3 retries), 3 sleeps + delays = [call.args[0] for call in mock_sleep.call_args_list] + assert delays == [1.0, 2.0, 4.0] # 1*2^0, 1*2^1, 1*2^2 + + async def test_delay_capped_at_max(self): + """Delay should not exceed retry_max_delay_seconds.""" + orchestrator = self._make_orchestrator( + retry_max_attempts=5, + retry_base_delay_seconds=10.0, + retry_max_delay_seconds=25.0, + ) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + + async def always_fail(): + raise error + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + always_fail, hooks, "test-provider" + ) + + delays = [call.args[0] for call in mock_sleep.call_args_list] + # 10*2^0=10, 10*2^1=20, 10*2^2=40→25(capped), 10*2^3=80→25, 10*2^4=160→25 + assert delays == [10.0, 20.0, 25.0, 25.0, 25.0] + + async def test_honors_retry_after_from_error(self): + """Should use retry_after from the error when present.""" + orchestrator = self._make_orchestrator( + retry_max_attempts=2, + retry_base_delay_seconds=1.0, + ) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True, retry_after=5.0) + + async def always_fail(): + raise error + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + always_fail, hooks, "test-provider" + ) + + delays = [call.args[0] for call in mock_sleep.call_args_list] + # Should use retry_after (5.0) instead of exponential backoff + assert all(d == 5.0 for d in delays) + + async def test_emits_retry_event_on_each_retry(self): + """Should emit provider:retry event for each retry attempt.""" + orchestrator = self._make_orchestrator(retry_max_attempts=2) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True, status_code=429) + call_count = 0 + + async def fail_twice_then_succeed(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise error + return "ok" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await orchestrator._call_provider_with_retry( + fail_twice_then_succeed, hooks, "test-provider" + ) + + assert result == "ok" + + # Should have emitted 2 retry events (no error event since it succeeded) + retry_calls = [ + call + for call in hooks.emit.call_args_list + if call.args[0] == "provider:retry" + ] + assert len(retry_calls) == 2 + + # Verify retry event payload + payload = retry_calls[0].args[1] + assert payload["provider"] == "test-provider" + assert payload["attempt"] == 1 + assert payload["max_retries"] == 2 + assert payload["retryable"] is True + assert payload["status_code"] == 429 + assert "delay_seconds" in payload + assert "error" in payload + + async def test_emits_provider_error_on_final_failure(self): + """Should emit PROVIDER_ERROR when all retries are exhausted.""" + orchestrator = self._make_orchestrator(retry_max_attempts=1) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + + async def always_fail(): + raise error + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + always_fail, hooks, "test-provider" + ) + + # Should have emitted: 1 retry event + 1 provider:error event + event_names = [call.args[0] for call in hooks.emit.call_args_list] + assert "provider:retry" in event_names + assert "provider:error" in event_names + + async def test_emits_provider_error_on_non_retryable(self): + """Should emit PROVIDER_ERROR immediately for non-retryable errors.""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + error = self._make_llm_error(retryable=False, status_code=400) + + async def fail_non_retryable(): + raise error + + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + fail_non_retryable, hooks, "test-provider" + ) + + # Only provider:error, no retry events + event_names = [call.args[0] for call in hooks.emit.call_args_list] + assert event_names == ["provider:error"] + + async def test_handles_sync_callable(self): + """Should handle sync callables (e.g., provider.stream()).""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + def sync_fn(): + return "stream_iterator" + + result = await orchestrator._call_provider_with_retry( + sync_fn, hooks, "test-provider" + ) + + assert result == "stream_iterator" + + async def test_retries_sync_callable_on_error(self): + """Should retry sync callables that raise retryable errors.""" + orchestrator = self._make_orchestrator(retry_max_attempts=2) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + call_count = 0 + + def sync_fail_then_succeed(): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise error + return "stream_iterator" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await orchestrator._call_provider_with_retry( + sync_fail_then_succeed, hooks, "test-provider" + ) + + assert result == "stream_iterator" + assert call_count == 2 + + async def test_zero_retries_means_no_retry(self): + """With retry_max_attempts=0, errors should propagate immediately.""" + orchestrator = self._make_orchestrator(retry_max_attempts=0) + hooks = AsyncMock() + + error = self._make_llm_error(retryable=True) + call_count = 0 + + async def fail(): + nonlocal call_count + call_count += 1 + raise error + + with pytest.raises(type(error)): + await orchestrator._call_provider_with_retry( + fail, hooks, "test-provider" + ) + + assert call_count == 1 # Single attempt, no retries + + async def test_none_provider_name_accepted(self): + """Should accept None as provider_name.""" + orchestrator = self._make_orchestrator() + hooks = AsyncMock() + + async def success_fn(): + return "response" + + result = await orchestrator._call_provider_with_retry( + success_fn, hooks, None + ) + + assert result == "response"