From efaf42d440dccb9c3714ec6ae55473610cd0024c Mon Sep 17 00:00:00 2001 From: Guillaume Date: Thu, 23 Oct 2025 00:54:37 -0700 Subject: [PATCH 1/2] implement publish method --- basalt/endpoints/publish_prompt.py | 85 ++++++++ basalt/sdk/promptsdk.py | 89 +++++++- basalt/utils/dtos.py | 27 +++ basalt/utils/protocols.py | 13 +- tests/test_publish_prompt.py | 292 ++++++++++++++++++++++++++ tests/test_publish_prompt_endpoint.py | 119 +++++++++++ 6 files changed, 623 insertions(+), 2 deletions(-) create mode 100644 basalt/endpoints/publish_prompt.py create mode 100644 tests/test_publish_prompt.py create mode 100644 tests/test_publish_prompt_endpoint.py diff --git a/basalt/endpoints/publish_prompt.py b/basalt/endpoints/publish_prompt.py new file mode 100644 index 0000000..e0632fd --- /dev/null +++ b/basalt/endpoints/publish_prompt.py @@ -0,0 +1,85 @@ +""" +Endpoint for publishing a prompt with a tag +""" +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from ..utils.dtos import DeploymentTagResponse, PublishPromptDTO + + +@dataclass +class PublishPromptEndpointResponse: + """ + Response from the publish prompt endpoint + """ + deploymentTag: DeploymentTagResponse + error: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PublishPromptEndpointResponse": + """ + Create an instance of PublishPromptEndpointResponse from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary containing the response data. + + Returns: + PublishPromptEndpointResponse + """ + if "error" in data: + return cls(deploymentTag=None, error=data["error"]) + + return cls( + deploymentTag=DeploymentTagResponse.from_dict(data["deploymentTag"]), + error=None + ) + + +class PublishPromptEndpoint: + """ + Endpoint class for publishing a prompt with a tag. + """ + @staticmethod + def prepare_request(dto: PublishPromptDTO) -> Dict[str, Any]: + """ + Prepare the request dictionary for the PublishPrompt endpoint. + + Args: + dto (PublishPromptDTO): The DTO containing publish prompt data. + + Returns: + The path, method, and body for publishing a prompt on the API. + """ + body = { + "newTag": dto.new_tag + } + + if dto.version: + body["version"] = dto.version + + if dto.tag: + body["tag"] = dto.tag + + return { + "path": f"/prompts/{dto.slug}/publish", + "method": "POST", + "body": body + } + + @staticmethod + def decode_response( + response: dict + ) -> Tuple[Optional[Exception], Optional[PublishPromptEndpointResponse]]: + """ + Decode the response returned from the API + + Args: + response (dict): The JSON response to encode into a PublishPromptEndpointResponse + + Returns: + A tuple containing an optional exception and an optional PublishPromptEndpointResponse. + """ + try: + return None, PublishPromptEndpointResponse.from_dict(response) + except Exception as e: + return e, None diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 21c4e64..5a0d9fd 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -3,12 +3,25 @@ from ..ressources.monitor.generation_types import GenerationParams, PromptReference from ..ressources.monitor.trace_types import TraceParams from ..ressources.prompts.prompt_types import Prompt as IPrompt, PromptParams -from ..utils.dtos import GetPromptDTO, GetPromptResult, PromptResponse, DescribePromptResponse, DescribePromptDTO, DescribeResult, ListResult, PromptListResponse, PromptListDTO +from ..utils.dtos import ( + GetPromptDTO, + GetPromptResult, + PromptResponse, + DescribePromptResponse, + DescribePromptDTO, + DescribeResult, + ListResult, + PromptListResponse, + PromptListDTO, + PublishPromptDTO, + PublishPromptResult, +) from ..utils.protocols import ICache, IApi, ILogger from ..endpoints.get_prompt import GetPromptEndpoint from ..endpoints.describe_prompt import DescribePromptEndpoint from ..endpoints.list_prompts import ListPromptsEndpoint +from ..endpoints.publish_prompt import PublishPromptEndpoint from ..objects.trace import Trace from ..objects.generation import Generation from ..objects.prompt import Prompt @@ -326,6 +339,80 @@ def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: available_tags=prompt.available_tags ) for prompt in result.prompts] + async def publish( + self, + slug: str, + new_tag: str, + version: Optional[str] = None, + tag: Optional[str] = None + ) -> PublishPromptResult: + """ + Publish a prompt by assigning a tag to a specific version. + + Args: + slug (str): The slug identifier for the prompt. + new_tag (str): The new tag to assign to the prompt version. + version (Optional[str]): The version number to publish. + tag (Optional[str]): The existing tag to publish. + + Returns: + Tuple[Optional[Exception], Optional[DeploymentTagResponse]]: + A tuple containing an optional exception and an optional DeploymentTagResponse. + """ + if not version and not tag: + return ValueError("Either version or tag must be provided"), None + + dto = PublishPromptDTO( + slug=slug, + new_tag=new_tag, + version=version, + tag=tag + ) + + err, result = await self._api.invoke(PublishPromptEndpoint, dto) + + if err is not None: + return err, None + + return None, result.deploymentTag + + def publish_sync( + self, + slug: str, + new_tag: str, + version: Optional[str] = None, + tag: Optional[str] = None + ) -> PublishPromptResult: + """ + Synchronously publish a prompt by assigning a tag to a specific version. + + Args: + slug (str): The slug identifier for the prompt. + new_tag (str): The new tag to assign to the prompt version. + version (Optional[str]): The version number to publish. + tag (Optional[str]): The existing tag to publish. + + Returns: + Tuple[Optional[Exception], Optional[DeploymentTagResponse]]: + A tuple containing an optional exception and an optional DeploymentTagResponse. + """ + if not version and not tag: + return ValueError("Either version or tag must be provided"), None + + dto = PublishPromptDTO( + slug=slug, + new_tag=new_tag, + version=version, + tag=tag + ) + + err, result = self._api.invoke_sync(PublishPromptEndpoint, dto) + + if err is not None: + return err, None + + return None, result.deploymentTag + @staticmethod def _create_prompt_instance( prompt_response: PromptResponse, diff --git a/basalt/utils/dtos.py b/basalt/utils/dtos.py index 45b83fc..1d45335 100644 --- a/basalt/utils/dtos.py +++ b/basalt/utils/dtos.py @@ -204,3 +204,30 @@ class CreateDatasetItemDTO: # Result types for monitor operations CreateExperimentResult = Tuple[Optional[Exception], Optional[Experiment]] + +# ------------------------------ Publish Prompt ----------------------------- # +@dataclass +class PublishPromptDTO: + """DTO for publishing a prompt with a tag""" + slug: str + new_tag: str + version: Optional[str] = None + tag: Optional[str] = None + + +@dataclass +class DeploymentTagResponse: + """Response for a deployment tag""" + id: str + label: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + return cls( + id=pick_typed(data, "id", str), + label=pick_typed(data, "label", str), + ) + + +# Result type for publish prompt +PublishPromptResult = Tuple[Optional[Exception], Optional[DeploymentTagResponse]] diff --git a/basalt/utils/protocols.py b/basalt/utils/protocols.py index 652e74a..7f4cc8c 100644 --- a/basalt/utils/protocols.py +++ b/basalt/utils/protocols.py @@ -1,5 +1,14 @@ from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal -from .dtos import GetPromptResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult, CreateExperimentResult +from .dtos import ( + GetPromptResult, + DescribeResult, + ListResult, + ListDatasetsResult, + GetDatasetResult, + CreateDatasetItemResult, + CreateExperimentResult, + PublishPromptResult, +) from ..ressources.monitor.monitorsdk_types import IMonitorSDK @@ -41,6 +50,8 @@ async def describe(self, slug: str, tag: Optional[str] = None, version: Optional def describe_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... async def list(self, feature_slug: Optional[str] = None) -> ListResult: ... def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: ... + async def publish(self, slug: str, new_tag: str, version: Optional[str] = None, tag: Optional[str] = None) -> PublishPromptResult: ... + def publish_sync(self, slug: str, new_tag: str, version: Optional[str] = None, tag: Optional[str] = None) -> PublishPromptResult: ... class IDatasetSDK(Protocol): async def list(self) -> ListDatasetsResult: ... diff --git a/tests/test_publish_prompt.py b/tests/test_publish_prompt.py new file mode 100644 index 0000000..4070733 --- /dev/null +++ b/tests/test_publish_prompt.py @@ -0,0 +1,292 @@ +import unittest +from unittest.mock import MagicMock, AsyncMock +from parameterized import parameterized + +from basalt.sdk.promptsdk import PromptSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import PublishPromptDTO, DeploymentTagResponse +from basalt.endpoints.publish_prompt import PublishPromptEndpoint, PublishPromptEndpointResponse + +logger = Logger() + + +class TestPublishPromptSync(unittest.TestCase): + """Test suite for the publish_sync method""" + + def setUp(self): + """Set up test fixtures""" + self.mocked_api = MagicMock() + self.mocked_cache = MagicMock() + self.mocked_cache.get.return_value = None + self.fallback_cache = MagicMock() + self.fallback_cache.get.return_value = None + + # Mock successful response + self.mocked_api.invoke_sync.return_value = (None, PublishPromptEndpointResponse( + deploymentTag=DeploymentTagResponse( + id="test-tag-id-123", + label="production" + ), + error=None + )) + + self.prompt_sdk = PromptSDK( + self.mocked_api, + cache=self.mocked_cache, + fallback_cache=self.fallback_cache, + logger=logger + ) + + def test_uses_correct_endpoint(self): + """Test that publish_sync uses the PublishPromptEndpoint""" + self.prompt_sdk.publish_sync("test-slug", "production", version="1.0.0") + + endpoint = self.mocked_api.invoke_sync.call_args[0][0] + self.assertEqual(endpoint, PublishPromptEndpoint) + + @parameterized.expand([ + # (slug, new_tag, version, tag) + ("my-prompt", "production", "1.0.0", None), + ("my-prompt", "staging", None, "latest"), + ("another-prompt", "prod", "2.5.1", None), + ("test-slug", "custom-tag", None, "development"), + ]) + def test_passes_correct_dto(self, slug, new_tag, version, tag): + """Test that the correct DTO is passed to the API""" + self.prompt_sdk.publish_sync(slug, new_tag, version=version, tag=tag) + + dto = self.mocked_api.invoke_sync.call_args[0][1] + + self.assertEqual(dto, PublishPromptDTO( + slug=slug, + new_tag=new_tag, + version=version, + tag=tag + )) + + def test_returns_success_response(self): + """Test that a successful response is properly returned""" + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production", + version="1.0.0" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + self.assertEqual(deployment_tag.id, "test-tag-id-123") + self.assertEqual(deployment_tag.label, "production") + + def test_forwards_api_error(self): + """Test that API errors are properly forwarded""" + error_message = "Prompt not found" + self.mocked_api.invoke_sync.return_value = (Exception(error_message), None) + + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production", + version="1.0.0" + ) + + self.assertIsInstance(err, Exception) + self.assertEqual(str(err), error_message) + self.assertIsNone(deployment_tag) + + def test_validates_version_or_tag_required(self): + """Test that either version or tag must be provided""" + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production" + # Neither version nor tag provided + ) + + self.assertIsInstance(err, ValueError) + self.assertEqual(str(err), "Either version or tag must be provided") + self.assertIsNone(deployment_tag) + + def test_accepts_version_only(self): + """Test that providing only version works""" + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production", + version="1.0.0" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + + def test_accepts_tag_only(self): + """Test that providing only tag works""" + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production", + tag="latest" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + + def test_handles_both_version_and_tag(self): + """Test that providing both version and tag works""" + err, deployment_tag = self.prompt_sdk.publish_sync( + "test-slug", + "production", + version="1.0.0", + tag="latest" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + + +class TestPublishPromptAsync(unittest.TestCase): + """Test suite for the async publish method""" + + def setUp(self): + """Set up test fixtures""" + self.mocked_api = MagicMock() + self.mocked_api.invoke = AsyncMock() + self.mocked_cache = MagicMock() + self.mocked_cache.get.return_value = None + self.fallback_cache = MagicMock() + self.fallback_cache.get.return_value = None + + # Mock successful response + self.mocked_api.invoke.return_value = (None, PublishPromptEndpointResponse( + deploymentTag=DeploymentTagResponse( + id="async-tag-id-456", + label="staging" + ), + error=None + )) + + self.prompt_sdk = PromptSDK( + self.mocked_api, + cache=self.mocked_cache, + fallback_cache=self.fallback_cache, + logger=logger + ) + + async def test_async_uses_correct_endpoint(self): + """Test that async publish uses the PublishPromptEndpoint""" + await self.prompt_sdk.publish("test-slug", "staging", version="2.0.0") + + endpoint = self.mocked_api.invoke.call_args[0][0] + self.assertEqual(endpoint, PublishPromptEndpoint) + + async def test_async_passes_correct_dto(self): + """Test that the correct DTO is passed to the API in async mode""" + slug = "my-async-prompt" + new_tag = "production" + version = "3.0.0" + + await self.prompt_sdk.publish(slug, new_tag, version=version) + + dto = self.mocked_api.invoke.call_args[0][1] + + self.assertEqual(dto, PublishPromptDTO( + slug=slug, + new_tag=new_tag, + version=version, + tag=None + )) + + async def test_async_returns_success_response(self): + """Test that a successful async response is properly returned""" + err, deployment_tag = await self.prompt_sdk.publish( + "test-slug", + "staging", + version="2.0.0" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + self.assertEqual(deployment_tag.id, "async-tag-id-456") + self.assertEqual(deployment_tag.label, "staging") + + async def test_async_forwards_api_error(self): + """Test that async API errors are properly forwarded""" + error_message = "Version not found" + self.mocked_api.invoke.return_value = (Exception(error_message), None) + + err, deployment_tag = await self.prompt_sdk.publish( + "test-slug", + "staging", + version="2.0.0" + ) + + self.assertIsInstance(err, Exception) + self.assertEqual(str(err), error_message) + self.assertIsNone(deployment_tag) + + async def test_async_validates_version_or_tag_required(self): + """Test that async method validates version or tag requirement""" + err, deployment_tag = await self.prompt_sdk.publish( + "test-slug", + "production" + # Neither version nor tag provided + ) + + self.assertIsInstance(err, ValueError) + self.assertEqual(str(err), "Either version or tag must be provided") + self.assertIsNone(deployment_tag) + + async def test_async_accepts_tag_only(self): + """Test that async method accepts tag only""" + err, deployment_tag = await self.prompt_sdk.publish( + "test-slug", + "production", + tag="latest" + ) + + self.assertIsNone(err) + self.assertIsNotNone(deployment_tag) + + +# Helper to run async tests +def run_async_test(coro): + """Helper function to run async tests""" + import asyncio + loop = asyncio.get_event_loop() + return loop.run_until_complete(coro) + + +# Add async test wrappers +def test_async_uses_correct_endpoint(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_uses_correct_endpoint()) + + +def test_async_passes_correct_dto(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_passes_correct_dto()) + + +def test_async_returns_success_response(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_returns_success_response()) + + +def test_async_forwards_api_error(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_forwards_api_error()) + + +def test_async_validates_version_or_tag_required(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_validates_version_or_tag_required()) + + +def test_async_accepts_tag_only(): + test_case = TestPublishPromptAsync() + test_case.setUp() + run_async_test(test_case.test_async_accepts_tag_only()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_publish_prompt_endpoint.py b/tests/test_publish_prompt_endpoint.py new file mode 100644 index 0000000..bcbf04d --- /dev/null +++ b/tests/test_publish_prompt_endpoint.py @@ -0,0 +1,119 @@ +import unittest +from basalt.endpoints.publish_prompt import PublishPromptEndpoint +from basalt.utils.dtos import PublishPromptDTO + + +class TestPublishPromptEndpoint(unittest.TestCase): + """Test suite for the PublishPromptEndpoint""" + + def test_includes_slug_in_path(self): + """Test that the slug is included in the request path""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="my-unique-prompt-slug", new_tag="production", version="1.0.0") + ) + + self.assertIn("my-unique-prompt-slug", result["path"]) + self.assertEqual(result["path"], "/prompts/my-unique-prompt-slug/publish") + + def test_uses_post_method(self): + """Test that the endpoint uses POST method""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="test-slug", new_tag="prod", version="1.0.0") + ) + + self.assertEqual(result["method"], "POST") + + def test_includes_new_tag_in_body(self): + """Test that newTag is included in the request body""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="production", version="1.0.0") + ) + + self.assertEqual(result["body"]["newTag"], "production") + + def test_includes_version_in_body_when_provided(self): + """Test that version is included in body when provided""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="prod", version="2.5.1") + ) + + self.assertEqual(result["body"]["version"], "2.5.1") + + def test_includes_tag_in_body_when_provided(self): + """Test that tag is included in body when provided""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="prod", tag="latest") + ) + + self.assertEqual(result["body"]["tag"], "latest") + + def test_omits_version_when_not_provided(self): + """Test that version is omitted from body when not provided""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="prod", tag="latest") + ) + + self.assertNotIn("version", result["body"]) + + def test_omits_tag_when_not_provided(self): + """Test that tag is omitted from body when not provided""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="prod", version="1.0.0") + ) + + self.assertNotIn("tag", result["body"]) + + def test_decodes_valid_response(self): + """Test that a valid response is properly decoded""" + response = { + "deploymentTag": { + "id": "test-deployment-tag-id-123", + "label": "production" + } + } + + exception, decoded = PublishPromptEndpoint.decode_response(response) + + self.assertIsNone(exception) + self.assertIsNotNone(decoded) + self.assertEqual(decoded.deploymentTag.id, "test-deployment-tag-id-123") + self.assertEqual(decoded.deploymentTag.label, "production") + self.assertIsNone(decoded.error) + + def test_decodes_error_response(self): + """Test that an error response is properly decoded""" + response = { + "error": "Prompt not found" + } + + exception, decoded = PublishPromptEndpoint.decode_response(response) + + self.assertIsNone(exception) + self.assertIsNotNone(decoded) + self.assertEqual(decoded.error, "Prompt not found") + self.assertIsNone(decoded.deploymentTag) + + def test_handles_malformed_response(self): + """Test that a malformed response returns an exception""" + response = { + "unexpected": "data" + } + + exception, decoded = PublishPromptEndpoint.decode_response(response) + + self.assertIsNotNone(exception) + self.assertIsNone(decoded) + + def test_handles_both_version_and_tag(self): + """Test that both version and tag can be provided""" + result = PublishPromptEndpoint.prepare_request( + PublishPromptDTO(slug="slug", new_tag="prod", version="1.0.0", tag="latest") + ) + + self.assertEqual(result["body"]["version"], "1.0.0") + self.assertEqual(result["body"]["tag"], "latest") + self.assertEqual(result["body"]["newTag"], "prod") + + +if __name__ == '__main__': + unittest.main() From ae86d85db59dbd558a7c158c89b4c60e31b86e21 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Thu, 23 Oct 2025 16:34:52 -0700 Subject: [PATCH 2/2] handle 422 errors --- basalt/utils/errors.py | 3 +++ basalt/utils/networker.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/basalt/utils/errors.py b/basalt/utils/errors.py index c5032c1..313c79d 100644 --- a/basalt/utils/errors.py +++ b/basalt/utils/errors.py @@ -16,3 +16,6 @@ class NotFound(FetchError): class NetworkBaseError(FetchError): pass + +class UnprocessableEntity(FetchError): + pass diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index f8c4aad..ceb5028 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -2,7 +2,7 @@ import aiohttp from typing import Any, Dict, Optional, Tuple, Mapping -from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized +from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized, UnprocessableEntity from .protocols import INetworker class Networker(INetworker): @@ -79,6 +79,9 @@ async def fetch( if response.status == 404: return NotFound(json_response.get('error', 'Not Found')), None + if response.status == 422: + return UnprocessableEntity(json_response.get('error', 'Unprocessable Entity')), None + response.raise_for_status() return None, json_response