Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions basalt/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
PromptModelParameters,
PromptParams,
PromptResponse,
PromptTools,
PublishPromptResponse,
)

Expand All @@ -33,6 +34,7 @@
"PromptModelParameters",
"PromptParams",
"PromptResponse",
"PromptTools",
"DescribePromptResponse",
"PromptListResponse",
"PublishPromptResponse",
Expand Down
50 changes: 48 additions & 2 deletions basalt/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,36 @@
)


@dataclass(slots=True, frozen=True)
class PromptTools:
"""Tools configuration for a prompt.

Immutable and uses slots to reduce per-instance memory overhead.
"""
tools: list[dict[str, Any]]
tool_choice: dict[str, Any] | None = None

@classmethod
def from_dict(cls, data: Mapping[str, Any] | None) -> PromptTools | None:
"""Create instance from API response mapping.

Robust against missing keys or wrong types.
"""
if data is None:
return None

tools_raw = data.get("tools")
tools = list(tools_raw) if isinstance(tools_raw, list) else []

tool_choice_raw = data.get("toolChoice")
tool_choice = dict(tool_choice_raw) if isinstance(tool_choice_raw, Mapping) else None

return cls(
tools=tools,
tool_choice=tool_choice,
)


@dataclass(slots=True, frozen=True)
class PromptModelParameters:
"""Model parameters for a prompt.
Expand All @@ -35,6 +65,7 @@ class PromptModelParameters:
frequency_penalty: float | None = None
presence_penalty: float | None = None
json_object: dict | None = None
tools: PromptTools | None = None

@classmethod
def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters:
Expand Down Expand Up @@ -74,6 +105,10 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters:
json_object = data.get("jsonObject")
json_object = dict(json_object) if isinstance(json_object, Mapping) else None

# Parse tools from parameters.tools
tools_data = data.get("tools")
tools = PromptTools.from_dict(tools_data if isinstance(tools_data, Mapping) else None)

return cls(
temperature=temperature,
top_k=top_k,
Expand All @@ -83,6 +118,7 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters:
max_length=max_length,
response_format=response_format,
json_object=json_object,
tools=tools,
)


Expand Down Expand Up @@ -123,8 +159,6 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModel:
version=str(version),
parameters=parameters,
)


@dataclass
class PromptParams:
"""Parameters for creating a new prompt instance."""
Expand Down Expand Up @@ -170,6 +204,16 @@ class Prompt:
variables: dict[str, Any] | None = None
tag: str | None = None

@property
def tools(self) -> PromptTools | None:
"""
Convenience property to access tools from model.parameters.tools.

Returns:
The PromptTools object if configured, otherwise None.
"""
return self.model.parameters.tools

def compile_variables(self, variables: dict[str, Any]) -> Prompt:
"""
Compile the prompt variables and render the text and system_text templates.
Expand Down Expand Up @@ -267,6 +311,7 @@ class PromptContextManager(_PromptContextMixin):
raw_system_text: str | None
variables: dict[str, Any] | None
tag: str | None
tools: PromptTools | None

def __init__(
self,
Expand Down Expand Up @@ -355,6 +400,7 @@ class AsyncPromptContextManager(_PromptContextMixin):
raw_system_text: str | None
variables: dict[str, Any] | None
tag: str | None
tools: PromptTools | None

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ dev = [
]


# Pytest configuration
# A recommended configuration for Ruff, the modern linter/formatter
[tool.ruff]
line-length = 160
Expand Down
122 changes: 122 additions & 0 deletions tests/prompts/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,3 +813,125 @@ def test_publish_prompt_sync_parameter_combinations(common_client, slug, new_tag
assert body["tag"] == tag
else:
assert "tag" not in body


def test_get_sync_with_tools(common_client):
"""Test that tools field is properly parsed and included in the Prompt."""
client: PromptsClient = common_client["client"]

with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch:
mock_fetch.return_value = make_response({"warning": "", "prompt": {
"text": "Hello {{name}}",
"slug": "test-slug",
"version": "1.0.0",
"tag": "prod",
"systemText": "You are a helpful assistant",
"model": {
"provider": "openai",
"model": "gpt-4",
"version": "1.0",
"parameters": {
"temperature": 0.7,
"maxLength": 100,
"responseFormat": "text",
"tools": {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}
}
],
"toolChoice": {"type": "auto"}
}
}
}
}})

prompt = client.get_sync("test-slug")

assert prompt.tools is not None
assert len(prompt.tools.tools) == 1
assert prompt.tools.tools[0]["type"] == "function"
assert prompt.tools.tools[0]["function"]["name"] == "get_weather"
assert prompt.tools.tool_choice is not None
assert prompt.tools.tool_choice["type"] == "auto"


def test_get_sync_without_tools(common_client):
"""Test that prompts without tools field work correctly."""
client: PromptsClient = common_client["client"]

with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch:
mock_fetch.return_value = make_response({"warning": "", "prompt": {
"text": "Hello",
"slug": "test-slug",
"version": "1.0.0",
"tag": "prod",
"systemText": "You are a helpful assistant",
"model": {
"provider": "openai",
"model": "gpt-4",
"version": "1.0",
"parameters": {
"temperature": 0.7,
"maxLength": 100,
"responseFormat": "text",
}
}
}})

prompt = client.get_sync("test-slug")

assert prompt.tools is None


@pytest.mark.asyncio
async def test_get_async_with_tools(common_client):
"""Test that tools field is properly parsed in async get."""
client: PromptsClient = common_client["client"]

with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch:
mock_fetch.return_value = make_response({"warning": "", "prompt": {
"text": "Hello {{name}}",
"slug": "test-slug",
"version": "1.0.0",
"tag": "prod",
"systemText": "You are a helpful assistant",
"model": {
"provider": "openai",
"model": "gpt-4",
"version": "1.0",
"parameters": {
"temperature": 0.7,
"maxLength": 100,
"responseFormat": "text",
"tools": {
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web",
}
}
]
}
}
}
}})

prompt = await client.get("test-slug")

assert prompt.tools is not None
assert len(prompt.tools.tools) == 1
assert prompt.tools.tools[0]["function"]["name"] == "search"
assert prompt.tools.tool_choice is None