diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index aadba1d361..0a0cc82f85 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -29,6 +29,7 @@ "local_shell_call": "local_shell_call_output", "tool_search_call": "tool_search_output", } +_CALL_OUTPUT_TYPES: frozenset[str] = frozenset(_TOOL_CALL_TO_OUTPUT_TYPE.values()) __all__ = [ "ReasoningItemIdPolicy", @@ -37,6 +38,7 @@ "TOOL_CALL_SESSION_TITLE_KEY", "copy_input_items", "drop_orphan_function_calls", + "drop_orphaned_messages_after_consumed_reasoning", "ensure_input_item_format", "prepare_model_input_items", "run_item_to_input_item", @@ -179,6 +181,63 @@ def _drop_reasoning_items_preceding_dropped_calls( return [entry for idx, entry in enumerate(items) if idx not in excluded] +def drop_orphaned_messages_after_consumed_reasoning( + items: list[TResponseInputItem], +) -> list[TResponseInputItem]: + """Drop message items that are orphaned because their preceding reasoning item was consumed + by a tool call. + + The Responses API requires every message item to be paired with its own reasoning item. When + any tool call (function_call, computer_call, shell_call, etc.) follows a reasoning item, that + reasoning item is considered consumed by the call. Any message item that follows (e.g. the + handoff agent's closing message) has no paired reasoning and causes a 400 from some providers: + ``Item 'msg_...' of type 'message' was provided without its required 'reasoning' item``. + + The drop is scoped to the first message after the consuming call. Dropping resets the flag so + that later turns whose assistant messages legitimately lack a reasoning item are not affected. + + This is the inverse of :func:`drop_orphan_function_calls`, which removes function calls + without outputs and their preceding reasoning items. + """ + fresh_reasoning = False # True when the most-recent reasoning item is not yet consumed + consumed_by_call = False # True after any tool call consumes the fresh reasoning + result: list[TResponseInputItem] = [] + + for item in items: + if not isinstance(item, dict): + result.append(item) + continue + item_type = item.get("type") + + if item_type == "reasoning": + fresh_reasoning = True + consumed_by_call = False + result.append(item) + elif item_type in _TOOL_CALL_TO_OUTPUT_TYPE: + if fresh_reasoning: + fresh_reasoning = False + consumed_by_call = True # reasoning is now consumed by this call + result.append(item) + elif item_type in _CALL_OUTPUT_TYPES: + # Any call output (function_call_output, computer_call_output, etc.) marks the + # end of its call sequence. The SDK appends call outputs after all model output + # items, so any orphaned message has already been dropped by this point. Reset + # here so that turns with no trailing message do not bleed consumed_by_call into + # the next agent's responses regardless of the call type. + consumed_by_call = False + result.append(item) + elif item_type == "message": + if not consumed_by_call: + result.append(item) + # else: orphaned — reasoning consumed by the preceding call; drop without resetting + # so that any further messages in the same turn are also dropped until a + # call-output item resets consumed_by_call. + else: + result.append(item) + + return result + + def ensure_input_item_format(item: TResponseInputItem) -> TResponseInputItem: """Ensure a single item is normalized for model input.""" coerced = _coerce_to_dict(item) @@ -213,7 +272,8 @@ def prepare_model_input_items( return normalized_caller_items normalized_generated_items = normalize_input_items_for_api(list(generated_items)) - filtered_generated_items = drop_orphan_function_calls(normalized_generated_items) + filtered_generated_items = drop_orphaned_messages_after_consumed_reasoning(normalized_generated_items) + filtered_generated_items = drop_orphan_function_calls(filtered_generated_items) return normalized_caller_items + filtered_generated_items @@ -223,7 +283,8 @@ def normalize_resumed_input( """Normalize resumed list inputs and drop orphan tool calls.""" if isinstance(raw_input, list): normalized = normalize_input_items_for_api(raw_input) - return drop_orphan_function_calls(normalized) + filtered = drop_orphaned_messages_after_consumed_reasoning(normalized) + return drop_orphan_function_calls(filtered) return raw_input diff --git a/src/agents/run_internal/oai_conversation.py b/src/agents/run_internal/oai_conversation.py index 4a0e088353..fc8b1c8e67 100644 --- a/src/agents/run_internal/oai_conversation.py +++ b/src/agents/run_internal/oai_conversation.py @@ -21,6 +21,7 @@ from .items import ( ReasoningItemIdPolicy, drop_orphan_function_calls, + drop_orphaned_messages_after_consumed_reasoning, fingerprint_input_item, normalize_input_items_for_api, prepare_model_input_items, @@ -501,7 +502,8 @@ def prepare_input( normalized_generated_items, prepared_generated_items, strict=False ) } - filtered_generated_items = drop_orphan_function_calls(normalized_generated_items) + filtered_generated_items = drop_orphaned_messages_after_consumed_reasoning(normalized_generated_items) + filtered_generated_items = drop_orphan_function_calls(filtered_generated_items) for item in filtered_generated_items: prepared_source_item = normalized_generated_sources.get(id(item)) if prepared_source_item is not None: diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index f483da13a3..8cfcf4ebca 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -29,6 +29,7 @@ copy_input_items, deduplicate_input_items_preferring_latest, drop_orphan_function_calls, + drop_orphaned_messages_after_consumed_reasoning, ensure_input_item_format, fingerprint_input_item, normalize_input_items_for_api, @@ -176,8 +177,9 @@ async def prepare_input_with_session( prune_history_indexes, ) prepared_as_inputs = [ensure_input_item_format(item) for item in prepared_items_raw] + filtered = drop_orphaned_messages_after_consumed_reasoning(prepared_as_inputs) filtered = drop_orphan_function_calls( - prepared_as_inputs, + filtered, pruning_indexes=prune_history_indexes, ) normalized = normalize_input_items_for_api(filtered) diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index eb22c70f14..91396baed4 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -376,6 +376,134 @@ def test_normalize_resumed_input_drops_orphan_tool_search_calls(): assert "paired_search" in call_ids +def test_normalize_resumed_input_drops_orphaned_message_after_consumed_reasoning(): + """normalize_resumed_input must strip messages orphaned by a consumed reasoning item. + + The SDK appends tool outputs (function_call_output) after all model-emitted items, so the + orphaned message appears between the function_call and its output in the flat list. + """ + raw_input: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"type": "reasoning", "id": "rs_1", "summary": []}), + cast( + TResponseInputItem, + {"type": "function_call", "call_id": "fc_1", "name": "transfer_to_x", "arguments": "{}"}, + ), + # message comes before function_call_output — model emits it, SDK appends the output after + cast( + TResponseInputItem, + {"type": "message", "role": "assistant", "content": "I'm handing off now."}, + ), + cast(TResponseInputItem, {"type": "function_call_output", "call_id": "fc_1", "output": "ok"}), + ] + + normalized = normalize_resumed_input(raw_input) + assert isinstance(normalized, list) + message_items = [ + item for item in normalized + if isinstance(item, dict) and item.get("type") == "message" and item.get("role") == "assistant" + ] + assert not message_items, "Orphaned assistant message must be dropped by normalize_resumed_input" + + +def test_normalize_resumed_input_drops_multiple_orphaned_messages_in_same_turn(): + """All orphaned messages before the call output must be dropped, not just the first one.""" + raw_input: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"type": "reasoning", "id": "rs_1", "summary": []}), + cast( + TResponseInputItem, + {"type": "function_call", "call_id": "fc_1", "name": "transfer_to_x", "arguments": "{}"}, + ), + cast(TResponseInputItem, {"type": "message", "role": "assistant", "content": "msg one"}), + cast(TResponseInputItem, {"type": "message", "role": "assistant", "content": "msg two"}), + cast(TResponseInputItem, {"type": "function_call_output", "call_id": "fc_1", "output": "ok"}), + ] + + normalized = normalize_resumed_input(raw_input) + assert isinstance(normalized, list) + assistant_messages = [ + item for item in normalized + if isinstance(item, dict) and item.get("type") == "message" and item.get("role") == "assistant" + ] + assert not assistant_messages, ( + "All orphaned assistant messages before the call output must be dropped, not just the first" + ) + + +def test_normalize_resumed_input_drops_orphaned_message_when_no_call_output(): + """Orphaned message must be dropped even when the function_call has no matching output. + + drop_orphan_function_calls() would remove the [reasoning, function_call] pair before + message-pruning runs if the order were reversed, leaving the orphaned message undetected. + Running message-pruning first ensures the message is dropped while the reasoning context + is still present, then drop_orphan_function_calls() cleans up the call pair. + """ + raw_input: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"type": "reasoning", "id": "rs_1", "summary": []}), + cast( + TResponseInputItem, + {"type": "function_call", "call_id": "fc_1", "name": "transfer_to_x", "arguments": "{}"}, + ), + # no function_call_output — orphaned call AND orphaned message + cast( + TResponseInputItem, + {"type": "message", "role": "assistant", "content": "Transferring now."}, + ), + ] + + normalized = normalize_resumed_input(raw_input) + assert isinstance(normalized, list) + assert normalized == [], ( + "Both the orphaned message and the call-without-output (+ its reasoning) must be dropped" + ) + + +@pytest.mark.asyncio +async def test_server_conversation_tracker_drops_orphaned_message_after_consumed_reasoning(): + """The OAI server-conversation path must strip orphaned messages via Runner.run end-to-end.""" + model = FakeModel() + delegate = Agent(name="delegate", model=model) + triage = Agent(name="triage", model=model, handoffs=[delegate]) + + model.add_multiple_turn_outputs( + [ + [ + ResponseReasoningItem( + id="rs_111", + type="reasoning", + summary=[Summary(text="Deciding to hand off.", type="summary_text")], + ), + get_handoff_tool_call(delegate), + get_text_message("Transferring now."), # orphaned — no own reasoning + ], + [get_text_message("done")], + ] + ) + + captured: list[list[dict[str, Any]]] = [] + + def capture(data): + if isinstance(data.model_data.input, list): + captured.append([item for item in data.model_data.input if isinstance(item, dict)]) + return data.model_data + + run_result = await Runner.run( + triage, + input="hello", + run_config=RunConfig(call_model_input_filter=capture), + ) + assert run_result.final_output == "done" + assert len(captured) >= 2 + + second_input = captured[1] + orphaned = [ + item for item in second_input + if item.get("type") == "message" + and item.get("role") == "assistant" + and "Transferring" in str(item.get("content", "")) + ] + assert not orphaned, "Orphaned assistant message must be absent from the second model call." + + def test_normalize_resumed_input_preserves_hosted_tool_search_pair_without_call_ids(): raw_input: list[TResponseInputItem] = [ cast( @@ -1032,6 +1160,164 @@ def capture_model_input(data): assert "reasoning" not in handoff_input_types +@pytest.mark.asyncio +async def test_handoff_drops_orphaned_message_after_consumed_reasoning() -> None: + """ + When a model turn during a handoff emits [reasoning, function_call, message], the reasoning + item is consumed by the function_call. The trailing message has no paired reasoning and some + providers (e.g. Azure OpenAI) reject it with HTTP 400. Verify it is dropped from input[]. + + Also verifies that the drop is scoped to that one trailing message: the delegate agent's + subsequent response (which has no reasoning of its own) must NOT be dropped. + """ + model = FakeModel() + delegate = Agent(name="delegate", model=model) + triage = Agent(name="triage", model=model, handoffs=[delegate]) + + model.add_multiple_turn_outputs( + [ + [ + ResponseReasoningItem( + id="rs_111", + type="reasoning", + summary=[Summary(text="Thinking about handoff.", type="summary_text")], + ), + get_handoff_tool_call(delegate), + get_text_message("I'm transferring you now."), # orphaned — no own reasoning + ], + [get_text_message("done")], # delegate responds without reasoning — must be kept + ] + ) + + captured_inputs: list[list[dict[str, Any]]] = [] + + def capture_model_input(data): + if isinstance(data.model_data.input, list): + captured_inputs.append( + [item for item in data.model_data.input if isinstance(item, dict)] + ) + return data.model_data + + result = await Runner.run( + triage, + input="user_message", + run_config=RunConfig(call_model_input_filter=capture_model_input), + ) + + # delegate's "done" message must reach final_output — if the drop leaked into later turns + # it would be missing and the runner would stall or return a wrong value. + assert result.final_output == "done" + assert len(captured_inputs) >= 2 + + second_input = captured_inputs[1] + orphaned_messages = [ + item for item in second_input + if item.get("type") == "message" and item.get("role") == "assistant" + and "transferring" in str(item.get("content", "")) + ] + assert not orphaned_messages, ( + "Message item emitted after a handoff function_call (which consumed the only reasoning " + "item) must be dropped from input[] to prevent provider API rejection." + ) + + +@pytest.mark.asyncio +async def test_handoff_without_trailing_message_keeps_delegate_response() -> None: + """ + When the handoff turn emits only [reasoning, function_call] with NO trailing message, + consumed_by_call must not leak into the next turn and silently drop the delegate's reply. + """ + model = FakeModel() + delegate = Agent(name="delegate", model=model) + triage = Agent(name="triage", model=model, handoffs=[delegate]) + + model.add_multiple_turn_outputs( + [ + [ + ResponseReasoningItem( + id="rs_111", + type="reasoning", + summary=[Summary(text="Deciding to hand off.", type="summary_text")], + ), + get_handoff_tool_call(delegate), + # no trailing message — the common case + ], + [get_text_message("delegate reply")], + ] + ) + + result = await Runner.run(triage, input="user_message") + + assert result.final_output == "delegate reply", ( + "Delegate response must not be dropped when the handoff turn has no trailing message." + ) + + +@pytest.mark.asyncio +async def test_session_history_drops_orphaned_message_on_next_run() -> None: + """ + save_result_to_session() persists raw run items including any orphaned trailing message. + On the next Runner.run(..., session=session) the history is rebuilt via + prepare_input_with_session(), which must apply drop_orphaned_messages_after_consumed_reasoning() + so the re-sent history does not contain the orphaned message that would cause a provider 400. + """ + model = FakeModel() + delegate = Agent(name="delegate", model=model) + triage = Agent(name="triage", model=model, handoffs=[delegate]) + session = SimpleListSession() + + # First run: triage reasons, hands off, and emits an orphaned trailing message. + model.add_multiple_turn_outputs( + [ + [ + ResponseReasoningItem( + id="rs_111", + type="reasoning", + summary=[Summary(text="Thinking about handoff.", type="summary_text")], + ), + get_handoff_tool_call(delegate), + get_text_message("I'm transferring you now."), # orphaned + ], + [get_text_message("done")], + ] + ) + first_result = await Runner.run(triage, input="user_message", session=session) + assert first_result.final_output == "done" + + # Second run: history is loaded from session. Capture what the model receives. + model.set_next_output([get_text_message("second done")]) + captured_inputs: list[list[dict[str, Any]]] = [] + + def capture(data): + if isinstance(data.model_data.input, list): + captured_inputs.append( + [item for item in data.model_data.input if isinstance(item, dict)] + ) + return data.model_data + + second_result = await Runner.run( + delegate, + input="follow-up", + session=session, + run_config=RunConfig(call_model_input_filter=capture), + ) + assert second_result.final_output == "second done" + + # The session-reconstructed history must not contain the orphaned trailing message. + assert captured_inputs, "call_model_input_filter must have fired" + first_captured = captured_inputs[0] + orphaned = [ + item for item in first_captured + if item.get("type") == "message" + and item.get("role") == "assistant" + and "transferring" in str(item.get("content", "")) + ] + assert not orphaned, ( + "Orphaned message saved to session must be filtered out when history is " + "replayed on the next run, to prevent provider 400 errors." + ) + + @pytest.mark.asyncio async def test_resume_preserves_filtered_model_input_after_handoff(): model = FakeModel()