Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/agents/agent_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,60 @@ def name(self) -> str:
return _type_to_str(self.output_type)


class _ResponseFormatOutputSchema(AgentOutputSchemaBase):
"""An ``AgentOutputSchemaBase`` that wraps a raw ``response_format`` dict.

This is used internally to allow ``CallModelInputFilter`` callbacks to
dynamically override the response format before a model call.
"""

def __init__(self, response_format: dict[str, Any]):
self._response_format = response_format

def is_plain_text(self) -> bool:
if not isinstance(self._response_format, dict):
return True
return self._response_format.get("type") == "text"

def is_strict_json_schema(self) -> bool:
if not isinstance(self._response_format, dict):
return False
js = self._response_format.get("json_schema")
if isinstance(js, dict):
return bool(js.get("strict", False))
return False

def json_schema(self) -> dict[str, Any]:
if not isinstance(self._response_format, dict):
return {}
js = self._response_format.get("json_schema")
if isinstance(js, dict):
schema = js.get("schema")
if isinstance(schema, dict):
return schema
return {}
Comment on lines +197 to +202
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve non-nested response_format schemas

When a filter supplies a raw response format that is not the Chat Completions nested shape, such as Responses-style {"type": "json_schema", "schema": ...} or JSON mode {"type": "json_object"}, this wrapper only looks under json_schema.schema and falls back to {}. The downstream converters then send an unconstrained json_schema request instead of the requested format, so dynamic structured-output filtering silently drops the schema for those supported response_format shapes.

Useful? React with 👍 / 👎.


def name(self) -> str:
if not isinstance(self._response_format, dict):
return "final_output"
js = self._response_format.get("json_schema")
if isinstance(js, dict):
name = js.get("name")
if isinstance(name, str):
return name
return "final_output"

def validate_json(self, json_str: str) -> Any:
import json

try:
return json.loads(json_str)
except json.JSONDecodeError as e:
raise ModelBehaviorError(
f"Failed to parse JSON output: {e}"
) from e


def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
origin = get_origin(t)
Expand Down
1 change: 1 addition & 0 deletions src/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ModelInputData:

input: list[TResponseInputItem]
instructions: str | None
response_format: Any | None = None


@dataclass
Expand Down
14 changes: 9 additions & 5 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_tool_trace_name_for_tool,
)
from ..agent import Agent
from ..agent_output import AgentOutputSchemaBase
from ..agent_output import AgentOutputSchemaBase, _ResponseFormatOutputSchema
from ..exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
Expand Down Expand Up @@ -1376,6 +1376,8 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
)
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)
if filtered.response_format is not None:
output_schema = _ResponseFormatOutputSchema(filtered.response_format)
hosted_mcp_tool_metadata = collect_mcp_list_tools_metadata(streamed_result._model_input_items)
if isinstance(filtered.input, list):
hosted_mcp_tool_metadata.update(collect_mcp_list_tools_metadata(filtered.input))
Expand Down Expand Up @@ -1760,7 +1762,7 @@ async def run_single_turn(
else:
input = _prepare_turn_input_items(original_input, generated_items, reasoning_item_id_policy)

new_response = await get_new_response(
new_response, output_schema = await get_new_response(
bindings,
system_prompt,
input,
Expand Down Expand Up @@ -1811,8 +1813,8 @@ async def get_new_response(
session: Session | None = None,
session_items_to_rewind: list[TResponseInputItem] | None = None,
prompt_cache_key_resolver: PromptCacheKeyResolver | None = None,
) -> ModelResponse:
"""Call the model and return the raw response, handling retries and hooks."""
) -> tuple[ModelResponse, AgentOutputSchemaBase | None]:
"""Call the model and return the raw response and (possibly updated) output schema."""
public_agent = bindings.public_agent
execution_agent = bindings.execution_agent
filtered = await maybe_filter_model_input(
Expand All @@ -1824,6 +1826,8 @@ async def get_new_response(
)
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)
if filtered.response_format is not None:
output_schema = _ResponseFormatOutputSchema(filtered.response_format)

model = get_model(execution_agent, run_config)
model_settings = get_model_settings(execution_agent, run_config)
Expand Down Expand Up @@ -1917,4 +1921,4 @@ async def rewind_model_request() -> None:
hooks.on_llm_end(context_wrapper, public_agent, new_response),
)

return new_response
return new_response, output_schema