From 7224caf3fc6e4a52e468eb607dda1e9e161a08cd Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 27 May 2026 17:55:20 -0700 Subject: [PATCH 1/2] Tool results: support multi-part tool content output The core goal here is to support image output from tools, so that we can support a read file tool that returns an image, or a screenshot tool, etc. We add a ContentOutput type that contains `list[TextPart|FilePart]` and is handled specially by the tool calling code. `ToolResultPart` gets a `result_kind` field that can signal the presence of `ContentOutput` while encoding/decoding. BACKWARDS COMPATABILITY: This is a compatability break, though, since it changes the serialized fields on ToolResultOutput (eliminating `is_error` and adding `result_kind`). --- examples/read_file_tool.py | 86 ++++++++ src/ai/__init__.py | 4 + src/ai/agents/agent.py | 18 +- src/ai/agents/ui/ai_sdk/inbound_messages.py | 10 +- src/ai/agents/ui/ai_sdk/outbound_messages.py | 23 +- src/ai/agents/ui/ai_sdk/outbound_stream.py | 17 +- src/ai/providers/ai_gateway/protocol.py | 68 ++++-- src/ai/providers/anthropic/protocol.py | 31 ++- src/ai/providers/openai/protocol.py | 79 ++++++- src/ai/types/builders.py | 45 +++- src/ai/types/integrity.py | 2 +- src/ai/types/messages.py | 202 ++++++++++++------ tests/agents/mcp/test_client.py | 3 +- tests/agents/test_aggregate_marker.py | 3 +- tests/agents/test_generator_tools.py | 12 +- tests/agents/test_runtime.py | 3 +- tests/agents/test_tools.py | 9 +- .../agents/ui/ai_sdk/test_inbound_messages.py | 1 - tests/providers/ai_gateway/test_protocol.py | 127 ++++++++++- .../anthropic/test_multipart_tool_result.py | 112 ++++++++++ .../openai/test_multipart_tool_result.py | 115 ++++++++++ tests/test_middleware.py | 3 +- tests/types/test_builders.py | 19 ++ tests/types/test_media.py | 17 ++ tests/types/test_messages.py | 105 +++++++++ 25 files changed, 988 insertions(+), 126 deletions(-) create mode 100644 examples/read_file_tool.py create mode 100644 tests/providers/anthropic/test_multipart_tool_result.py create mode 100644 tests/providers/openai/test_multipart_tool_result.py diff --git a/examples/read_file_tool.py b/examples/read_file_tool.py new file mode 100644 index 00000000..a875a9d4 --- /dev/null +++ b/examples/read_file_tool.py @@ -0,0 +1,86 @@ +"""Tool that returns a ContentOutput so the model can see image files directly. + +The ``read_file`` tool reads a path from disk and inspects the bytes: + +* If the file is an image, it returns a :class:`ContentOutput` carrying + a summary line and an image :class:`FilePart`. All three providers + turn that into a real image content block on the next model turn, so + the model actually *sees* the picture. +* Otherwise it returns the decoded text -- a plain ``str`` result the + provider sends to the model verbatim. + +A single tool covers both code-reading and image-reading duties in an +agentic loop. +""" + +import asyncio +import json +import pathlib + +import ai +from ai.types import media + +# Restrict the tool to a directory we trust the model to roam in. +# `.resolve()` collapses symlinks so a path inside ALLOWED_ROOT cannot +# escape via a symlink that points elsewhere. +ALLOWED_ROOT = pathlib.Path(__file__).parent.resolve() + + +def _resolve_within_allowed(path: str) -> pathlib.Path: + resolved = pathlib.Path(path).resolve() + if not resolved.is_relative_to(ALLOWED_ROOT): + raise ValueError( + f"Refusing to read {path!r}: outside allowed root {ALLOWED_ROOT}" + ) + return resolved + + +@ai.tool +async def read_file(path: str) -> str | ai.messages.ContentOutput: + """Read a file from disk. + + Image files come back as a ContentOutput so the model can view them. + """ + data = _resolve_within_allowed(path).read_bytes() + image_type = media.detect_image_media_type(data) + if image_type is not None: + return ai.content_output( + f"Loaded {path} ({image_type}, {len(data)} bytes).", + ai.file_part(data, media_type=image_type), + ) + return data.decode("utf-8", errors="replace") + + +async def main() -> None: + model = ai.get_model("gateway:anthropic/claude-sonnet-4.6") + my_agent = ai.agent(tools=[read_file]) + + here = pathlib.Path(__file__).parent + image_path = here / "sample_image.jpg" + text_path = here / "agent_simple.py" + + messages = [ + ai.system_message( + "Use the read_file tool to inspect any files the user mentions." + ), + ai.user_message( + f"First read {image_path} and describe what you see in the " + f"picture. Then read {text_path} and summarize what the " + f"script does in one sentence." + ), + ] + + async with my_agent.run(model, messages) as stream: + async for event in stream: + if isinstance(event, ai.events.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.events.ToolEnd): + args = json.loads(event.tool_call.tool_args or "{}") + print(f"\n[read_file({args.get('path')!r})]") + elif isinstance(event, ai.events.StreamEnd): + print() + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 066d9aaa..6d6ab8f2 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -63,8 +63,10 @@ from .types import events, messages, tools from .types.builders import ( assistant_message, + content_output, file_part, system_message, + text_part, thinking, tool_message, tool_result_part, @@ -119,6 +121,7 @@ "agent", "assistant_message", "cancel_hook", + "content_output", "errors", # Submodules "events", @@ -137,6 +140,7 @@ "resolve_hook", "stream", "system_message", + "text_part", "thinking", "tool", "tool_message", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 90bfd449..3183d5b3 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -22,6 +22,7 @@ Any, ClassVar, Generic, + Literal, Protocol, Self, cast, @@ -58,6 +59,18 @@ def _unwrap_singleton_group(exc: BaseException) -> BaseException: return exc +def _result_kind(value: Any) -> Literal["json", "content"]: + """Tag a successful tool return value for ``ToolResultPart.result_kind``. + + A :class:`ContentOutput` becomes ``"content"`` (expanded into provider + multimodal blocks); everything else is ``"json"`` (the encoder sends a + ``str`` raw and JSON-encodes anything else). + """ + if isinstance(value, types.messages.ContentOutput): + return "content" + return "json" + + def _error_tool_result( exc: BaseException, *, @@ -75,7 +88,7 @@ def _error_tool_result( tool_call_id=tool_call_id, tool_name=tool_name, result=f"{type(unwrapped).__name__}: {unwrapped}", - is_error=True, + result_kind="error", ), exception=unwrapped, ) @@ -616,6 +629,7 @@ async def _real( tool_call_id=call.tool_call_id, tool_name=call.tool_name, result=result, + result_kind=_result_kind(result), ) part.set_model_input(model_input) return tool_result(part) @@ -1039,7 +1053,7 @@ def pending_tool_result( tool_call_id=tool_call_id, tool_name=tool_name, result=f"Pending on hook {hook.hook_id!r}", - is_error=True, + result_kind="error", is_hook_pending=True, ) msg = types.messages.Message(role="tool", parts=[part]) diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index c0c36631..0435b75f 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -85,8 +85,11 @@ def _build_result_part( output: Any, is_error: bool, ) -> messages_.ToolResultPart: + result: Any + result_kind: messages_.ResultKind if is_error: - result: Any = output + result = output + result_kind = "error" else: decoded = _decode_wire_output(output) result = ( @@ -94,11 +97,12 @@ def _build_result_part( if isinstance(decoded, MessageBundle) else _normalize_tool_result(decoded) ) + result_kind = "json" return messages_.ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, result=result, - is_error=is_error, + result_kind=result_kind, ) @@ -190,7 +194,7 @@ def _patch_pending_hook_aborts( tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, result=f"Pending on hook '{hook.hook_id}'", - is_error=True, + result_kind="error", is_hook_pending=True, ) ) diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index 4520e20d..98366cd5 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from typing import Any, cast from ....types import media @@ -107,6 +108,21 @@ def dedupe_tool_parts( return result +def _output_view( + part: messages_.ToolResultPart, +) -> tuple[str, dict[str, Any]]: + """Map a tool result to ``(state, field_updates)`` for the UI wire.""" + result = part.result + if isinstance(result, messages_.ContentOutput): + return "output-available", { + "output": [item.model_dump(mode="json") for item in result.value] + } + if part.is_error: + text = result if isinstance(result, str) else json.dumps(result) + return "output-error", {"error_text": text} + return "output-available", {"output": result} + + def merge_tool_results( ui_parts: list[ui_messages.UIMessagePart], tool_parts: list[messages_.Part], @@ -121,15 +137,12 @@ def merge_tool_results( continue case messages_.ToolResultPart(): tool_call_id = part.tool_call_id - state = "output-error" if part.is_error else "output-available" + state, field_updates = _output_view(part) updates = { "state": state, "result_provider_metadata": part.provider_metadata, + **field_updates, } - if part.is_error: - updates["error_text"] = str(part.result) - else: - updates["output"] = part.result case messages_.BuiltinToolReturnPart(): tool_call_id = part.tool_call_id updates = { diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index 288bfbe5..215174e4 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -21,11 +21,12 @@ def _tool_error_text(part: messages_.ToolResultPart) -> str: """Best-effort error text extraction from a failed tool result.""" - if isinstance(part.result, str) and part.result: - return part.result - if isinstance(part.result, dict): + result = part.result + if isinstance(result, str) and result: + return result + if isinstance(result, dict): for key in ("error", "message", "detail"): - value = part.result.get(key) + value = result.get(key) if isinstance(value, str) and value: return value return "Tool execution failed" @@ -403,7 +404,13 @@ def on_tool_result( ) ) else: - wire_output = _to_wire_output(part.result) + result = part.result + if isinstance(result, messages_.ContentOutput): + wire_output: Any = [ + item.model_dump(mode="json") for item in result.value + ] + else: + wire_output = _to_wire_output(result) if wire_output is None: # Aggregator produced no anchor (e.g. sub-agent # tool that yielded nothing). Skip the final diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index a19e8a8d..a16f55cc 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -65,6 +65,57 @@ def _file_part_to_wire(part: types.messages.FilePart) -> dict[str, Any]: return {"type": "file", "data": b64, "mediaType": part.media_type} +# --------------------------------------------------------------------------- +# Tool result output -> v3 wire +# --------------------------------------------------------------------------- + + +def _file_part_to_v3_inline(part: types.messages.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to an inline v3 content element. + + Images become ``image-data``; everything else becomes ``file-data``. + """ + b64 = types.media.data_to_base64(part.data) + if part.media_type.startswith("image/"): + return {"type": "image-data", "data": b64, "mediaType": part.media_type} + entry: dict[str, Any] = { + "type": "file-data", + "data": b64, + "mediaType": part.media_type, + } + if part.filename is not None: + entry["filename"] = part.filename + return entry + + +def _tool_result_output( + part: types.messages.ToolResultPart, +) -> dict[str, Any]: + """Convert a tool result to its v3 ``output`` wire form. + + The v3 protocol carries a tagged output union. A :class:`ContentOutput` + becomes ``content``; an error result becomes ``error-text`` (for a + ``str``) or ``error-json``; otherwise ``text`` (for a ``str``) or + ``json``. The text-vs-json call is made here, at the wire boundary. + """ + value = part.get_model_input() + if isinstance(value, types.messages.ContentOutput): + parts: list[dict[str, Any]] = [] + for item in value.value: + if isinstance(item, types.messages.FilePart): + parts.append(_file_part_to_v3_inline(item)) + else: + parts.append({"type": "text", "text": item.text}) + return {"type": "content", "value": parts} + if part.is_error: + if value is None or isinstance(value, str): + return {"type": "error-text", "value": value or ""} + return {"type": "error-json", "value": value} + if value is None or isinstance(value, str): + return {"type": "text", "value": value or ""} + return {"type": "json", "value": value} + + # --------------------------------------------------------------------------- # Streaming request building — Message list → v3 prompt # --------------------------------------------------------------------------- @@ -172,22 +223,7 @@ async def _messages_to_prompt( tool_results: list[dict[str, Any]] = [] for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() - output = ( - { - "type": "error-text", - "value": ( - str(model_input) - if model_input is not None - else "" - ), - } - if part.is_error - else { - "type": "json", - "value": model_input, - } - ) + output = _tool_result_output(part) tool_results.append( { "type": "tool-result", diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 05351b21..2b20339d 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -179,6 +179,29 @@ def _file_part_to_anthropic( raise ValueError(f"Unsupported media type for Anthropic: {mt}") +def _tool_result_to_anthropic(value: Any) -> str | list[dict[str, Any]]: + """Convert a tool result's model-facing value to Anthropic content. + + A :class:`ContentOutput` expands into Anthropic content blocks + (image/document) so the model sees actual media. Everything else is + sent as a string (the Anthropic API accepts a string as tool_result + content): ``str`` raw, ``None`` as ``""``, anything else JSON-encoded. + """ + if isinstance(value, types.messages.ContentOutput): + blocks: list[dict[str, Any]] = [] + for item in value.value: + if isinstance(item, types.messages.FilePart): + blocks.append(_file_part_to_anthropic(item)) + else: + blocks.append({"type": "text", "text": item.text}) + return blocks + if value is None: + return "" + if isinstance(value, str): + return value + return json.dumps(value, separators=(",", ":"), default=str) + + async def _messages_to_anthropic( messages: list[types.messages.Message], ) -> tuple[str | None, list[dict[str, Any]]]: @@ -271,13 +294,13 @@ async def _messages_to_anthropic( tool_results: list[dict[str, Any]] = [] for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() + tool_content = _tool_result_to_anthropic( + part.get_model_input() + ) entry: dict[str, Any] = { "type": "tool_result", "tool_use_id": part.tool_call_id, - "content": str(model_input) - if model_input is not None - else "", + "content": tool_content, } if part.is_error: entry["is_error"] = True diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index 41be970e..d4cfa8ab 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -108,6 +108,38 @@ async def _file_part_to_openai( raise ValueError(f"Unsupported media type for OpenAI: {mt}") +def _tool_result_to_openai(value: Any) -> str | list[dict[str, Any]]: + """Convert a tool result's model-facing value to OpenAI content (chat). + + A :class:`ContentOutput` expands into a content array with ``text`` + and ``image_url`` parts. Everything else is sent as a string: ``str`` + raw, ``None`` as ``""``, anything else JSON-encoded. + """ + if isinstance(value, types.messages.ContentOutput): + parts: list[dict[str, Any]] = [] + for item in value.value: + if isinstance(item, types.messages.FilePart): + mt = item.media_type + if mt.startswith("image/"): + data_url = types.media.data_to_data_url(item.data, mt) + parts.append( + { + "type": "image_url", + "image_url": {"url": data_url}, + } + ) + else: + parts.append({"type": "text", "text": f"[file: {mt}]"}) + else: + parts.append({"type": "text", "text": item.text}) + return parts + if value is None: + return "" + if isinstance(value, str): + return value + return _json_dumps(value) + + async def _messages_to_openai( messages: list[types.messages.Message], ) -> list[dict[str, Any]]: @@ -165,14 +197,14 @@ async def _messages_to_openai( case "tool": for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() + tool_content = _tool_result_to_openai( + part.get_model_input() + ) result.append( { "role": "tool", "tool_call_id": part.tool_call_id, - "content": str(model_input) - if model_input is not None - else "", + "content": tool_content, } ) @@ -511,12 +543,39 @@ def _raw_item_from_metadata(part: Any) -> dict[str, Any] | None: return None -def _stringify_tool_result(result: Any) -> str: - if result is None: +def _tool_result_to_responses(value: Any) -> str | list[dict[str, Any]]: + """Convert a tool result's model-facing value to a Responses ``output``. + + Returns a plain string for ordinary values (``str`` raw, ``None`` as + ``""``, anything else JSON-encoded), or an array of ``input_text`` / + ``input_image`` / ``input_file`` parts for a :class:`ContentOutput` + (the Responses API accepts both shapes on ``function_call_output.output``). + """ + if isinstance(value, types.messages.ContentOutput): + parts: list[dict[str, Any]] = [] + for item in value.value: + if isinstance(item, types.messages.FilePart): + data_url = types.media.data_to_data_url( + item.data, item.media_type + ) + if item.media_type.startswith("image/"): + parts.append({"type": "input_image", "image_url": data_url}) + else: + entry: dict[str, Any] = { + "type": "input_file", + "file_data": data_url, + } + if item.filename is not None: + entry["filename"] = item.filename + parts.append(entry) + else: + parts.append({"type": "input_text", "text": item.text}) + return parts + if value is None: return "" - if isinstance(result, str): - return result - return _json_dumps(result) + if isinstance(value, str): + return value + return _json_dumps(value) async def _file_part_to_responses( @@ -650,7 +709,7 @@ async def _messages_to_responses( { "type": "function_call_output", "call_id": part.tool_call_id, - "output": _stringify_tool_result( + "output": _tool_result_to_responses( part.get_model_input() ), } diff --git a/src/ai/types/builders.py b/src/ai/types/builders.py index eeb83c10..99eb0896 100644 --- a/src/ai/types/builders.py +++ b/src/ai/types/builders.py @@ -13,11 +13,14 @@ from . import events as events_ from .messages import ( + ContentOutput, + ContentPart, FilePart, HookPart, Message, Part, ReasoningPart, + ResultKind, TextPart, ToolCallPart, ToolResultPart, @@ -89,6 +92,35 @@ def file_part( return FilePart.from_bytes(data, media_type=media_type, filename=filename) +def text_part( + text: str, + *, + provider_metadata: dict[str, Any] | None = None, +) -> TextPart: + """Create a :class:`TextPart`. + + Bare strings passed to the ``*_message`` builders are coerced into + text parts automatically; reach for this when you need to attach + ``provider_metadata`` or build a part list directly. + """ + return TextPart(text=text, provider_metadata=provider_metadata) + + +def content_output(*content: str | TextPart | FilePart) -> ContentOutput: + """Create a multipart :class:`ContentOutput` tool result. + + Bare strings become :class:`TextPart` objects, mirroring the + ``*_message`` builders, so a tool can return mixed text and files + without constructing the part list by hand. + + >>> ai.content_output("Here is the chart:", ai.file_part(png_bytes)) + """ + parts: list[ContentPart] = [] + for item in content: + parts.append(TextPart(text=item) if isinstance(item, str) else item) + return ContentOutput(value=parts) + + def thinking( text: str, *, @@ -208,11 +240,22 @@ def tool_result_part( ) -> ToolResultPart: """Create a :class:`ToolResultPart`. + ``result`` is stored as-is; ``result_kind`` is derived: ``"error"`` when + ``is_error`` is set, ``"content"`` for a :class:`ContentOutput`, else + ``"json"`` (a ``str`` is sent raw to the model, anything else is + JSON-encoded at the provider boundary). + >>> ai.tool_result_part("tc-1", result={"temp": 72}, tool_name="weather") """ + if is_error: + result_kind: ResultKind = "error" + elif isinstance(result, ContentOutput): + result_kind = "content" + else: + result_kind = "json" return ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, result=result, - is_error=is_error, + result_kind=result_kind, ) diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py index a5328cc4..dcc2d402 100644 --- a/src/ai/types/integrity.py +++ b/src/ai/types/integrity.py @@ -180,7 +180,7 @@ def _flush_pending() -> None: tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, result="Tool result not available", - is_error=True, + result_kind="error", ) for tc in pending.values() ) diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 217b4287..6c61b1c3 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -1,3 +1,4 @@ +import base64 import uuid from typing import Annotated, Any, Literal, Self, overload @@ -21,19 +22,127 @@ class TextPart(pydantic.BaseModel): kind: Literal["text"] = "text" +class FilePart(pydantic.BaseModel): + """File, image, or audio content part. + + Covers images (``image/*``), documents (``application/pdf``, ``text/*``), + and audio (``audio/*``). The ``media_type`` field tells provider + converters how to format this part for each API. + + ``data`` accepts: + + * **str** -- a URL (``http(s)://...`` or ``data:...``) *or* raw + base-64 text. + * **bytes** -- raw binary data (will be base-64 encoded when serialized + to JSON for providers that need it). + """ + + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) + data: str | bytes + media_type: str # IANA media type, e.g. "image/png", "audio/wav" + filename: str | None = None + kind: Literal["file"] = "file" + provider_metadata: dict[str, Any] | None = None + + @pydantic.field_serializer("data", when_used="json") + @classmethod + def _serialize_data(cls, v: str | bytes, _info: Any) -> str: + """Encode ``bytes`` as standard base-64 for JSON serialization. + + Pydantic's built-in ``ser_json_bytes`` uses URL-safe base-64 + (``-`` and ``_``) which LLM provider APIs reject. This + serializer uses standard base-64 (``+`` and ``/``) instead. + ``str`` values (URLs, existing base-64) pass through unchanged. + """ + if isinstance(v, bytes): + return base64.b64encode(v).decode("ascii") + return v + + @classmethod + def from_url(cls, url: str, *, media_type: str | None = None) -> Self: + """Create from a URL, inferring ``media_type`` from the URL if omitted. + + Inference handles ``data:`` URLs (the media type is embedded in the + prefix) and ``http(s)://`` URLs (via :func:`mimetypes.guess_type`). + Raises :class:`ValueError` if inference fails and no explicit + ``media_type`` is provided. + """ + if media_type is None: + media_type = media.infer_media_type(url) + return cls(data=url, media_type=media_type) + + @classmethod + def from_bytes( + cls, + data: bytes, + *, + media_type: str | None = None, + filename: str | None = None, + ) -> Self: + """Create from raw bytes, detecting ``media_type`` via magic bytes. + + Attempts image detection first, then audio. Raises + :class:`ValueError` if no ``media_type`` is provided and + detection fails. + """ + if media_type is None: + media_type = media.detect_image_media_type( + data + ) or media.detect_audio_media_type(data) + if media_type is None: + raise ValueError( + "Cannot detect media_type from bytes. " + "Provide media_type explicitly." + ) + return cls(data=data, media_type=media_type, filename=filename) + + +# --------------------------------------------------------------------------- +# Multipart tool result -- a tool may return a mix of text and file/image +# parts so the model sees actual media. Stored on ``ToolResultPart.result`` +# with ``result_kind="content"``; providers expand it into their multimodal +# wire format. +# --------------------------------------------------------------------------- + + +ContentPart = Annotated[ + TextPart | FilePart, + pydantic.Field(discriminator="kind"), +] + + +class ContentOutput(pydantic.BaseModel): + """Multipart tool result -- mix of text and file/image parts.""" + + type: Literal["content"] = "content" + value: list[ContentPart] + + model_config = pydantic.ConfigDict(frozen=True) + + _MODEL_INPUT_UNSET: Any = object() +# Coarse tag for the shape of ``ToolResultPart.result``. ``"content"`` means +# a :class:`ContentOutput`; ``"error"`` flags an error result; ``"json"`` (the +# default) is any plain value. Providers decide text-vs-json at the wire +# boundary (a ``str`` is sent raw, everything else is JSON-encoded). +ResultKind = Literal["error", "json", "content"] + class ToolResultPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str - is_error: bool = False is_hook_pending: bool = False provider_metadata: dict[str, Any] | None = None - # The "real" result of the tool call + # The "real" result of the tool call. Stays ``Any``: a plain value + # (str, dict, BaseModel, ...), a :class:`ContentOutput` for multipart + # results, or an aggregator snapshot. ``result_kind`` tags its shape. result: Any = None + result_kind: ResultKind = "json" # Value the LLM sees on its next turn. For most tools this is # identical to ``result``; for aggregator-backed tools (sub-agents, @@ -50,6 +159,32 @@ class ToolResultPart(pydantic.BaseModel): kind: Literal["tool_result"] = "tool_result" model_config = pydantic.ConfigDict(frozen=True) + @pydantic.model_validator(mode="before") + @classmethod + def _restore_content(cls, data: Any) -> Any: + """Rebuild a typed :class:`ContentOutput` after a JSON round-trip. + + ``result`` is ``Any``, so pydantic restores a serialized + ``ContentOutput`` as a plain dict. When ``result_kind`` says the + result is content, coerce it back so providers (and the UI adapter) + can rely on ``isinstance(result, ContentOutput)``. + """ + if ( + isinstance(data, dict) + and data.get("result_kind") == "content" + and isinstance(data.get("result"), dict) + ): + data = { + **data, + "result": ContentOutput.model_validate(data["result"]), + } + return data + + @property + def is_error(self) -> bool: + """Whether this result represents an error to the model.""" + return self.result_kind == "error" + def get_model_input(self) -> Any: """Return the value the LLM should see, falling back to ``result``.""" if self._model_input is _MODEL_INPUT_UNSET: @@ -141,69 +276,6 @@ class HookPart[T](pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) -class FilePart(pydantic.BaseModel): - """File, image, or audio content part. - - Covers images (``image/*``), documents (``application/pdf``, ``text/*``), - and audio (``audio/*``). The ``media_type`` field tells provider - converters how to format this part for each API. - - ``data`` accepts: - - * **str** -- a URL (``http(s)://...`` or ``data:...``) *or* raw - base-64 text. - * **bytes** -- raw binary data (will be base-64 encoded when serialized - to JSON for providers that need it). - """ - - model_config = pydantic.ConfigDict(frozen=True) - - id: str = pydantic.Field(default_factory=lambda: generate_id("part")) - data: str | bytes - media_type: str # IANA media type, e.g. "image/png", "audio/wav" - filename: str | None = None - kind: Literal["file"] = "file" - provider_metadata: dict[str, Any] | None = None - - @classmethod - def from_url(cls, url: str, *, media_type: str | None = None) -> Self: - """Create from a URL, inferring ``media_type`` from the URL if omitted. - - Inference handles ``data:`` URLs (the media type is embedded in the - prefix) and ``http(s)://`` URLs (via :func:`mimetypes.guess_type`). - Raises :class:`ValueError` if inference fails and no explicit - ``media_type`` is provided. - """ - if media_type is None: - media_type = media.infer_media_type(url) - return cls(data=url, media_type=media_type) - - @classmethod - def from_bytes( - cls, - data: bytes, - *, - media_type: str | None = None, - filename: str | None = None, - ) -> Self: - """Create from raw bytes, detecting ``media_type`` via magic bytes. - - Attempts image detection first, then audio. Raises - :class:`ValueError` if no ``media_type`` is provided and - detection fails. - """ - if media_type is None: - media_type = media.detect_image_media_type( - data - ) or media.detect_audio_media_type(data) - if media_type is None: - raise ValueError( - "Cannot detect media_type from bytes. " - "Provide media_type explicitly." - ) - return cls(data=data, media_type=media_type, filename=filename) - - Part = Annotated[ TextPart | ToolCallPart diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 13a73cc7..f6db7b0c 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -137,7 +137,8 @@ async def fake_fn(**kwargs: str) -> str: # Tool result is visible in messages. tool_results = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_results) >= 1 - assert tool_results[0].tool_results[0].result == "echoed: hello" + tr = tool_results[0].tool_results[0].result + assert tr == "echoed: hello" # LLM was called twice (tool call + final text). assert llm.call_count == 2 diff --git a/tests/agents/test_aggregate_marker.py b/tests/agents/test_aggregate_marker.py index e0fe90a2..22683111 100644 --- a/tests/agents/test_aggregate_marker.py +++ b/tests/agents/test_aggregate_marker.py @@ -153,4 +153,5 @@ async def test_alias_declared_tool_runs_end_to_end() -> None: tool_results = [ e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] - assert tool_results[0].results[0].result == "Answer for test" + tr = tool_results[0].results[0].result + assert tr == "Answer for test" diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 0f2bc9d6..64136208 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -72,7 +72,8 @@ async def test_generator_tool_streams_and_returns_result() -> None: e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] assert len(tool_results) >= 1 - assert tool_results[0].results[0].result == "Answer for test" + tr = tool_results[0].results[0].result + assert tr == "Answer for test" # --------------------------------------------------------------------------- @@ -179,13 +180,14 @@ async def test_yield_from_nested_agent() -> None: tool_results = [ e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] - # MessageAggregator stores the rich MessageBundle as `result` and the - # extracted assistant text as the model input (the value the parent - # LLM sees on its next turn). + # MessageAggregator stores the rich MessageBundle as ``result`` and the + # extracted assistant text as the model input the parent LLM sees on its + # next turn. sub_part = tool_results[0].results[0] assert isinstance(sub_part.result, MessageBundle) assert sub_part.result.messages[0].text == "Mars has two moons." - assert sub_part.get_model_input() == "Mars has two moons." + model_input = sub_part.get_model_input() + assert model_input == "Mars has two moons." # The outer LLM's second call (index 2) must NOT contain any inner # agent messages. It should only see: the original user message, diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 9a4096b9..ec71fe36 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -60,7 +60,8 @@ async def test_agent_tool_then_text() -> None: assert llm.call_count == 2 tool_results = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_results) >= 1 - assert tool_results[0].tool_results[0].result == 10 + tr = tool_results[0].tool_results[0].result + assert tr == 10 # -- Agent default loop: multiple tool calls in one message ---------------- diff --git a/tests/agents/test_tools.py b/tests/agents/test_tools.py index eee0b332..f28650f0 100644 --- a/tests/agents/test_tools.py +++ b/tests/agents/test_tools.py @@ -75,7 +75,8 @@ async def add(a: int, b: int) -> int: tool_args='{"a": 1, "b": 2}', ) result = await ai.agents.BoundToolCall(part=part, tool=add)() - assert result.results[0].result == 3 + out = result.results[0].result + assert out == 3 # -- ToolCall binds a ToolCallPart to a Tool and returns tool messages ---- @@ -101,7 +102,8 @@ async def double(x: int) -> int: assert len(result.results) == 1 assert result.results[0].tool_call_id == "tc-1" assert result.results[0].tool_name == "double" - assert result.results[0].result == 10 + out = result.results[0].result + assert out == 10 assert not result.results[0].is_error @@ -175,7 +177,8 @@ async def double(x: int) -> int: result = await tc(x=7) - assert result.results[0].result == 14 + out = result.results[0].result + assert out == 14 async def test_tool_call_override_validation_failure() -> None: diff --git a/tests/agents/ui/ai_sdk/test_inbound_messages.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py index f2d2abbc..a5185309 100644 --- a/tests/agents/ui/ai_sdk/test_inbound_messages.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -209,7 +209,6 @@ def test_to_messages_passthrough_keeps_wire_shape() -> None: tool_msgs = [m for m in messages if m.role == "tool"] part = tool_msgs[0].tool_results[0] assert part.result == {"pong": True} - assert part.get_model_input() == {"pong": True} def test_to_messages_accepts_metadata_and_ui_only_parts() -> None: diff --git a/tests/providers/ai_gateway/test_protocol.py b/tests/providers/ai_gateway/test_protocol.py index 5c881bc2..64322c08 100644 --- a/tests/providers/ai_gateway/test_protocol.py +++ b/tests/providers/ai_gateway/test_protocol.py @@ -125,7 +125,7 @@ async def test_tool_error_result(self) -> None: tool_call_id="tc-1", tool_name="get_weather", result="Connection timeout", - is_error=True, + result_kind="error", ) ], ), @@ -389,3 +389,128 @@ def test_non_dict_returns_empty(self) -> None: usage = protocol._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 + + +# --------------------------------------------------------------------------- +# Multi-part tool result helpers +# --------------------------------------------------------------------------- + + +class TestFilePartToV3Inline: + def test_image_data(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + entry = protocol._file_part_to_v3_inline(fp) + assert entry == { + "type": "image-data", + "data": "b64data", + "mediaType": "image/png", + } + + def test_file_data_with_filename(self) -> None: + fp = messages.FilePart( + data="pdfdata", + media_type="application/pdf", + filename="doc.pdf", + ) + entry = protocol._file_part_to_v3_inline(fp) + assert entry["type"] == "file-data" + assert entry["mediaType"] == "application/pdf" + assert entry["filename"] == "doc.pdf" + + def test_bytes_become_base64(self) -> None: + fp = messages.FilePart(data=b"\x89PNG", media_type="image/png") + entry = protocol._file_part_to_v3_inline(fp) + assert entry["type"] == "image-data" + assert entry["data"] != "" + + +class TestToolResultOutput: + @staticmethod + def _part( + result: object, *, result_kind: messages.ResultKind = "json" + ) -> messages.ToolResultPart: + return messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="t", + result=result, + result_kind=result_kind, + ) + + def test_text(self) -> None: + result = protocol._tool_result_output(self._part("hi")) + assert result == {"type": "text", "value": "hi"} + + def test_json(self) -> None: + result = protocol._tool_result_output(self._part({"key": "value"})) + assert result == {"type": "json", "value": {"key": "value"}} + + def test_error_text(self) -> None: + result = protocol._tool_result_output( + self._part("oops", result_kind="error") + ) + assert result == {"type": "error-text", "value": "oops"} + + def test_error_json(self) -> None: + result = protocol._tool_result_output( + self._part({"code": 500}, result_kind="error") + ) + assert result == {"type": "error-json", "value": {"code": 500}} + + def test_content_multipart(self) -> None: + fp = messages.FilePart(data="b64", media_type="image/jpeg") + result = protocol._tool_result_output( + self._part( + messages.ContentOutput( + value=[messages.TextPart(text="desc"), fp] + ), + result_kind="content", + ) + ) + assert result["type"] == "content" + assert result["value"][0] == {"type": "text", "text": "desc"} + assert result["value"][1]["type"] == "image-data" + + +class TestMessagesToPromptMultipart: + async def test_tool_result_with_file_part(self) -> None: + """ContentOutput with a FilePart uses the 'content' wire output.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + ) + ], + ), + ] + result = await protocol._messages_to_prompt(msgs) + tr = result[1]["content"][0] + assert tr["output"]["type"] == "content" + assert tr["output"]["value"][0] == { + "type": "text", + "text": "Image loaded", + } + assert tr["output"]["value"][1] == { + "type": "image-data", + "data": "iVBOR", + "mediaType": "image/png", + } diff --git a/tests/providers/anthropic/test_multipart_tool_result.py b/tests/providers/anthropic/test_multipart_tool_result.py new file mode 100644 index 00000000..e68fd455 --- /dev/null +++ b/tests/providers/anthropic/test_multipart_tool_result.py @@ -0,0 +1,112 @@ +"""Tests for multi-part tool results in the Anthropic protocol.""" + +from __future__ import annotations + +from ai.providers.anthropic import protocol +from ai.types import messages + + +class TestToolResultToAnthropic: + def test_str_value(self) -> None: + result = protocol._tool_result_to_anthropic("hello") + assert result == "hello" + + def test_none_value(self) -> None: + result = protocol._tool_result_to_anthropic(None) + assert result == "" + + def test_dict_value(self) -> None: + result = protocol._tool_result_to_anthropic({"key": "value"}) + assert result == '{"key":"value"}' + + def test_list_value(self) -> None: + result = protocol._tool_result_to_anthropic([1, 2, 3]) + assert result == "[1,2,3]" + + def test_error_str_value(self) -> None: + # An error result is just its (string) value; ``is_error`` rides the + # ``tool_result`` block's flag, set by the caller. + result = protocol._tool_result_to_anthropic("boom") + assert result == "boom" + + def test_content_text_and_file(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput( + value=[messages.TextPart(text="Image loaded"), fp] + ) + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"type": "text", "text": "Image loaded"} + assert result[1]["type"] == "image" + assert result[1]["source"]["type"] == "base64" + assert result[1]["source"]["media_type"] == "image/png" + assert result[1]["source"]["data"] == "b64data" + + def test_content_file_only(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/jpeg") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput(value=[fp]) + ) + assert isinstance(result, list) + assert result[0]["type"] == "image" + assert result[0]["source"]["media_type"] == "image/jpeg" + + def test_content_bytes_file(self) -> None: + fp = messages.FilePart(data=b"\x89PNG", media_type="image/png") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput(value=[messages.TextPart(text="desc"), fp]) + ) + assert isinstance(result, list) + assert result[1]["type"] == "image" + assert result[1]["source"]["data"] != "" + + +class TestMessagesToAnthropicMultipart: + async def test_tool_result_with_file_part(self) -> None: + """FilePart in tool results produces image content blocks.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Read image")], + ), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + result_kind="content", + ) + ], + ), + ] + _, result = await protocol._messages_to_anthropic(msgs) + tool_msg = result[-1] + assert tool_msg["role"] == "user" + tr = tool_msg["content"][0] + assert tr["type"] == "tool_result" + content = tr["content"] + assert isinstance(content, list) + assert content[0] == {"type": "text", "text": "Image loaded"} + assert content[1]["type"] == "image" + assert content[1]["source"]["type"] == "base64" + assert content[1]["source"]["media_type"] == "image/png" diff --git a/tests/providers/openai/test_multipart_tool_result.py b/tests/providers/openai/test_multipart_tool_result.py new file mode 100644 index 00000000..40f6687d --- /dev/null +++ b/tests/providers/openai/test_multipart_tool_result.py @@ -0,0 +1,115 @@ +"""Tests for multi-part tool results in the OpenAI protocol.""" + +from __future__ import annotations + +from ai.providers.openai import protocol +from ai.types import messages + + +class TestToolResultToOpenai: + def test_str_value(self) -> None: + result = protocol._tool_result_to_openai("hello") + assert result == "hello" + + def test_none_value(self) -> None: + result = protocol._tool_result_to_openai(None) + assert result == "" + + def test_dict_value(self) -> None: + result = protocol._tool_result_to_openai({"key": "value"}) + assert result == '{"key":"value"}' + + def test_list_value(self) -> None: + result = protocol._tool_result_to_openai([1, 2, 3]) + assert result == "[1,2,3]" + + def test_error_str_value(self) -> None: + result = protocol._tool_result_to_openai("boom") + assert result == "boom" + + def test_content_text_and_image(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + result = protocol._tool_result_to_openai( + messages.ContentOutput( + value=[messages.TextPart(text="Image loaded"), fp] + ) + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"type": "text", "text": "Image loaded"} + assert result[1]["type"] == "image_url" + assert result[1]["image_url"]["url"].startswith( + "data:image/png;base64," + ) + assert "b64data" in result[1]["image_url"]["url"] + + def test_content_image_only(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/jpeg") + result = protocol._tool_result_to_openai( + messages.ContentOutput(value=[fp]) + ) + assert isinstance(result, list) + assert result[0]["type"] == "image_url" + assert result[0]["image_url"]["url"].startswith( + "data:image/jpeg;base64," + ) + + def test_content_non_image_file(self) -> None: + fp = messages.FilePart(data="pdfdata", media_type="application/pdf") + result = protocol._tool_result_to_openai( + messages.ContentOutput(value=[messages.TextPart(text="desc"), fp]) + ) + assert isinstance(result, list) + assert result[1] == {"type": "text", "text": "[file: application/pdf]"} + + +class TestMessagesToOpenaiMultipart: + async def test_tool_result_with_file_part(self) -> None: + """ContentOutput with a FilePart produces image_url parts.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="system", + parts=[messages.TextPart(text="System")], + ), + messages.Message( + role="user", + parts=[messages.TextPart(text="Read image")], + ), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + result_kind="content", + ) + ], + ), + ] + result = await protocol._messages_to_openai(msgs) + tool_msg = result[-1] + assert tool_msg["role"] == "tool" + content = tool_msg["content"] + assert isinstance(content, list) + assert content[0] == {"type": "text", "text": "Image loaded"} + assert content[1]["type"] == "image_url" + assert content[1]["image_url"]["url"].startswith( + "data:image/png;base64," + ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index ba82e78a..a2576227 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -275,5 +275,6 @@ async def double(x: int) -> int: assert len(tool_result_msgs) >= 1 # The fixer middleware supplied x=99, so double should return 198. - assert tool_result_msgs[0].tool_results[0].result == 198 + tr = tool_result_msgs[0].tool_results[0].result + assert tr == 198 assert tool_result_msgs[0].tool_results[0].is_error is False diff --git a/tests/types/test_builders.py b/tests/types/test_builders.py index 2aae1275..30b76d1b 100644 --- a/tests/types/test_builders.py +++ b/tests/types/test_builders.py @@ -20,6 +20,25 @@ def test_user_message_mixed_content() -> None: assert isinstance(msg.parts[2], messages.TextPart) +def test_text_part() -> None: + tp = builders.text_part("hello", provider_metadata={"k": "v"}) + assert isinstance(tp, messages.TextPart) + assert tp.text == "hello" + assert tp.provider_metadata == {"k": "v"} + + +def test_content_output_coerces_strings() -> None: + fp = messages.FilePart( + data="https://example.com/img.png", media_type="image/png" + ) + out = builders.content_output("Here:", fp) + assert isinstance(out, messages.ContentOutput) + assert len(out.value) == 2 + assert isinstance(out.value[0], messages.TextPart) + assert out.value[0].text == "Here:" + assert isinstance(out.value[1], messages.FilePart) + + def test_file_part_from_url() -> None: fp = builders.file_part("https://example.com/image.png") assert isinstance(fp, messages.FilePart) diff --git a/tests/types/test_media.py b/tests/types/test_media.py index 39ef8f4f..56afab9d 100644 --- a/tests/types/test_media.py +++ b/tests/types/test_media.py @@ -73,3 +73,20 @@ def test_empty_or_short_media_returns_none() -> None: assert media.detect_audio_media_type(b"") is None assert media.detect_image_media_type(bytes([0x89])) is None assert media.detect_audio_media_type(bytes([0xFF])) is None + + +def test_data_to_base64_bytes_produces_standard() -> None: + """bytes input produces standard base-64 (+ and /).""" + data = b"\xff\xd8\xff\xe0" # JPEG header + result = media.data_to_base64(data) + decoded = base64.b64decode(result) + assert decoded == data + # Must be standard base-64, not URL-safe + assert "-" not in result + assert "_" not in result + + +def test_data_to_base64_str_passthrough() -> None: + """Standard base-64 string passes through unchanged.""" + standard = "/9j/4AAQSkZJRg==" + assert media.data_to_base64(standard) == standard diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index f684f884..03806b69 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + import pytest from ai.types import messages, usage @@ -112,3 +114,106 @@ def test_from_bytes_explicit_overrides() -> None: def test_from_bytes_unknown_raises() -> None: with pytest.raises(ValueError, match="Cannot detect media_type"): messages.FilePart.from_bytes(b"\x00\x01\x02\x03") + + +# --------------------------------------------------------------------------- +# ToolResultPart -- typed result coercion and round-trip +# --------------------------------------------------------------------------- + + +def test_tool_result_content_output_with_file_part_round_trip() -> None: + """FilePart inside ContentOutput survives JSON round-trip.""" + fp = messages.FilePart(data=b"fake-image-data", media_type="image/png") + trp = messages.ToolResultPart( + tool_call_id="tc1", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="label"), fp] + ), + result_kind="content", + ) + j = trp.model_dump_json() + restored = messages.ToolResultPart.model_validate_json(j) + assert isinstance(restored.result, messages.ContentOutput) + assert len(restored.result.value) == 2 + text_part, file_part = restored.result.value + assert isinstance(text_part, messages.TextPart) + assert text_part.text == "label" + assert isinstance(file_part, messages.FilePart) + assert file_part.media_type == "image/png" + + +def test_tool_result_plain_values_stored_raw() -> None: + """Plain str / dict / list / None results are stored as-is and round-trip. + + ``result`` is ``Any`` -- there is no wrapper type, so a tool's return + value lands on the part unchanged and survives a JSON round-trip. + """ + cases: list[Any] = ["hello", None, [1, 2, 3], {"key": "val"}] + for raw in cases: + trp = messages.ToolResultPart( + tool_call_id="tc", tool_name="t", result=raw + ) + assert trp.result == raw + assert trp.result_kind == "json" + restored = messages.ToolResultPart.model_validate_json( + trp.model_dump_json() + ) + assert restored.result == raw + assert restored.result_kind == "json" + + +def test_tool_result_content_in_message_round_trip() -> None: + """ContentOutput with a FilePart survives Message round-trip.""" + fp = messages.FilePart(data=b"img-data", media_type="image/webp") + msg = messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="Read image"), fp] + ), + result_kind="content", + ) + ], + ) + j = msg.model_dump_json() + restored = messages.Message.model_validate_json(j) + part = restored.parts[0] + assert isinstance(part, messages.ToolResultPart) + assert isinstance(part.result, messages.ContentOutput) + fp2 = part.result.value[1] + assert isinstance(fp2, messages.FilePart) + assert fp2.media_type == "image/webp" + + +def test_tool_result_file_part_base64_valid_after_round_trip() -> None: + """After round-trip, data_to_base64 produces standard base-64.""" + import base64 + + from ai.types import media as media_ + + raw = b"\xff\xd8\xff\xe0\x00\x10JFIF" * 10 + fp = messages.FilePart(data=raw, media_type="image/jpeg") + trp = messages.ToolResultPart( + tool_call_id="tc", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="label"), fp] + ), + result_kind="content", + ) + restored = messages.ToolResultPart.model_validate_json( + trp.model_dump_json() + ) + assert isinstance(restored.result, messages.ContentOutput) + fp2 = restored.result.value[1] + assert isinstance(fp2, messages.FilePart) + + b64 = media_.data_to_base64(fp2.data) + assert "_" not in b64 + assert "-" not in b64 + decoded = base64.b64decode(b64) + assert decoded == raw From 45b97948aa79553dc256067d6faea8fdf8d4ec57 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 28 May 2026 15:54:35 -0700 Subject: [PATCH 2/2] Reconstruct content tool results inbound via adapter metadata A content tool result's multipart payload already round-trips inside the UI tool part's `output`, but the signal to rehydrate it as a typed `ContentOutput` was lost on the way back in -- it landed as opaque JSON. Carry the signal through the existing `aiPython` adapter metadata channel (the same one used for source ids), keyed by tool_call_id: outbound records `toolResultKinds` for content results, inbound looks it up and rebuilds the ContentOutput so providers re-expand it into multimodal blocks. Only `content` is recorded -- `error` rides the UI `state` enum and `json` is the default. --- src/ai/agents/ui/ai_sdk/id_utils.py | 74 +++++++++++++++---- src/ai/agents/ui/ai_sdk/inbound_messages.py | 17 +++++ .../agents/ui/ai_sdk/test_inbound_messages.py | 55 +++++++++++++- 3 files changed, 130 insertions(+), 16 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py index 939d9bba..484d8a86 100644 --- a/src/ai/agents/ui/ai_sdk/id_utils.py +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -4,18 +4,25 @@ source message's ``id``, ``role``, ``turnId``, and ``partIds``. Outbound UI bubbles can collapse assistant/tool/internal messages into one UI message; inbound parsing uses this metadata to restore stable message and part ids. + +It also writes ``metadata["aiPython"]["toolResultKinds"]`` mapping a tool +call id to its ``result_kind`` for results the wire ``state`` can't convey +(``content``): the multipart payload already round-trips inside the UI tool +part's ``output``, but the signal to rehydrate it as a typed ``ContentOutput`` +would otherwise be lost. ``error`` rides the UI ``state`` enum and ``json`` +is the default, so only ``content`` is recorded. """ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Literal, cast +from typing import Literal, cast -if TYPE_CHECKING: - from ....types import messages as messages_ +from ....types import messages as messages_ ADAPTER_METADATA_KEY = "aiPython" SOURCE_MESSAGES_KEY = "sourceMessages" +TOOL_RESULT_KINDS_KEY = "toolResultKinds" MessageRole = Literal["user", "assistant", "system", "tool", "internal"] _VALID_ROLES = {"user", "assistant", "system", "tool", "internal"} @@ -77,23 +84,40 @@ def _restore_message_ids( return message.model_copy(update=updates) +def _tool_result_kinds( + source_messages: list[messages_.Message], +) -> dict[str, str]: + """Collect ``{tool_call_id: result_kind}`` for content tool results.""" + kinds: dict[str, str] = {} + for message in source_messages: + for part in message.parts: + if ( + isinstance(part, messages_.ToolResultPart) + and part.result_kind == "content" + ): + kinds[part.tool_call_id] = part.result_kind + return kinds + + def metadata_for( source_messages: list[messages_.Message], ) -> dict[str, object]: """Return adapter metadata for restoring collapsed source message ids.""" - return { - ADAPTER_METADATA_KEY: { - SOURCE_MESSAGES_KEY: [ - { - "id": message.id, - "role": message.role, - "turnId": message.turn_id, - "partIds": [part.id for part in message.parts], - } - for message in source_messages - ] - } + adapter: dict[str, object] = { + SOURCE_MESSAGES_KEY: [ + { + "id": message.id, + "role": message.role, + "turnId": message.turn_id, + "partIds": [part.id for part in message.parts], + } + for message in source_messages + ] } + tool_result_kinds = _tool_result_kinds(source_messages) + if tool_result_kinds: + adapter[TOOL_RESULT_KINDS_KEY] = tool_result_kinds + return {ADAPTER_METADATA_KEY: adapter} def source_messages_from(metadata: object) -> list[SourceMessage]: @@ -119,6 +143,26 @@ def source_messages_from(metadata: object) -> list[SourceMessage]: return result +def tool_result_kinds_from(metadata: object) -> dict[str, str]: + """Parse ``{tool_call_id: result_kind}``, ignoring malformed entries.""" + if not isinstance(metadata, dict): + return {} + metadata_dict = cast("dict[str, object]", metadata) + adapter_metadata = metadata_dict.get(ADAPTER_METADATA_KEY) + if not isinstance(adapter_metadata, dict): + return {} + adapter_metadata_dict = cast("dict[str, object]", adapter_metadata) + raw_kinds = adapter_metadata_dict.get(TOOL_RESULT_KINDS_KEY) + if not isinstance(raw_kinds, dict): + return {} + raw_kinds_dict = cast("dict[str, object]", raw_kinds) + return { + tool_call_id: kind + for tool_call_id, kind in raw_kinds_dict.items() + if isinstance(tool_call_id, str) and isinstance(kind, str) + } + + def restore_source_ids( messages: list[messages_.Message], source_messages: list[SourceMessage], diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index 0435b75f..373e8919 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -84,12 +84,24 @@ def _build_result_part( tool_name: str, output: Any, is_error: bool, + kind_hint: str | None = None, ) -> messages_.ToolResultPart: + """Reconstruct a tool result from its wire form. + + ``kind_hint`` comes from the adapter's ``toolResultKinds`` metadata + (see :mod:`id_utils`). When it marks the result as ``content``, the + ``output`` -- a list of dumped content parts -- is rehydrated into a + typed :class:`ContentOutput` so providers re-expand it into multimodal + blocks; otherwise behaviour matches a plain value round-trip. + """ result: Any result_kind: messages_.ResultKind if is_error: result = output result_kind = "error" + elif kind_hint == "content": + result = messages_.ContentOutput.model_validate({"value": output}) + result_kind = "content" else: decoded = _decode_wire_output(output) result = ( @@ -210,6 +222,7 @@ def _parse( for ui_msg in ui_messages: source_messages = id_utils.source_messages_from(ui_msg.metadata) + result_kinds = id_utils.tool_result_kinds_from(ui_msg.metadata) assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] hook_parts: list[messages_.HookPart[Any]] = [] @@ -272,6 +285,9 @@ def _parse( tool_name=inv.tool_name, output=inv.result, is_error=is_error, + kind_hint=result_kinds.get( + inv.tool_invocation_id + ), ) ) @@ -335,6 +351,7 @@ def _parse( tool_name=tp.tool_name, output=_tool_result_output(tp), is_error=is_error, + kind_hint=result_kinds.get(tp.tool_call_id), ) ) if tp.result_provider_metadata is not None: diff --git a/tests/agents/ui/ai_sdk/test_inbound_messages.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py index a5185309..a3342e05 100644 --- a/tests/agents/ui/ai_sdk/test_inbound_messages.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -5,7 +5,7 @@ import pytest from ai.agents.agent import MessageBundle -from ai.agents.ui.ai_sdk import to_messages +from ai.agents.ui.ai_sdk import to_messages, to_ui_messages from ai.agents.ui.ai_sdk.inbound_messages import _normalize_ui_messages from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart from ai.types import messages as messages_ @@ -189,6 +189,59 @@ def test_to_messages_decodes_subagent_tool_output() -> None: assert not result_part.has_model_input +def test_content_result_round_trips_via_metadata() -> None: + """A content tool result survives outbound -> inbound as ContentOutput. + + The multipart payload rides the UI tool part's ``output``; the + ``toolResultKinds`` adapter metadata carries the signal to rehydrate it. + """ + fp = messages_.FilePart(data=b"img-bytes", media_type="image/png") + turn = "turn-1" + internal = [ + messages_.Message( + id="a1", + turn_id=turn, + role="assistant", + parts=[ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="read", + tool_args="{}", + ) + ], + ), + messages_.Message( + id="t1", + turn_id=turn, + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="read", + result=messages_.ContentOutput( + value=[messages_.TextPart(text="desc"), fp] + ), + result_kind="content", + ) + ], + ), + ] + + ui = to_ui_messages(internal) + restored, _ = to_messages(ui) + + tool_msgs = [m for m in restored if m.role == "tool"] + assert len(tool_msgs) == 1 + part = tool_msgs[0].tool_results[0] + assert part.result_kind == "content" + assert isinstance(part.result, messages_.ContentOutput) + text_part, file_part = part.result.value + assert isinstance(text_part, messages_.TextPart) + assert text_part.text == "desc" + assert isinstance(file_part, messages_.FilePart) + assert file_part.media_type == "image/png" + + def test_to_messages_passthrough_keeps_wire_shape() -> None: """Non-UIMessage tool outputs stay in their wire form.""" ui = [