Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/ai/providers/ai_gateway/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,22 @@ async def _messages_to_prompt(
assistant_content: list[dict[str, Any]] = []
for part in msg.parts:
match part:
case types.messages.ReasoningPart(text=text):
assistant_content.append(
{"type": "reasoning", "text": text}
)
case types.messages.ReasoningPart(
text=text, provider_metadata=pm
):
reasoning_entry: dict[str, Any] = {
"type": "reasoning",
"text": text,
}
# Replay the provider's reasoning metadata (e.g.
# the thinking-block signature) verbatim. Without
# it the provider drops the block and the model
# loses access to its prior reasoning. v3 mirrors
# inbound ``providerMetadata`` to outbound
# ``providerOptions``.
if pm:
reasoning_entry["providerOptions"] = pm
assistant_content.append(reasoning_entry)
case types.messages.TextPart(text=text):
assistant_content.append(
{"type": "text", "text": text}
Expand Down Expand Up @@ -382,6 +394,8 @@ def _parse_stream_part(
return [types.events.TextEnd(block_id=data.get("id", "text"))]

case "reasoning-start":
# Metadata on -start is gateway routing info (generationId),
# not the provider's reasoning metadata; don't replay it.
return [
types.events.ReasoningStart(
block_id=data.get("id", "reasoning")
Expand All @@ -393,12 +407,16 @@ def _parse_stream_part(
types.events.ReasoningDelta(
block_id=data.get("id", "reasoning"),
chunk=data.get("delta", ""),
provider_metadata=data.get("providerMetadata"),
)
]

case "reasoning-end":
return [
types.events.ReasoningEnd(block_id=data.get("id", "reasoning"))
types.events.ReasoningEnd(
block_id=data.get("id", "reasoning"),
provider_metadata=data.get("providerMetadata"),
)
]

case "tool-input-start":
Expand Down
9 changes: 8 additions & 1 deletion src/ai/types/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,14 @@ async def replay_message_events(
yield ReasoningStart(block_id=part.id)
if part.text:
yield ReasoningDelta(block_id=part.id, chunk=part.text)
yield ReasoningEnd(block_id=part.id)
# Carry the signature (and any other reasoning metadata) on the
# end event, mirroring how the real adapters emit it -- otherwise
# a replayed-then-rebuilt turn loses its signature and can't be
# replayed to the provider.
yield ReasoningEnd(
block_id=part.id,
provider_metadata=part.provider_metadata,
)
elif isinstance(part, messages.ToolCallPart):
yield ToolStart(
tool_call_id=part.tool_call_id,
Expand Down
138 changes: 138 additions & 0 deletions tests/providers/ai_gateway/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
from __future__ import annotations

import json
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, patch

import pydantic

from ai import models
from ai.providers.ai_gateway import protocol
from ai.types import events as events_
from ai.types import messages
Expand Down Expand Up @@ -67,6 +70,43 @@ async def test_assistant_with_reasoning_and_text(self) -> None:
assert content[0] == {"type": "reasoning", "text": "Let me think..."}
assert content[1] == {"type": "text", "text": "42"}

async def test_assistant_reasoning_replays_signature(self) -> None:
"""A reasoning part's metadata (the thinking-block signature) must
be replayed verbatim as ``providerOptions`` so the upstream can
verify its own thinking."""
msgs = [
messages.Message(
role="assistant",
parts=[
messages.ReasoningPart(
text="Let me think...",
provider_metadata={
"anthropic": {"signature": "ErMJabc123"}
},
),
],
)
]
result = await protocol._messages_to_prompt(msgs)
assert result[0]["content"][0] == {
"type": "reasoning",
"text": "Let me think...",
"providerOptions": {"anthropic": {"signature": "ErMJabc123"}},
}

async def test_assistant_reasoning_without_signature_omits_options(
self,
) -> None:
"""No signature -> no ``providerOptions`` key (back-compat)."""
msgs = [
messages.Message(
role="assistant",
parts=[messages.ReasoningPart(text="hmm")],
)
]
result = await protocol._messages_to_prompt(msgs)
assert result[0]["content"][0] == {"type": "reasoning", "text": "hmm"}

async def test_tool_call_with_result_produces_two_messages(self) -> None:
"""A completed tool call must produce an assistant message
(with the tool-call) AND a tool message (with the result)."""
Expand Down Expand Up @@ -323,6 +363,48 @@ def test_finish_v3_nested_usage(self) -> None:
assert done.usage.cache_read_tokens == 50
assert done.usage.reasoning_tokens == 30

def test_reasoning_delta_carries_provider_metadata(self) -> None:
"""A reasoning-delta's ``providerMetadata`` (the thinking-block
signature) rides through verbatim on ``provider_metadata``."""
events = protocol._parse_stream_part(
{
"type": "reasoning-delta",
"id": "0",
"delta": "",
"providerMetadata": {"anthropic": {"signature": "ErMJabc123"}},
},
set(),
)
assert len(events) == 1
delta = events[0]
assert isinstance(delta, events_.ReasoningDelta)
assert delta.provider_metadata == {
"anthropic": {"signature": "ErMJabc123"}
}

def test_reasoning_delta_without_metadata(self) -> None:
"""A plain reasoning-delta carries no provider_metadata."""
events = protocol._parse_stream_part(
{"type": "reasoning-delta", "id": "0", "delta": "thinking"},
set(),
)
assert isinstance(events[0], events_.ReasoningDelta)
assert events[0].provider_metadata is None

def test_reasoning_start_drops_routing_metadata(self) -> None:
"""Metadata on -start is gateway routing info (generationId), not
provider reasoning metadata, and must not be replayed."""
events = protocol._parse_stream_part(
{
"type": "reasoning-start",
"id": "0",
"providerMetadata": {"gateway": {"generationId": "gen_1"}},
},
set(),
)
assert isinstance(events[0], events_.ReasoningStart)
assert events[0].provider_metadata is None

def test_file_part(self) -> None:
"""A ``file`` stream part (inline image from Gemini/GPT-5)
must produce a FileEvent."""
Expand Down Expand Up @@ -389,3 +471,59 @@ def test_non_dict_returns_empty(self) -> None:
usage = protocol._parse_usage("not a dict")
assert usage.input_tokens == 0
assert usage.output_tokens == 0


# ---------------------------------------------------------------------------
# Thinking-block round trip (signature survives in -> aggregate -> out)
# ---------------------------------------------------------------------------


class TestReasoningSignatureRoundTrip:
"""The whole point of capturing the signature: it must survive being
parsed from the wire, aggregated into a Message, and re-serialized so
the upstream sees its own thinking on the next turn."""

async def test_signature_survives_round_trip(self) -> None:
# Wire parts as the gateway emits them: the signature rides on the
# final (empty) reasoning-delta, not the start or end.
wire_parts: list[dict[str, Any]] = [
{"type": "reasoning-start", "id": "0"},
{"type": "reasoning-delta", "id": "0", "delta": "thinking hard"},
{
"type": "reasoning-delta",
"id": "0",
"delta": "",
"providerMetadata": {"anthropic": {"signature": "ErMJsig=="}},
},
{"type": "reasoning-end", "id": "0"},
]

async def _gen() -> AsyncGenerator[events_.Event]:
for part in wire_parts:
for event in protocol._parse_stream_part(part, set()):
yield event

stream = models.Stream(_gen())
async for _ in stream:
pass

# Aggregated message: one reasoning part carrying the signature.
reasoning = [
p
for p in stream.message.parts
if isinstance(p, messages.ReasoningPart)
]
assert len(reasoning) == 1
assert reasoning[0].text == "thinking hard"
assert reasoning[0].provider_metadata == {
"anthropic": {"signature": "ErMJsig=="}
}

# Round-trip back out: the metadata is replayed verbatim to the
# provider as providerOptions.
out = await protocol._messages_to_prompt(
[messages.Message(role="assistant", parts=stream.message.parts)]
)
assert out[0]["content"][0]["providerOptions"] == {
"anthropic": {"signature": "ErMJsig=="}
}
63 changes: 63 additions & 0 deletions tests/types/test_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for ``ai.types.events``."""

from __future__ import annotations

from ai import models
from ai.types import events, messages


class TestReplayMessageEvents:
async def test_reasoning_signature_survives_replay(self) -> None:
"""A signed reasoning part replayed through the Stream aggregator
must keep its provider_metadata -- otherwise a rebuilt turn can't
be replayed to the provider."""
original = messages.Message(
role="assistant",
parts=[
messages.ReasoningPart(
text="thinking hard",
provider_metadata={"anthropic": {"signature": "ErMJsig=="}},
),
messages.TextPart(text="the answer is 42"),
],
)

async with models.Stream(
events.replay_message_events(original)
) as stream:
async for _ in stream:
pass

reasoning = [
p
for p in stream.message.parts
if isinstance(p, messages.ReasoningPart)
]
assert len(reasoning) == 1
assert reasoning[0].text == "thinking hard"
assert reasoning[0].provider_metadata == {
"anthropic": {"signature": "ErMJsig=="}
}

async def test_reasoning_signature_on_end_event(self) -> None:
"""The signature rides on the ReasoningEnd event, mirroring how the
real streaming adapters emit it."""
msg = messages.Message(
role="assistant",
parts=[
messages.ReasoningPart(
text="hmm",
provider_metadata={"anthropic": {"signature": "sig"}},
)
],
)

reasoning_ends = [
e
async for e in events.replay_message_events(msg)
if isinstance(e, events.ReasoningEnd)
]
assert len(reasoning_ends) == 1
assert reasoning_ends[0].provider_metadata == {
"anthropic": {"signature": "sig"}
}
Loading