From 4c9d6a285cebe889accf99515d7b3a69b7561224 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 11 May 2026 16:35:25 -0400 Subject: [PATCH] handle gemini thought signatures --- .../langchain_openai/chat_models/base.py | 44 +++++++ .../tests/unit_tests/chat_models/test_base.py | 122 ++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 722721560e324..a48f8b3193b37 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -194,6 +194,23 @@ def _get_default_model_profile(model_name: str) -> ModelProfile: "tool_search", ) +# Google's OpenAI-compatible endpoint returns Gemini thought signatures on +# tool calls via `extra_content.google.thought_signature`. The signature must +# be echoed back on the corresponding tool call in subsequent turns. +_GEMINI_THOUGHT_SIGNATURES_MAP_KEY = "__gemini_function_call_thought_signatures__" + + +def _extract_gemini_thought_signature(raw_tool_call: Mapping[str, Any]) -> str | None: + """Pull a Gemini thought signature off a raw OpenAI-format tool call.""" + extra_content = raw_tool_call.get("extra_content") + if not isinstance(extra_content, Mapping): + return None + google = extra_content.get("google") + if not isinstance(google, Mapping): + return None + signature = google.get("thought_signature") + return signature if isinstance(signature, str) else None + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. @@ -218,6 +235,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs["function_call"] = dict(function_call) tool_calls = [] invalid_tool_calls = [] + thought_signatures: dict[str, str] = {} if raw_tool_calls := _dict.get("tool_calls"): for raw_tool_call in raw_tool_calls: try: @@ -226,6 +244,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: invalid_tool_calls.append( make_invalid_tool_call(raw_tool_call, str(e)) ) + if (signature := _extract_gemini_thought_signature(raw_tool_call)) and ( + tool_call_id := raw_tool_call.get("id") + ): + thought_signatures[tool_call_id] = signature + if thought_signatures: + additional_kwargs[_GEMINI_THOUGHT_SIGNATURES_MAP_KEY] = thought_signatures if audio := _dict.get("audio"): additional_kwargs["audio"] = audio return AIMessage( @@ -381,6 +405,18 @@ def _convert_message_to_dict( message_dict["function_call"] = message.additional_kwargs["function_call"] else: pass + if "tool_calls" in message_dict and ( + thought_signatures := message.additional_kwargs.get( + _GEMINI_THOUGHT_SIGNATURES_MAP_KEY + ) + ): + for tool_call in message_dict["tool_calls"]: + if ( + signature := thought_signatures.get(tool_call.get("id")) + ) is not None: + tool_call["extra_content"] = { + "google": {"thought_signature": signature} + } # If tool calls present, content null value should be None not empty string. if "function_call" in message_dict or "tool_calls" in message_dict: message_dict["content"] = message_dict["content"] or None @@ -452,6 +488,14 @@ def _convert_delta_to_message_chunk( ] except KeyError: pass + thought_signatures: dict[str, str] = {} + for rtc in raw_tool_calls: + if (signature := _extract_gemini_thought_signature(rtc)) and ( + tool_call_id := rtc.get("id") + ): + thought_signatures[tool_call_id] = signature + if thought_signatures: + additional_kwargs[_GEMINI_THOUGHT_SIGNATURES_MAP_KEY] = thought_signatures if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content, id=id_) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 08b0850fd5fc1..9b98c4211028b 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -66,6 +66,7 @@ _convert_to_v03_ai_message, ) from langchain_openai.chat_models.base import ( + _GEMINI_THOUGHT_SIGNATURES_MAP_KEY, OpenAIRefusalError, _construct_lc_result_from_responses_api, _construct_responses_api_input, @@ -787,6 +788,127 @@ def test_function_calls_with_tool_calls(mock_client: MagicMock) -> None: assert "tool_calls" not in tool_call_message_payload +@pytest.mark.parametrize("streaming", [False, True]) +def test_gemini_thought_signature_roundtrip(streaming: bool) -> None: + """Gemini thought signatures should survive a tool-calling round trip. + + Google's OpenAI-compatible endpoint returns `extra_content.google.thought_signature` + on tool calls. The signature must be echoed back on the same tool call in + subsequent turns or the model errors out, so we stash a `{id: signature}` map on + `additional_kwargs[_GEMINI_THOUGHT_SIGNATURES_MAP_KEY]` and re-inject it on send. + """ + tool_call = { + "id": "tc_1", + "type": "function", + "function": {"name": "greet", "arguments": "{}"}, + "extra_content": {"google": {"thought_signature": "SIG_A"}}, + } + non_streaming_response = { + "id": "resp_1", + "object": "chat.completion", + "model": "gemini-3-flash-preview", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": None, + "tool_calls": [tool_call], + }, + } + ], + } + streaming_chunks = [ + { + "id": "resp_1", + "object": "chat.completion.chunk", + "model": "gemini-3-flash-preview", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": None, + "tool_calls": [{"index": 0, **tool_call}], + }, + "finish_reason": None, + } + ], + }, + { + "id": "resp_1", + "object": "chat.completion.chunk", + "model": "gemini-3-flash-preview", + "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], + }, + ] + + captured: dict[str, Any] = {} + + def configure_mock(client: MagicMock) -> None: + if streaming: + + def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager: + captured["kwargs"] = kwargs + return MockSyncContextManager(streaming_chunks) + + client.create = mock_create + else: + + def mock_with_raw_create(*args: Any, **kwargs: Any) -> MagicMock: + captured["kwargs"] = kwargs + resp = MagicMock() + resp.parse.return_value = non_streaming_response + resp.headers = {} + return resp + + client.with_raw_response.create = mock_with_raw_create + + llm = ChatOpenAI(model="gemini-3-flash-preview") + mock_client = MagicMock() + configure_mock(mock_client) + + # Turn 1: parse signature off the response. + with patch.object(llm, "client", mock_client): + if streaming: + chunks = list(llm.stream("Use the greet tool")) + assistant_message: AIMessage = chunks[0] + for chunk in chunks[1:]: + assistant_message = assistant_message + chunk # type: ignore[assignment] + else: + assistant_message = cast(AIMessage, llm.invoke("Use the greet tool")) + + assert assistant_message.additional_kwargs.get( + _GEMINI_THOUGHT_SIGNATURES_MAP_KEY + ) == {"tc_1": "SIG_A"} + assert assistant_message.tool_calls + assert assistant_message.tool_calls[0]["id"] == "tc_1" + + # Turn 2: confirm the signature is re-attached on the outgoing payload. + captured.clear() + configure_mock(mock_client) + history: list[BaseMessage] = [ + HumanMessage("Use the greet tool"), + assistant_message, + ToolMessage(content="hello", tool_call_id="tc_1"), + ] + with patch.object(llm, "client", mock_client): + if streaming: + list(llm.stream(history)) + else: + llm.invoke(history) + + sent_messages = captured["kwargs"]["messages"] + sent_assistant = sent_messages[1] + assert sent_assistant["role"] == "assistant" + sent_tool_calls = sent_assistant["tool_calls"] + assert sent_tool_calls[0]["id"] == "tc_1" + assert sent_tool_calls[0]["extra_content"] == { + "google": {"thought_signature": "SIG_A"} + } + + def test_custom_token_counting() -> None: def token_encoder(text: str) -> list[int]: return [1, 2, 3]