From 796b285d8436ff0a54d0f4960f514c29f186b2dd Mon Sep 17 00:00:00 2001 From: Oxygen <1391083091@qq.com> Date: Sat, 6 Jun 2026 00:46:39 +0800 Subject: [PATCH] feat: add response_format to CallModelInputFilter callback Allows modifying response_format dynamically in the model input filter callback, enabling per-call structured output control. Closes #3563 Co-Authored-By: Claude Opus 4.8 --- src/agents/agent_output.py | 54 +++++++++++++++++++++++++++++ src/agents/run_config.py | 1 + src/agents/run_internal/run_loop.py | 14 +++++--- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 182355b0f8..4ef15e5bfd 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -168,6 +168,60 @@ def name(self) -> str: return _type_to_str(self.output_type) +class _ResponseFormatOutputSchema(AgentOutputSchemaBase): + """An ``AgentOutputSchemaBase`` that wraps a raw ``response_format`` dict. + + This is used internally to allow ``CallModelInputFilter`` callbacks to + dynamically override the response format before a model call. + """ + + def __init__(self, response_format: dict[str, Any]): + self._response_format = response_format + + def is_plain_text(self) -> bool: + if not isinstance(self._response_format, dict): + return True + return self._response_format.get("type") == "text" + + def is_strict_json_schema(self) -> bool: + if not isinstance(self._response_format, dict): + return False + js = self._response_format.get("json_schema") + if isinstance(js, dict): + return bool(js.get("strict", False)) + return False + + def json_schema(self) -> dict[str, Any]: + if not isinstance(self._response_format, dict): + return {} + js = self._response_format.get("json_schema") + if isinstance(js, dict): + schema = js.get("schema") + if isinstance(schema, dict): + return schema + return {} + + def name(self) -> str: + if not isinstance(self._response_format, dict): + return "final_output" + js = self._response_format.get("json_schema") + if isinstance(js, dict): + name = js.get("name") + if isinstance(name, str): + return name + return "final_output" + + def validate_json(self, json_str: str) -> Any: + import json + + try: + return json.loads(json_str) + except json.JSONDecodeError as e: + raise ModelBehaviorError( + f"Failed to parse JSON output: {e}" + ) from e + + def _is_subclass_of_base_model_or_dict(t: Any) -> bool: # If it's a generic alias, 'origin' will be the actual type, e.g. 'list' origin = get_origin(t) diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..44b911119d 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -50,6 +50,7 @@ class ModelInputData: input: list[TResponseInputItem] instructions: str | None + response_format: Any | None = None @dataclass diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..4730478a33 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -29,7 +29,7 @@ get_tool_trace_name_for_tool, ) from ..agent import Agent -from ..agent_output import AgentOutputSchemaBase +from ..agent_output import AgentOutputSchemaBase, _ResponseFormatOutputSchema from ..exceptions import ( AgentsException, InputGuardrailTripwireTriggered, @@ -1376,6 +1376,8 @@ def _tool_search_fingerprint(raw_item: Any) -> str: ) if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) + if filtered.response_format is not None: + output_schema = _ResponseFormatOutputSchema(filtered.response_format) hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items) if isinstance(filtered.input, list): hosted_mcp_tool_metadata.update(collect_mcp_list_tools_metadata(filtered.input)) @@ -1760,7 +1762,7 @@ async def run_single_turn( else: input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy) - new_response = await get_new_response( + new_response, output_schema = await get_new_response( bindings, system_prompt, input, @@ -1811,8 +1813,8 @@ async def get_new_response( session: Session | None = None, session_items_to_rewind: list[TResponseInputItem] | None = None, prompt_cache_key_resolver: PromptCacheKeyResolver | None = None, -) -> ModelResponse: - """Call the model and return the raw response, handling retries and hooks.""" +) -> tuple[ModelResponse, AgentOutputSchemaBase | None]: + """Call the model and return the raw response and (possibly updated) output schema.""" public_agent = bindings.public_agent execution_agent = bindings.execution_agent filtered = await maybe_filter_model_input( @@ -1824,6 +1826,8 @@ async def get_new_response( ) if isinstance(filtered.input, list): filtered.input = deduplicate_input_items_preferring_latest(filtered.input) + if filtered.response_format is not None: + output_schema = _ResponseFormatOutputSchema(filtered.response_format) model = get_model(execution_agent, run_config) model_settings = get_model_settings(execution_agent, run_config) @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None: hooks.on_llm_end(context_wrapper, public_agent, new_response), ) - return new_response + return new_response, output_schema