forked from xai-org/xai-sdk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_calling.py
More file actions
151 lines (117 loc) · 4.95 KB
/
function_calling.py
File metadata and controls
151 lines (117 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import asyncio
import json
from typing import Literal, Sequence
from absl import app, flags
from pydantic import BaseModel, Field
from xai_sdk import AsyncClient
from xai_sdk.chat import system, tool, tool_result, user
STREAM = flags.DEFINE_bool("stream", False, "Whether streaming is enabled.")
async def function_calling(client: AsyncClient) -> None:
"""Multi-turn chat with function calling."""
def get_weather(city: str, units: Literal["C", "F"]) -> str:
temperature = 20 if units == "C" else 68
return f"The weather in {city} is sunny at a temperature of {temperature} {units}."
chat = client.chat.create(
model="grok-3",
messages=[system("You are a helpful assistant that can answer questions and help with tasks.")],
tools=[
tool(
name="get_weather",
description="Get the weather for a given city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for.",
},
"units": {
"type": "string",
"description": "The units to use for the temperature.",
"enum": ["C", "F"],
},
},
"required": ["city", "units"],
},
),
],
)
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
chat.append(user(user_input))
response = await chat.sample()
chat.append(response)
if response.content:
print("Grok: ", end="", flush=True)
print(response.content, end="", flush=True)
if response.tool_calls:
for tool_call in response.tool_calls:
tool_args = json.loads(tool_call.function.arguments)
result = get_weather(tool_args["city"], tool_args["units"])
chat.append(tool_result(result))
response = await chat.sample()
print()
print("Grok: ", end="", flush=True)
print(response.content, end="", flush=True)
chat.append(response)
print()
async def function_calling_streaming(client: AsyncClient) -> None:
"""Multi-turn chat with function calling and streaming."""
# Define the shape of the tool call arguments as a Pydantic model.
class GetWeatherRequest(BaseModel):
city: str = Field(description="The name of the city to get the weather for.")
units: Literal["C", "F"] = Field(description="The units to use for the temperature.")
def get_weather(request: GetWeatherRequest) -> str:
temperature = 20 if request.units == "C" else 68
return f"The weather in {request.city} is sunny at a temperature of {temperature} {request.units}."
chat = client.chat.create(
model="grok-3",
messages=[system("You are a helpful assistant that can answer questions and help with tasks.")],
tools=[
tool(
name="get_weather",
description="Get the weather for a given city.",
# Generate the json schema from the Pydantic model.
parameters=GetWeatherRequest.model_json_schema(),
)
],
)
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
break
chat.append(user(user_input))
stream = chat.stream()
print("Grok: ", end="", flush=True)
last_response = None
async for response, chunk in stream:
print(chunk.content, end="", flush=True)
last_response = response
assert last_response is not None
chat.append(last_response)
if last_response.tool_calls:
for tool_call in last_response.tool_calls:
# Validate the tool call arguments as a Pydantic model and get proper type checking.
request = GetWeatherRequest.model_validate_json(tool_call.function.arguments)
result = get_weather(request)
chat.append(tool_result(result))
stream = chat.stream()
last_response = None
async for response, chunk in stream:
print(chunk.content, end="", flush=True)
last_response = response
assert last_response is not None
chat.append(last_response)
print()
async def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Unexpected command line arguments.")
client = AsyncClient()
if STREAM.value:
await function_calling_streaming(client)
else:
await function_calling(client)
if __name__ == "__main__":
app.run(lambda argv: asyncio.run(main(argv)))