Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/read_file_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tool that returns a ContentOutput so the model can see image files directly.

The ``read_file`` tool reads a path from disk and inspects the bytes:

* If the file is an image, it returns a :class:`ContentOutput` carrying
a summary line and an image :class:`FilePart`. All three providers
turn that into a real image content block on the next model turn, so
the model actually *sees* the picture.
* Otherwise it returns the decoded text -- a plain ``str`` result the
provider sends to the model verbatim.

A single tool covers both code-reading and image-reading duties in an
agentic loop.
"""

import asyncio
import json
import pathlib

import ai
from ai.types import media

# Restrict the tool to a directory we trust the model to roam in.
# `.resolve()` collapses symlinks so a path inside ALLOWED_ROOT cannot
# escape via a symlink that points elsewhere.
ALLOWED_ROOT = pathlib.Path(__file__).parent.resolve()


def _resolve_within_allowed(path: str) -> pathlib.Path:
resolved = pathlib.Path(path).resolve()
if not resolved.is_relative_to(ALLOWED_ROOT):
raise ValueError(
f"Refusing to read {path!r}: outside allowed root {ALLOWED_ROOT}"
)
return resolved


@ai.tool
async def read_file(path: str) -> str | ai.messages.ContentOutput:
"""Read a file from disk.

Image files come back as a ContentOutput so the model can view them.
"""
data = _resolve_within_allowed(path).read_bytes()
image_type = media.detect_image_media_type(data)
if image_type is not None:
return ai.content_output(
f"Loaded {path} ({image_type}, {len(data)} bytes).",
ai.file_part(data, media_type=image_type),
)
return data.decode("utf-8", errors="replace")


async def main() -> None:
model = ai.get_model("gateway:anthropic/claude-sonnet-4.6")
my_agent = ai.agent(tools=[read_file])

here = pathlib.Path(__file__).parent
image_path = here / "sample_image.jpg"
text_path = here / "agent_simple.py"

messages = [
ai.system_message(
"Use the read_file tool to inspect any files the user mentions."
),
ai.user_message(
f"First read {image_path} and describe what you see in the "
f"picture. Then read {text_path} and summarize what the "
f"script does in one sentence."
),
]

async with my_agent.run(model, messages) as stream:
async for event in stream:
if isinstance(event, ai.events.TextDelta):
print(event.chunk, end="", flush=True)
elif isinstance(event, ai.events.ToolEnd):
args = json.loads(event.tool_call.tool_args or "{}")
print(f"\n[read_file({args.get('path')!r})]")
elif isinstance(event, ai.events.StreamEnd):
print()
print()


if __name__ == "__main__":
asyncio.run(main())
4 changes: 4 additions & 0 deletions src/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@
from .types import events, messages, tools
from .types.builders import (
assistant_message,
content_output,
file_part,
system_message,
text_part,
thinking,
tool_message,
tool_result_part,
Expand Down Expand Up @@ -176,6 +178,7 @@
"agent",
"assistant_message",
"cancel_hook",
"content_output",
"errors",
"events",
"file_part",
Expand All @@ -192,6 +195,7 @@
"resolve_hook",
"stream",
"system_message",
"text_part",
"thinking",
"tool",
"tool_message",
Expand Down
18 changes: 16 additions & 2 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Any,
ClassVar,
Generic,
Literal,
Protocol,
Self,
cast,
Expand Down Expand Up @@ -58,6 +59,18 @@ def _unwrap_singleton_group(exc: BaseException) -> BaseException:
return exc


def _result_kind(value: Any) -> Literal["json", "content"]:
"""Tag a successful tool return value for ``ToolResultPart.result_kind``.

A :class:`ContentOutput` becomes ``"content"`` (expanded into provider
multimodal blocks); everything else is ``"json"`` (the encoder sends a
``str`` raw and JSON-encodes anything else).
"""
if isinstance(value, types.messages.ContentOutput):
return "content"
return "json"


def _error_tool_result(
exc: BaseException,
*,
Expand All @@ -75,7 +88,7 @@ def _error_tool_result(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=f"{type(unwrapped).__name__}: {unwrapped}",
is_error=True,
result_kind="error",
),
exception=unwrapped,
)
Expand Down Expand Up @@ -616,6 +629,7 @@ async def _real(
tool_call_id=call.tool_call_id,
tool_name=call.tool_name,
result=result,
result_kind=_result_kind(result),
)
part.set_model_input(model_input)
return tool_result(part)
Expand Down Expand Up @@ -1041,7 +1055,7 @@ def pending_tool_result(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=f"Pending on hook {hook.hook_id!r}",
is_error=True,
result_kind="error",
is_hook_pending=True,
)
msg = types.messages.Message(role="tool", parts=[part])
Expand Down
74 changes: 59 additions & 15 deletions src/ai/agents/ui/ai_sdk/id_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
source message's ``id``, ``role``, ``turnId``, and ``partIds``. Outbound UI
bubbles can collapse assistant/tool/internal messages into one UI message;
inbound parsing uses this metadata to restore stable message and part ids.

It also writes ``metadata["aiPython"]["toolResultKinds"]`` mapping a tool
call id to its ``result_kind`` for results the wire ``state`` can't convey
(``content``): the multipart payload already round-trips inside the UI tool
part's ``output``, but the signal to rehydrate it as a typed ``ContentOutput``
would otherwise be lost. ``error`` rides the UI ``state`` enum and ``json``
is the default, so only ``content`` is recorded.
"""

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Literal, cast
from typing import Literal, cast

if TYPE_CHECKING:
from ....types import messages as messages_
from ....types import messages as messages_

ADAPTER_METADATA_KEY = "aiPython"
SOURCE_MESSAGES_KEY = "sourceMessages"
TOOL_RESULT_KINDS_KEY = "toolResultKinds"

MessageRole = Literal["user", "assistant", "system", "tool", "internal"]
_VALID_ROLES = {"user", "assistant", "system", "tool", "internal"}
Expand Down Expand Up @@ -77,23 +84,40 @@ def _restore_message_ids(
return message.model_copy(update=updates)


def _tool_result_kinds(
source_messages: list[messages_.Message],
) -> dict[str, str]:
"""Collect ``{tool_call_id: result_kind}`` for content tool results."""
kinds: dict[str, str] = {}
for message in source_messages:
for part in message.parts:
if (
isinstance(part, messages_.ToolResultPart)
and part.result_kind == "content"
):
kinds[part.tool_call_id] = part.result_kind
return kinds


def metadata_for(
source_messages: list[messages_.Message],
) -> dict[str, object]:
"""Return adapter metadata for restoring collapsed source message ids."""
return {
ADAPTER_METADATA_KEY: {
SOURCE_MESSAGES_KEY: [
{
"id": message.id,
"role": message.role,
"turnId": message.turn_id,
"partIds": [part.id for part in message.parts],
}
for message in source_messages
]
}
adapter: dict[str, object] = {
SOURCE_MESSAGES_KEY: [
{
"id": message.id,
"role": message.role,
"turnId": message.turn_id,
"partIds": [part.id for part in message.parts],
}
for message in source_messages
]
}
tool_result_kinds = _tool_result_kinds(source_messages)
if tool_result_kinds:
adapter[TOOL_RESULT_KINDS_KEY] = tool_result_kinds
return {ADAPTER_METADATA_KEY: adapter}


def source_messages_from(metadata: object) -> list[SourceMessage]:
Expand All @@ -119,6 +143,26 @@ def source_messages_from(metadata: object) -> list[SourceMessage]:
return result


def tool_result_kinds_from(metadata: object) -> dict[str, str]:
"""Parse ``{tool_call_id: result_kind}``, ignoring malformed entries."""
if not isinstance(metadata, dict):
return {}
metadata_dict = cast("dict[str, object]", metadata)
adapter_metadata = metadata_dict.get(ADAPTER_METADATA_KEY)
if not isinstance(adapter_metadata, dict):
return {}
adapter_metadata_dict = cast("dict[str, object]", adapter_metadata)
raw_kinds = adapter_metadata_dict.get(TOOL_RESULT_KINDS_KEY)
if not isinstance(raw_kinds, dict):
return {}
raw_kinds_dict = cast("dict[str, object]", raw_kinds)
return {
tool_call_id: kind
for tool_call_id, kind in raw_kinds_dict.items()
if isinstance(tool_call_id, str) and isinstance(kind, str)
}


def restore_source_ids(
messages: list[messages_.Message],
source_messages: list[SourceMessage],
Expand Down
27 changes: 24 additions & 3 deletions src/ai/agents/ui/ai_sdk/inbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,37 @@ def _build_result_part(
tool_name: str,
output: Any,
is_error: bool,
kind_hint: str | None = None,
) -> messages_.ToolResultPart:
"""Reconstruct a tool result from its wire form.

``kind_hint`` comes from the adapter's ``toolResultKinds`` metadata
(see :mod:`id_utils`). When it marks the result as ``content``, the
``output`` -- a list of dumped content parts -- is rehydrated into a
typed :class:`ContentOutput` so providers re-expand it into multimodal
blocks; otherwise behaviour matches a plain value round-trip.
"""
result: Any
result_kind: messages_.ResultKind
if is_error:
result: Any = output
result = output
result_kind = "error"
elif kind_hint == "content":
result = messages_.ContentOutput.model_validate({"value": output})
result_kind = "content"
else:
decoded = _decode_wire_output(output)
result = (
decoded
if isinstance(decoded, MessageBundle)
else _normalize_tool_result(decoded)
)
result_kind = "json"
return messages_.ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=result,
is_error=is_error,
result_kind=result_kind,
)


Expand Down Expand Up @@ -190,7 +206,7 @@ def _patch_pending_hook_aborts(
tool_call_id=tc.tool_call_id,
tool_name=tc.tool_name,
result=f"Pending on hook '{hook.hook_id}'",
is_error=True,
result_kind="error",
is_hook_pending=True,
)
)
Expand All @@ -206,6 +222,7 @@ def _parse(

for ui_msg in ui_messages:
source_messages = id_utils.source_messages_from(ui_msg.metadata)
result_kinds = id_utils.tool_result_kinds_from(ui_msg.metadata)
assistant_parts: list[messages_.Part] = []
tool_result_parts: list[messages_.ToolResultPart] = []
hook_parts: list[messages_.HookPart[Any]] = []
Expand Down Expand Up @@ -268,6 +285,9 @@ def _parse(
tool_name=inv.tool_name,
output=inv.result,
is_error=is_error,
kind_hint=result_kinds.get(
inv.tool_invocation_id
),
)
)

Expand Down Expand Up @@ -331,6 +351,7 @@ def _parse(
tool_name=tp.tool_name,
output=_tool_result_output(tp),
is_error=is_error,
kind_hint=result_kinds.get(tp.tool_call_id),
)
)
if tp.result_provider_metadata is not None:
Expand Down
23 changes: 18 additions & 5 deletions src/ai/agents/ui/ai_sdk/outbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
from typing import Any, cast

from ....types import media
Expand Down Expand Up @@ -107,6 +108,21 @@ def dedupe_tool_parts(
return result


def _output_view(
part: messages_.ToolResultPart,
) -> tuple[str, dict[str, Any]]:
"""Map a tool result to ``(state, field_updates)`` for the UI wire."""
result = part.result
if isinstance(result, messages_.ContentOutput):
return "output-available", {
"output": [item.model_dump(mode="json") for item in result.value]
}
if part.is_error:
text = result if isinstance(result, str) else json.dumps(result)
return "output-error", {"error_text": text}
return "output-available", {"output": result}


def merge_tool_results(
ui_parts: list[ui_messages.UIMessagePart],
tool_parts: list[messages_.Part],
Expand All @@ -121,15 +137,12 @@ def merge_tool_results(
continue
case messages_.ToolResultPart():
tool_call_id = part.tool_call_id
state = "output-error" if part.is_error else "output-available"
state, field_updates = _output_view(part)
updates = {
"state": state,
"result_provider_metadata": part.provider_metadata,
**field_updates,
}
if part.is_error:
updates["error_text"] = str(part.result)
else:
updates["output"] = part.result
case messages_.BuiltinToolReturnPart():
tool_call_id = part.tool_call_id
updates = {
Expand Down
Loading
Loading