diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index a19e8a8d..628810a1 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -116,10 +116,22 @@ async def _messages_to_prompt( assistant_content: list[dict[str, Any]] = [] for part in msg.parts: match part: - case types.messages.ReasoningPart(text=text): - assistant_content.append( - {"type": "reasoning", "text": text} - ) + case types.messages.ReasoningPart( + text=text, provider_metadata=pm + ): + reasoning_entry: dict[str, Any] = { + "type": "reasoning", + "text": text, + } + # Replay the provider's reasoning metadata (e.g. + # the thinking-block signature) verbatim. Without + # it the provider drops the block and the model + # loses access to its prior reasoning. v3 mirrors + # inbound ``providerMetadata`` to outbound + # ``providerOptions``. + if pm: + reasoning_entry["providerOptions"] = pm + assistant_content.append(reasoning_entry) case types.messages.TextPart(text=text): assistant_content.append( {"type": "text", "text": text} @@ -382,6 +394,8 @@ def _parse_stream_part( return [types.events.TextEnd(block_id=data.get("id", "text"))] case "reasoning-start": + # Metadata on -start is gateway routing info (generationId), + # not the provider's reasoning metadata; don't replay it. return [ types.events.ReasoningStart( block_id=data.get("id", "reasoning") @@ -393,12 +407,16 @@ def _parse_stream_part( types.events.ReasoningDelta( block_id=data.get("id", "reasoning"), chunk=data.get("delta", ""), + provider_metadata=data.get("providerMetadata"), ) ] case "reasoning-end": return [ - types.events.ReasoningEnd(block_id=data.get("id", "reasoning")) + types.events.ReasoningEnd( + block_id=data.get("id", "reasoning"), + provider_metadata=data.get("providerMetadata"), + ) ] case "tool-input-start": diff --git a/src/ai/types/events.py b/src/ai/types/events.py index 14aa220c..15ee16b6 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -220,7 +220,14 @@ async def replay_message_events( yield ReasoningStart(block_id=part.id) if part.text: yield ReasoningDelta(block_id=part.id, chunk=part.text) - yield ReasoningEnd(block_id=part.id) + # Carry the signature (and any other reasoning metadata) on the + # end event, mirroring how the real adapters emit it -- otherwise + # a replayed-then-rebuilt turn loses its signature and can't be + # replayed to the provider. + yield ReasoningEnd( + block_id=part.id, + provider_metadata=part.provider_metadata, + ) elif isinstance(part, messages.ToolCallPart): yield ToolStart( tool_call_id=part.tool_call_id, diff --git a/tests/providers/ai_gateway/test_protocol.py b/tests/providers/ai_gateway/test_protocol.py index 5c881bc2..4ca0ccd4 100644 --- a/tests/providers/ai_gateway/test_protocol.py +++ b/tests/providers/ai_gateway/test_protocol.py @@ -13,10 +13,13 @@ from __future__ import annotations import json +from collections.abc import AsyncGenerator +from typing import Any from unittest.mock import AsyncMock, patch import pydantic +from ai import models from ai.providers.ai_gateway import protocol from ai.types import events as events_ from ai.types import messages @@ -67,6 +70,43 @@ async def test_assistant_with_reasoning_and_text(self) -> None: assert content[0] == {"type": "reasoning", "text": "Let me think..."} assert content[1] == {"type": "text", "text": "42"} + async def test_assistant_reasoning_replays_signature(self) -> None: + """A reasoning part's metadata (the thinking-block signature) must + be replayed verbatim as ``providerOptions`` so the upstream can + verify its own thinking.""" + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ReasoningPart( + text="Let me think...", + provider_metadata={ + "anthropic": {"signature": "ErMJabc123"} + }, + ), + ], + ) + ] + result = await protocol._messages_to_prompt(msgs) + assert result[0]["content"][0] == { + "type": "reasoning", + "text": "Let me think...", + "providerOptions": {"anthropic": {"signature": "ErMJabc123"}}, + } + + async def test_assistant_reasoning_without_signature_omits_options( + self, + ) -> None: + """No signature -> no ``providerOptions`` key (back-compat).""" + msgs = [ + messages.Message( + role="assistant", + parts=[messages.ReasoningPart(text="hmm")], + ) + ] + result = await protocol._messages_to_prompt(msgs) + assert result[0]["content"][0] == {"type": "reasoning", "text": "hmm"} + async def test_tool_call_with_result_produces_two_messages(self) -> None: """A completed tool call must produce an assistant message (with the tool-call) AND a tool message (with the result).""" @@ -323,6 +363,48 @@ def test_finish_v3_nested_usage(self) -> None: assert done.usage.cache_read_tokens == 50 assert done.usage.reasoning_tokens == 30 + def test_reasoning_delta_carries_provider_metadata(self) -> None: + """A reasoning-delta's ``providerMetadata`` (the thinking-block + signature) rides through verbatim on ``provider_metadata``.""" + events = protocol._parse_stream_part( + { + "type": "reasoning-delta", + "id": "0", + "delta": "", + "providerMetadata": {"anthropic": {"signature": "ErMJabc123"}}, + }, + set(), + ) + assert len(events) == 1 + delta = events[0] + assert isinstance(delta, events_.ReasoningDelta) + assert delta.provider_metadata == { + "anthropic": {"signature": "ErMJabc123"} + } + + def test_reasoning_delta_without_metadata(self) -> None: + """A plain reasoning-delta carries no provider_metadata.""" + events = protocol._parse_stream_part( + {"type": "reasoning-delta", "id": "0", "delta": "thinking"}, + set(), + ) + assert isinstance(events[0], events_.ReasoningDelta) + assert events[0].provider_metadata is None + + def test_reasoning_start_drops_routing_metadata(self) -> None: + """Metadata on -start is gateway routing info (generationId), not + provider reasoning metadata, and must not be replayed.""" + events = protocol._parse_stream_part( + { + "type": "reasoning-start", + "id": "0", + "providerMetadata": {"gateway": {"generationId": "gen_1"}}, + }, + set(), + ) + assert isinstance(events[0], events_.ReasoningStart) + assert events[0].provider_metadata is None + def test_file_part(self) -> None: """A ``file`` stream part (inline image from Gemini/GPT-5) must produce a FileEvent.""" @@ -389,3 +471,59 @@ def test_non_dict_returns_empty(self) -> None: usage = protocol._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 + + +# --------------------------------------------------------------------------- +# Thinking-block round trip (signature survives in -> aggregate -> out) +# --------------------------------------------------------------------------- + + +class TestReasoningSignatureRoundTrip: + """The whole point of capturing the signature: it must survive being + parsed from the wire, aggregated into a Message, and re-serialized so + the upstream sees its own thinking on the next turn.""" + + async def test_signature_survives_round_trip(self) -> None: + # Wire parts as the gateway emits them: the signature rides on the + # final (empty) reasoning-delta, not the start or end. + wire_parts: list[dict[str, Any]] = [ + {"type": "reasoning-start", "id": "0"}, + {"type": "reasoning-delta", "id": "0", "delta": "thinking hard"}, + { + "type": "reasoning-delta", + "id": "0", + "delta": "", + "providerMetadata": {"anthropic": {"signature": "ErMJsig=="}}, + }, + {"type": "reasoning-end", "id": "0"}, + ] + + async def _gen() -> AsyncGenerator[events_.Event]: + for part in wire_parts: + for event in protocol._parse_stream_part(part, set()): + yield event + + stream = models.Stream(_gen()) + async for _ in stream: + pass + + # Aggregated message: one reasoning part carrying the signature. + reasoning = [ + p + for p in stream.message.parts + if isinstance(p, messages.ReasoningPart) + ] + assert len(reasoning) == 1 + assert reasoning[0].text == "thinking hard" + assert reasoning[0].provider_metadata == { + "anthropic": {"signature": "ErMJsig=="} + } + + # Round-trip back out: the metadata is replayed verbatim to the + # provider as providerOptions. + out = await protocol._messages_to_prompt( + [messages.Message(role="assistant", parts=stream.message.parts)] + ) + assert out[0]["content"][0]["providerOptions"] == { + "anthropic": {"signature": "ErMJsig=="} + } diff --git a/tests/types/test_events.py b/tests/types/test_events.py new file mode 100644 index 00000000..a54b8c9a --- /dev/null +++ b/tests/types/test_events.py @@ -0,0 +1,63 @@ +"""Tests for ``ai.types.events``.""" + +from __future__ import annotations + +from ai import models +from ai.types import events, messages + + +class TestReplayMessageEvents: + async def test_reasoning_signature_survives_replay(self) -> None: + """A signed reasoning part replayed through the Stream aggregator + must keep its provider_metadata -- otherwise a rebuilt turn can't + be replayed to the provider.""" + original = messages.Message( + role="assistant", + parts=[ + messages.ReasoningPart( + text="thinking hard", + provider_metadata={"anthropic": {"signature": "ErMJsig=="}}, + ), + messages.TextPart(text="the answer is 42"), + ], + ) + + async with models.Stream( + events.replay_message_events(original) + ) as stream: + async for _ in stream: + pass + + reasoning = [ + p + for p in stream.message.parts + if isinstance(p, messages.ReasoningPart) + ] + assert len(reasoning) == 1 + assert reasoning[0].text == "thinking hard" + assert reasoning[0].provider_metadata == { + "anthropic": {"signature": "ErMJsig=="} + } + + async def test_reasoning_signature_on_end_event(self) -> None: + """The signature rides on the ReasoningEnd event, mirroring how the + real streaming adapters emit it.""" + msg = messages.Message( + role="assistant", + parts=[ + messages.ReasoningPart( + text="hmm", + provider_metadata={"anthropic": {"signature": "sig"}}, + ) + ], + ) + + reasoning_ends = [ + e + async for e in events.replay_message_events(msg) + if isinstance(e, events.ReasoningEnd) + ] + assert len(reasoning_ends) == 1 + assert reasoning_ends[0].provider_metadata == { + "anthropic": {"signature": "sig"} + }