diff --git a/amplifier_module_loop_streaming/__init__.py b/amplifier_module_loop_streaming/__init__.py index b87af2a..bdc77c1 100644 --- a/amplifier_module_loop_streaming/__init__.py +++ b/amplifier_module_loop_streaming/__init__.py @@ -7,6 +7,7 @@ __amplifier_module_type__ = "orchestrator" import asyncio +import json import logging import re import time @@ -266,25 +267,52 @@ async def _execute_stream( "Appended ephemeral injection to last tool result message" ) else: - # Fall back to new message if last message isn't a tool result + # Fall back: merge if same role, else new message + if last_msg.get("role") == result.context_injection_role: + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{result.context_injection}", + } + logger.debug( + "Merged ephemeral injection into last message (same role: %s)", + result.context_injection_role, + ) + else: + message_dicts.append( + { + "role": result.context_injection_role, + "content": result.context_injection, + } + ) + logger.debug( + f"Last message role is '{last_msg.get('role')}', not 'tool' - " + "created new message for injection" + ) + else: + # Structural prevention: merge into last message if same role + # to avoid consecutive messages with the same role (confuses models) + if ( + len(message_dicts) > 0 + and message_dicts[-1].get("role") == result.context_injection_role + ): + last_msg = message_dicts[-1] + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{result.context_injection}", + } + logger.debug( + "Merged ephemeral injection into last message (same role: %s)", + result.context_injection_role, + ) + else: message_dicts.append( { "role": result.context_injection_role, "content": result.context_injection, } ) - logger.debug( - f"Last message role is '{last_msg.get('role')}', not 'tool' - " - "created new message for injection" - ) - else: - # Default behavior: append as new message - message_dicts.append( - { - "role": result.context_injection_role, - "content": result.context_injection, - } - ) # Apply pending ephemeral injections from tool:post hooks if self._pending_ephemeral_injections: @@ -304,22 +332,50 @@ async def _execute_stream( "Applied pending ephemeral injection to last tool result" ) else: - message_dicts.append( - { - "role": injection["role"], - "content": injection["content"], + # Merge if same role as last message + if last_msg.get("role") == injection["role"]: + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{injection['content']}", } + logger.debug( + "Merged pending injection into last message (same role: %s)", + injection["role"], + ) + else: + message_dicts.append( + { + "role": injection["role"], + "content": injection["content"], + } + ) + logger.debug( + "Last message not a tool result, created new message for injection" + ) + else: + # Structural prevention: merge if same role as last message + if ( + len(message_dicts) > 0 + and message_dicts[-1].get("role") == injection["role"] + ): + last_msg = message_dicts[-1] + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{injection['content']}", + } + logger.debug( + "Merged pending ephemeral injection into last message (same role: %s)", + injection["role"], + ) + else: + message_dicts.append( + {"role": injection["role"], "content": injection["content"]} ) logger.debug( - "Last message not a tool result, created new message for injection" + "Applied pending ephemeral injection as new message" ) - else: - message_dicts.append( - {"role": injection["role"], "content": injection["content"]} - ) - logger.debug( - "Applied pending ephemeral injection as new message" - ) # Clear pending injections after applying self._pending_ephemeral_injections = [] @@ -668,16 +724,27 @@ async def _execute_stream( # Get one final response with the reminder (via _execute_stream helper) message_dicts = await context.get_messages_for_request(provider=provider) message_dicts = list(message_dicts) - message_dicts.append( - { - "role": "user", - "content": """ + # Merge loop-limit reminder into last message if same role, + # to avoid consecutive user messages that confuse role attribution + _loop_limit_content = """ You have reached the maximum number of iterations for this turn. Please provide a response to the user now, summarizing your progress and noting what remains to be done. You can continue in the next turn if needed. DO NOT mention this iteration limit or reminder to the user explicitly. Simply wrap up naturally. -""", +""" + if ( + len(message_dicts) > 0 + and message_dicts[-1].get("role") == "user" + ): + last_msg = message_dicts[-1] + original_content = last_msg.get("content", "") + message_dicts[-1] = { + **last_msg, + "content": f"{original_content}\n\n{_loop_limit_content}", } - ) + else: + message_dicts.append( + {"role": "user", "content": _loop_limit_content} + ) try: # Convert dicts to ChatRequest @@ -987,8 +1054,24 @@ async def _execute_tool_only( f"Stored ephemeral injection from tool:post ({tool_call.name}) for next iteration" ) - # Return result content (JSON-serialized for dict/list outputs) - content = result.get_serialized_output() + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if returned_result is not None and returned_result is not result_data: + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + content = json.dumps(modified_result) + else: + content = str(modified_result) + else: + content = result.get_serialized_output() return (tool_call.id, tool_call.name, content) except Exception as e: @@ -1106,13 +1189,31 @@ async def _execute_tool_with_result( f"Stored ephemeral injection from tool:post ({tool_call.name}) for next iteration" ) - # Add result with tool_call_id (JSON-serialized for dict/list outputs) + # Check if a hook modified the tool result. + # hooks.emit() chains modify actions: when a hook + # returns action="modify", the data dict is replaced. + # We detect this by checking if the returned "result" + # is a different object than what we originally sent. + modified_result = None + if post_result and post_result.data is not None: + returned_result = post_result.data.get("result") + if returned_result is not None and returned_result is not result_data: + modified_result = returned_result + + if modified_result is not None: + if isinstance(modified_result, (dict, list)): + tool_content = json.dumps(modified_result) + else: + tool_content = str(modified_result) + else: + tool_content = result.get_serialized_output() + await context.add_message( { "role": "tool", "name": tool_call.name, "tool_call_id": tool_call.id, - "content": result.get_serialized_output(), + "content": tool_content, } ) response_added = True diff --git a/tests/test_hook_modify.py b/tests/test_hook_modify.py new file mode 100644 index 0000000..d901ca6 --- /dev/null +++ b/tests/test_hook_modify.py @@ -0,0 +1,200 @@ +"""Tests for hook modify action on tool:post events in streaming orchestrator. + +Verifies that when a hook returns HookResult(action="modify", data={"result": ...}) +on a tool:post event, both _execute_tool_only and _execute_tool_with_result +use the modified data instead of the original result.get_serialized_output(). +""" + +import json + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from amplifier_core.hooks import HookRegistry +from amplifier_core.models import HookResult + + +def _make_tool_result(output, success=True): + """Create a mock tool result with get_serialized_output() and model_dump().""" + result = MagicMock() + result.success = success + result.output = output + result.error = None + + def get_serialized_output(): + if isinstance(output, (dict, list)): + return json.dumps(output) + return str(output) + + result.get_serialized_output = get_serialized_output + + def model_dump(): + return {"success": success, "output": output, "error": None} + + result.model_dump = model_dump + return result + + +@pytest.mark.asyncio +async def test_execute_tool_only_modify_replaces_result(): + """_execute_tool_only should use modified data when hook returns modify.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + # Tool with original output + original_output = {"original": True, "big_data": "x" * 1000} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # Hook that modifies tool output + modified_content = {"modified": True, "truncated": True} + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult(action="modify", data={"result": modified_content}) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + # Mock tool call + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {"key": "value"} + + # Call _execute_tool_only directly + tool_call_id, tool_name, content = await orchestrator._execute_tool_only( + tool_call, tools, hooks, "group_1" + ) + + assert tool_call_id == "tc_1" + assert tool_name == "test_tool" + # Content should be the MODIFIED data + assert content == json.dumps(modified_content), ( + f"Expected modified content {json.dumps(modified_content)}, got {content}" + ) + assert "big_data" not in content + + +@pytest.mark.asyncio +async def test_execute_tool_only_no_modify_uses_original(): + """_execute_tool_only should use original when no modify hook.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + original_output = {"original": True} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + hooks = HookRegistry() + + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {} + + tool_call_id, tool_name, content = await orchestrator._execute_tool_only( + tool_call, tools, hooks, "group_1" + ) + + assert content == json.dumps(original_output) + + +@pytest.mark.asyncio +async def test_execute_tool_with_result_modify_replaces_context(): + """_execute_tool_with_result should use modified data in context message.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + # Tool with original output + original_output = {"original": True, "big_data": "x" * 1000} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + # Hook that modifies tool output + modified_content = {"modified": True, "truncated": True} + hooks = HookRegistry() + + async def modify_hook(event: str, data: dict) -> HookResult: + if event == "tool:post": + return HookResult(action="modify", data={"result": modified_content}) + return HookResult() + + hooks.register("tool:post", modify_hook, priority=50, name="test_modify") + + # Mock tool call + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {"key": "value"} + + # Mock context to capture add_message calls + context = AsyncMock() + messages_added = [] + + async def capture_add_message(msg): + messages_added.append(msg) + + context.add_message = AsyncMock(side_effect=capture_add_message) + + # Call _execute_tool_with_result directly + await orchestrator._execute_tool_with_result(tool_call, tools, context, hooks) + + # Find the tool result message + tool_msgs = [msg for msg in messages_added if msg.get("role") == "tool"] + assert len(tool_msgs) == 1, f"Expected 1 tool message, got {len(tool_msgs)}" + + tool_content = tool_msgs[0]["content"] + assert tool_content == json.dumps(modified_content), ( + f"Expected modified content {json.dumps(modified_content)}, got {tool_content}" + ) + assert "big_data" not in tool_content + + +@pytest.mark.asyncio +async def test_execute_tool_with_result_no_modify_uses_original(): + """_execute_tool_with_result should use original when no modify hook.""" + with patch.dict("sys.modules", {"amplifier_core.llm_errors": MagicMock()}): + from amplifier_module_loop_streaming import StreamingOrchestrator + + orchestrator = StreamingOrchestrator({"max_iterations": 5}) + + original_output = {"original": True} + mock_tool = MagicMock() + mock_tool.name = "test_tool" + mock_tool.execute = AsyncMock(return_value=_make_tool_result(original_output)) + tools = {"test_tool": mock_tool} + + hooks = HookRegistry() + + tool_call = MagicMock() + tool_call.id = "tc_1" + tool_call.name = "test_tool" + tool_call.arguments = {} + + context = AsyncMock() + messages_added = [] + + async def capture_add_message(msg): + messages_added.append(msg) + + context.add_message = AsyncMock(side_effect=capture_add_message) + + await orchestrator._execute_tool_with_result(tool_call, tools, context, hooks) + + tool_msgs = [msg for msg in messages_added if msg.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["content"] == json.dumps(original_output) diff --git a/tests/test_rate_limiting.py b/tests/test_rate_limiting.py index 4446327..6b9b2e5 100644 --- a/tests/test_rate_limiting.py +++ b/tests/test_rate_limiting.py @@ -1,8 +1,9 @@ """Tests for orchestrator rate limiting feature.""" -import asyncio import time -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock + +import pytest class TestRateLimitingConfig: @@ -11,73 +12,74 @@ class TestRateLimitingConfig: def test_default_disabled(self): """Rate limiting should be disabled by default.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({}) assert orchestrator.min_delay_between_calls_ms == 0 def test_config_enabled(self): """Rate limiting can be enabled via config.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 500}) assert orchestrator.min_delay_between_calls_ms == 500 def test_tracking_initialized(self): """Last provider call timestamp should start as None.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({}) assert orchestrator._last_provider_call_end is None +@pytest.mark.asyncio class TestRateLimitDelay: """Test the _apply_rate_limit_delay method.""" async def test_no_delay_when_disabled(self): """No delay should be applied when rate limiting is disabled.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 0}) hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 1) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called() async def test_no_delay_on_first_call(self): """No delay on first call (no previous timestamp).""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 1000}) hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 1) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called() async def test_delay_applied_when_needed(self): """Delay should be applied when elapsed < configured.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 100}) orchestrator._last_provider_call_end = time.monotonic() # Just now hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 2) elapsed = (time.monotonic() - start) * 1000 # Convert to ms - + # Should have delayed close to 100ms assert elapsed >= 90 # Allow some tolerance assert elapsed < 150 hooks.emit.assert_called_once() - + # Verify event payload call_args = hooks.emit.call_args assert call_args[0][0] == "orchestrator:rate_limit_delay" @@ -87,14 +89,14 @@ async def test_delay_applied_when_needed(self): async def test_no_delay_if_enough_time_elapsed(self): """No delay if enough time has already passed.""" from amplifier_module_loop_streaming import StreamingOrchestrator - + orchestrator = StreamingOrchestrator({"min_delay_between_calls_ms": 50}) orchestrator._last_provider_call_end = time.monotonic() - 0.1 # 100ms ago hooks = AsyncMock() - + start = time.monotonic() await orchestrator._apply_rate_limit_delay(hooks, 2) elapsed = time.monotonic() - start - + assert elapsed < 0.01 # Should be nearly instant hooks.emit.assert_not_called()