Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions basalt/endpoints/publish_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Endpoint for publishing a prompt with a tag
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

from ..utils.dtos import DeploymentTagResponse, PublishPromptDTO


@dataclass
class PublishPromptEndpointResponse:
"""
Response from the publish prompt endpoint
"""
deploymentTag: DeploymentTagResponse
error: Optional[str] = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PublishPromptEndpointResponse":
"""
Create an instance of PublishPromptEndpointResponse from a dictionary.

Args:
data (Dict[str, Any]): The dictionary containing the response data.

Returns:
PublishPromptEndpointResponse
"""
if "error" in data:
return cls(deploymentTag=None, error=data["error"])

return cls(
deploymentTag=DeploymentTagResponse.from_dict(data["deploymentTag"]),
error=None
)


class PublishPromptEndpoint:
"""
Endpoint class for publishing a prompt with a tag.
"""
@staticmethod
def prepare_request(dto: PublishPromptDTO) -> Dict[str, Any]:
"""
Prepare the request dictionary for the PublishPrompt endpoint.

Args:
dto (PublishPromptDTO): The DTO containing publish prompt data.

Returns:
The path, method, and body for publishing a prompt on the API.
"""
body = {
"newTag": dto.new_tag
}

if dto.version:
body["version"] = dto.version

if dto.tag:
body["tag"] = dto.tag

return {
"path": f"/prompts/{dto.slug}/publish",
"method": "POST",
"body": body
}

@staticmethod
def decode_response(
response: dict
) -> Tuple[Optional[Exception], Optional[PublishPromptEndpointResponse]]:
"""
Decode the response returned from the API

Args:
response (dict): The JSON response to encode into a PublishPromptEndpointResponse

Returns:
A tuple containing an optional exception and an optional PublishPromptEndpointResponse.
"""
try:
return None, PublishPromptEndpointResponse.from_dict(response)
except Exception as e:
return e, None
89 changes: 88 additions & 1 deletion basalt/sdk/promptsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@
from ..ressources.monitor.generation_types import GenerationParams, PromptReference
from ..ressources.monitor.trace_types import TraceParams
from ..ressources.prompts.prompt_types import Prompt as IPrompt, PromptParams
from ..utils.dtos import GetPromptDTO, GetPromptResult, PromptResponse, DescribePromptResponse, DescribePromptDTO, DescribeResult, ListResult, PromptListResponse, PromptListDTO
from ..utils.dtos import (
GetPromptDTO,
GetPromptResult,
PromptResponse,
DescribePromptResponse,
DescribePromptDTO,
DescribeResult,
ListResult,
PromptListResponse,
PromptListDTO,
PublishPromptDTO,
PublishPromptResult,
)
from ..utils.protocols import ICache, IApi, ILogger

from ..endpoints.get_prompt import GetPromptEndpoint
from ..endpoints.describe_prompt import DescribePromptEndpoint
from ..endpoints.list_prompts import ListPromptsEndpoint
from ..endpoints.publish_prompt import PublishPromptEndpoint
from ..objects.trace import Trace
from ..objects.generation import Generation
from ..objects.prompt import Prompt
Expand Down Expand Up @@ -326,6 +339,80 @@ def list_sync(self, feature_slug: Optional[str] = None) -> ListResult:
available_tags=prompt.available_tags
) for prompt in result.prompts]

async def publish(
self,
slug: str,
new_tag: str,
version: Optional[str] = None,
tag: Optional[str] = None
) -> PublishPromptResult:
"""
Publish a prompt by assigning a tag to a specific version.

Args:
slug (str): The slug identifier for the prompt.
new_tag (str): The new tag to assign to the prompt version.
version (Optional[str]): The version number to publish.
tag (Optional[str]): The existing tag to publish.

Returns:
Tuple[Optional[Exception], Optional[DeploymentTagResponse]]:
A tuple containing an optional exception and an optional DeploymentTagResponse.
"""
if not version and not tag:
return ValueError("Either version or tag must be provided"), None

dto = PublishPromptDTO(
slug=slug,
new_tag=new_tag,
version=version,
tag=tag
)

err, result = await self._api.invoke(PublishPromptEndpoint, dto)

if err is not None:
return err, None

return None, result.deploymentTag

def publish_sync(
self,
slug: str,
new_tag: str,
version: Optional[str] = None,
tag: Optional[str] = None
) -> PublishPromptResult:
"""
Synchronously publish a prompt by assigning a tag to a specific version.

Args:
slug (str): The slug identifier for the prompt.
new_tag (str): The new tag to assign to the prompt version.
version (Optional[str]): The version number to publish.
tag (Optional[str]): The existing tag to publish.

Returns:
Tuple[Optional[Exception], Optional[DeploymentTagResponse]]:
A tuple containing an optional exception and an optional DeploymentTagResponse.
"""
if not version and not tag:
return ValueError("Either version or tag must be provided"), None

dto = PublishPromptDTO(
slug=slug,
new_tag=new_tag,
version=version,
tag=tag
)

err, result = self._api.invoke_sync(PublishPromptEndpoint, dto)

if err is not None:
return err, None

return None, result.deploymentTag

@staticmethod
def _create_prompt_instance(
prompt_response: PromptResponse,
Expand Down
27 changes: 27 additions & 0 deletions basalt/utils/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,30 @@ class CreateDatasetItemDTO:

# Result types for monitor operations
CreateExperimentResult = Tuple[Optional[Exception], Optional[Experiment]]

# ------------------------------ Publish Prompt ----------------------------- #
@dataclass
class PublishPromptDTO:
"""DTO for publishing a prompt with a tag"""
slug: str
new_tag: str
version: Optional[str] = None
tag: Optional[str] = None


@dataclass
class DeploymentTagResponse:
"""Response for a deployment tag"""
id: str
label: str

@classmethod
def from_dict(cls, data: Dict[str, Any]):
return cls(
id=pick_typed(data, "id", str),
label=pick_typed(data, "label", str),
)


# Result type for publish prompt
PublishPromptResult = Tuple[Optional[Exception], Optional[DeploymentTagResponse]]
3 changes: 3 additions & 0 deletions basalt/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ class NotFound(FetchError):

class NetworkBaseError(FetchError):
pass

class UnprocessableEntity(FetchError):
pass
5 changes: 4 additions & 1 deletion basalt/utils/networker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import aiohttp
from typing import Any, Dict, Optional, Tuple, Mapping

from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized
from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized, UnprocessableEntity
from .protocols import INetworker

class Networker(INetworker):
Expand Down Expand Up @@ -79,6 +79,9 @@ async def fetch(
if response.status == 404:
return NotFound(json_response.get('error', 'Not Found')), None

if response.status == 422:
return UnprocessableEntity(json_response.get('error', 'Unprocessable Entity')), None

response.raise_for_status()

return None, json_response
Expand Down
13 changes: 12 additions & 1 deletion basalt/utils/protocols.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal
from .dtos import GetPromptResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult, CreateExperimentResult
from .dtos import (
GetPromptResult,
DescribeResult,
ListResult,
ListDatasetsResult,
GetDatasetResult,
CreateDatasetItemResult,
CreateExperimentResult,
PublishPromptResult,
)

from ..ressources.monitor.monitorsdk_types import IMonitorSDK

Expand Down Expand Up @@ -41,6 +50,8 @@ async def describe(self, slug: str, tag: Optional[str] = None, version: Optional
def describe_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ...
async def list(self, feature_slug: Optional[str] = None) -> ListResult: ...
def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: ...
async def publish(self, slug: str, new_tag: str, version: Optional[str] = None, tag: Optional[str] = None) -> PublishPromptResult: ...
def publish_sync(self, slug: str, new_tag: str, version: Optional[str] = None, tag: Optional[str] = None) -> PublishPromptResult: ...

class IDatasetSDK(Protocol):
async def list(self) -> ListDatasetsResult: ...
Expand Down
Loading