diff --git a/basalt/prompts/__init__.py b/basalt/prompts/__init__.py index 1b5d257..39c7c0f 100644 --- a/basalt/prompts/__init__.py +++ b/basalt/prompts/__init__.py @@ -18,6 +18,7 @@ PromptModelParameters, PromptParams, PromptResponse, + PromptTools, PublishPromptResponse, ) @@ -33,6 +34,7 @@ "PromptModelParameters", "PromptParams", "PromptResponse", + "PromptTools", "DescribePromptResponse", "PromptListResponse", "PublishPromptResponse", diff --git a/basalt/prompts/models.py b/basalt/prompts/models.py index 87f5d47..27cba56 100644 --- a/basalt/prompts/models.py +++ b/basalt/prompts/models.py @@ -20,6 +20,36 @@ ) +@dataclass(slots=True, frozen=True) +class PromptTools: + """Tools configuration for a prompt. + + Immutable and uses slots to reduce per-instance memory overhead. + """ + tools: list[dict[str, Any]] + tool_choice: dict[str, Any] | None = None + + @classmethod + def from_dict(cls, data: Mapping[str, Any] | None) -> PromptTools | None: + """Create instance from API response mapping. + + Robust against missing keys or wrong types. + """ + if data is None: + return None + + tools_raw = data.get("tools") + tools = list(tools_raw) if isinstance(tools_raw, list) else [] + + tool_choice_raw = data.get("toolChoice") + tool_choice = dict(tool_choice_raw) if isinstance(tool_choice_raw, Mapping) else None + + return cls( + tools=tools, + tool_choice=tool_choice, + ) + + @dataclass(slots=True, frozen=True) class PromptModelParameters: """Model parameters for a prompt. @@ -35,6 +65,7 @@ class PromptModelParameters: frequency_penalty: float | None = None presence_penalty: float | None = None json_object: dict | None = None + tools: PromptTools | None = None @classmethod def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters: @@ -74,6 +105,10 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters: json_object = data.get("jsonObject") json_object = dict(json_object) if isinstance(json_object, Mapping) else None + # Parse tools from parameters.tools + tools_data = data.get("tools") + tools = PromptTools.from_dict(tools_data if isinstance(tools_data, Mapping) else None) + return cls( temperature=temperature, top_k=top_k, @@ -83,6 +118,7 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters: max_length=max_length, response_format=response_format, json_object=json_object, + tools=tools, ) @@ -123,8 +159,6 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModel: version=str(version), parameters=parameters, ) - - @dataclass class PromptParams: """Parameters for creating a new prompt instance.""" @@ -170,6 +204,16 @@ class Prompt: variables: dict[str, Any] | None = None tag: str | None = None + @property + def tools(self) -> PromptTools | None: + """ + Convenience property to access tools from model.parameters.tools. + + Returns: + The PromptTools object if configured, otherwise None. + """ + return self.model.parameters.tools + def compile_variables(self, variables: dict[str, Any]) -> Prompt: """ Compile the prompt variables and render the text and system_text templates. @@ -267,6 +311,7 @@ class PromptContextManager(_PromptContextMixin): raw_system_text: str | None variables: dict[str, Any] | None tag: str | None + tools: PromptTools | None def __init__( self, @@ -355,6 +400,7 @@ class AsyncPromptContextManager(_PromptContextMixin): raw_system_text: str | None variables: dict[str, Any] | None tag: str | None + tools: PromptTools | None def __init__( self, diff --git a/pyproject.toml b/pyproject.toml index 9d7998f..fa5467b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ dev = [ ] +# Pytest configuration # A recommended configuration for Ruff, the modern linter/formatter [tool.ruff] line-length = 160 diff --git a/tests/prompts/test_client.py b/tests/prompts/test_client.py index 2730072..bc54c60 100644 --- a/tests/prompts/test_client.py +++ b/tests/prompts/test_client.py @@ -813,3 +813,125 @@ def test_publish_prompt_sync_parameter_combinations(common_client, slug, new_tag assert body["tag"] == tag else: assert "tag" not in body + + +def test_get_sync_with_tools(common_client): + """Test that tools field is properly parsed and included in the Prompt.""" + client: PromptsClient = common_client["client"] + + with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: + mock_fetch.return_value = make_response({"warning": "", "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are a helpful assistant", + "model": { + "provider": "openai", + "model": "gpt-4", + "version": "1.0", + "parameters": { + "temperature": 0.7, + "maxLength": 100, + "responseFormat": "text", + "tools": { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + } + ], + "toolChoice": {"type": "auto"} + } + } + } + }}) + + prompt = client.get_sync("test-slug") + + assert prompt.tools is not None + assert len(prompt.tools.tools) == 1 + assert prompt.tools.tools[0]["type"] == "function" + assert prompt.tools.tools[0]["function"]["name"] == "get_weather" + assert prompt.tools.tool_choice is not None + assert prompt.tools.tool_choice["type"] == "auto" + + +def test_get_sync_without_tools(common_client): + """Test that prompts without tools field work correctly.""" + client: PromptsClient = common_client["client"] + + with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: + mock_fetch.return_value = make_response({"warning": "", "prompt": { + "text": "Hello", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are a helpful assistant", + "model": { + "provider": "openai", + "model": "gpt-4", + "version": "1.0", + "parameters": { + "temperature": 0.7, + "maxLength": 100, + "responseFormat": "text", + } + } + }}) + + prompt = client.get_sync("test-slug") + + assert prompt.tools is None + + +@pytest.mark.asyncio +async def test_get_async_with_tools(common_client): + """Test that tools field is properly parsed in async get.""" + client: PromptsClient = common_client["client"] + + with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: + mock_fetch.return_value = make_response({"warning": "", "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are a helpful assistant", + "model": { + "provider": "openai", + "model": "gpt-4", + "version": "1.0", + "parameters": { + "temperature": 0.7, + "maxLength": 100, + "responseFormat": "text", + "tools": { + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + } + } + ] + } + } + } + }}) + + prompt = await client.get("test-slug") + + assert prompt.tools is not None + assert len(prompt.tools.tools) == 1 + assert prompt.tools.tools[0]["function"]["name"] == "search" + assert prompt.tools.tool_choice is None