-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathbeeai-reasoning.patch
More file actions
201 lines (184 loc) · 8.36 KB
/
beeai-reasoning.patch
File metadata and controls
201 lines (184 loc) · 8.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
diff --git a/python/beeai_framework/adapters/litellm/chat.py b/python/beeai_framework/adapters/litellm/chat.py
index b5c5a9b4..9ad9f544 100644
--- a/python/beeai_framework/adapters/litellm/chat.py
+++ b/python/beeai_framework/adapters/litellm/chat.py
@@ -37,7 +37,10 @@ from beeai_framework.backend.chat import (
)
from beeai_framework.backend.errors import ChatModelError
from beeai_framework.backend.message import (
+ AnyMessage,
AssistantMessage,
+ AssistantMessageContent,
+ MessageReasoningContent,
MessageTextContent,
MessageToolCallContent,
ToolMessage,
@@ -195,6 +198,7 @@ class LiteLLMChatModel(ChatModel, ABC):
"role": "assistant",
"content": msg_text_content or None,
"tool_calls": msg_tool_calls or None,
+ "thinking_blocks": message.meta.get("thinking_blocks"),
}
if self.model_supports_tool_calling
else {
@@ -307,27 +311,41 @@ class LiteLLMChatModel(ChatModel, ABC):
total_cost_usd=prompt_tokens_cost_usd + completion_tokens_cost_usd,
)
- return ChatModelOutput(
- output=(
- [
- AssistantMessage(
- [
- MessageToolCallContent(
- id=call.id or "",
- tool_name=call.function.name or "",
- args=call.function.arguments,
- )
- for call in update.tool_calls
- ],
- id=chunk.id,
+ reasoning_content = getattr(update, "reasoning_content", None) if update else None
+ # Anthropic requires `thinking_blocks` (with cryptographic signatures) to be sent back
+ # in conversation history; without them LiteLLM silently disables thinking on follow-up turns
+ meta = None
+ if (thinking_blocks := (getattr(update, "thinking_blocks", None) if update else None)) and (
+ # Streaming deltas carry partial blocks without signatures - filter those out
+ signed_thinking_blocks := [
+ b
+ for b in thinking_blocks
+ if (b.get("signature") if isinstance(b, dict) else getattr(b, "signature", None))
+ ]
+ ):
+ meta = {"thinking_blocks": signed_thinking_blocks}
+
+ if update:
+ parts: list[AssistantMessageContent] = []
+ if reasoning_content:
+ parts.append(MessageReasoningContent(text=reasoning_content))
+ if update.tool_calls:
+ parts.extend(
+ MessageToolCallContent(
+ id=call.id or "",
+ tool_name=call.function.name or "",
+ args=call.function.arguments,
)
- if update.tool_calls
- # pyrefly: ignore [bad-argument-type]
- else AssistantMessage(update.content or update.reasoning_content or "", id=chunk.id)
- ]
- if (update and update.model_dump(exclude_none=True))
- else []
- ),
+ for call in update.tool_calls
+ )
+ if update.content:
+ parts.append(MessageTextContent(text=update.content))
+ output: list[AnyMessage] = [AssistantMessage(parts, id=chunk.id, meta=meta)] if parts or meta else []
+ else:
+ output: list[AnyMessage] = []
+
+ return ChatModelOutput(
+ output=output,
# Will be set later
output_structured=None,
finish_reason=finish_reason,
diff --git a/python/beeai_framework/backend/__init__.py b/python/beeai_framework/backend/__init__.py
index fe8e5002..f3d7fe86 100644
--- a/python/beeai_framework/backend/__init__.py
+++ b/python/beeai_framework/backend/__init__.py
@@ -23,6 +23,7 @@ from beeai_framework.backend.message import (
Message,
MessageFileContent,
MessageImageContent,
+ MessageReasoningContent,
MessageTextContent,
MessageToolCallContent,
MessageToolResultContent,
@@ -65,6 +66,7 @@ __all__ = [
"MessageError",
"MessageFileContent",
"MessageImageContent",
+ "MessageReasoningContent",
"MessageTextContent",
"MessageToolCallContent",
"MessageToolResultContent",
diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py
index 972b5246..a1ed0ca3 100644
--- a/python/beeai_framework/backend/chat.py
+++ b/python/beeai_framework/backend/chat.py
@@ -229,6 +229,11 @@ class ChatModelOptions(RunnableOptions, total=False):
Generated chunks will be streamed without validation of the produced tool calls.
"""
+ reasoning_effort: str | None
+ """
+ Controls the amount of reasoning effort for models that support it (e.g., "low", "medium", "high").
+ """
+
fallback_tool: AnyTool | None
"""
Tool to invoke when the model makes a malformed tool call (for example, when it forgets the name of a tool).
diff --git a/python/beeai_framework/backend/message.py b/python/beeai_framework/backend/message.py
index 3877d35b..befd7e9e 100644
--- a/python/beeai_framework/backend/message.py
+++ b/python/beeai_framework/backend/message.py
@@ -86,6 +86,11 @@ class MessageToolResultContent(BaseModel):
tool_call_id: str
+class MessageReasoningContent(BaseModel):
+ type: Literal["reasoning"] = "reasoning"
+ text: str
+
+
class MessageToolCallContent(BaseModel):
type: Literal["tool-call"] = "tool-call"
id: str
@@ -157,7 +162,7 @@ class Message(ABC, Generic[T]):
return type(self)([c.model_copy() for c in self.content], self.meta.copy())
-AssistantMessageContent = MessageTextContent | MessageToolCallContent
+AssistantMessageContent = MessageTextContent | MessageToolCallContent | MessageReasoningContent
class AssistantMessage(Message[AssistantMessageContent]):
@@ -175,8 +180,10 @@ class AssistantMessage(Message[AssistantMessageContent]):
(
MessageTextContent(text=c)
if isinstance(c, str)
- # pyrefly: ignore [bad-argument-type]
- else to_any_model([MessageToolCallContent, MessageTextContent], cast(AssistantMessageContent, c))
+ else to_any_model(
+ [MessageToolCallContent, MessageReasoningContent, MessageTextContent],
+ cast(AssistantMessageContent, c), # pyrefly: ignore [bad-argument-type]
+ )
)
for c in cast_list(content)
]
@@ -189,12 +196,19 @@ class AssistantMessage(Message[AssistantMessageContent]):
id=id,
)
+ @property
+ def reasoning(self) -> str:
+ return "".join([x.text for x in self.get_reasoning_messages()])
+
def get_tool_calls(self) -> list[MessageToolCallContent]:
return [cont for cont in self.content if isinstance(cont, MessageToolCallContent)]
def get_text_messages(self) -> list[MessageTextContent]:
return [cont for cont in self.content if isinstance(cont, MessageTextContent)]
+ def get_reasoning_messages(self) -> list[MessageReasoningContent]:
+ return [cont for cont in self.content if isinstance(cont, MessageReasoningContent)]
+
class ToolMessage(Message[MessageToolResultContent]):
role = Role.TOOL
diff --git a/python/beeai_framework/backend/types.py b/python/beeai_framework/backend/types.py
index b44a0a19..222d60a7 100644
--- a/python/beeai_framework/backend/types.py
+++ b/python/beeai_framework/backend/types.py
@@ -33,6 +33,7 @@ class ChatModelParameters(BaseModel):
seed: int | None = None
stop_sequences: list[str] | None = None
stream: bool | None = None
+ reasoning_effort: str | None = None
class ChatModelStructureInput(ChatModelParameters, Generic[T]):
@@ -218,6 +219,9 @@ class ChatModelOutput(RunnableOutput):
def get_text_content(self) -> str:
return "".join([x.text for x in list(filter(lambda x: isinstance(x, AssistantMessage), self.output))])
+ def get_reasoning_content(self) -> str:
+ return "".join([x.reasoning for x in self.output if isinstance(x, AssistantMessage)])
+
ChatModelCache = BaseCache[list[ChatModelOutput]]