diff --git a/basalt/basaltsdk.py b/basalt/basaltsdk.py index 57a451f..04cab7c 100644 --- a/basalt/basaltsdk.py +++ b/basalt/basaltsdk.py @@ -5,6 +5,7 @@ class BasaltSDK(IBasaltSDK): """ The BasaltSDK class implements the IBasaltSDK interface. It serves as the main entry point for interacting with the Basalt SDK. + """ def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK, dataset_sdk: IDatasetSDK): diff --git a/basalt/objects/dataset.py b/basalt/objects/dataset.py index 51cd679..a2577d9 100644 --- a/basalt/objects/dataset.py +++ b/basalt/objects/dataset.py @@ -21,15 +21,9 @@ def to_dict(self) -> Dict[str, Any]: result = { "values": self.values, "metadata": self.metadata, - "name": self.name, - "idealOutput": self.ideal_output + "name": self.name, + "idealOutput": self.ideal_output } - - # if self.name: - # result["name"] = self.name - - # if self.ideal_output: - # result["idealOutput"] = self.ideal_output return result diff --git a/basalt/ressources/monitor/generation_types.py b/basalt/ressources/monitor/generation_types.py index 24f56e8..761aabd 100644 --- a/basalt/ressources/monitor/generation_types.py +++ b/basalt/ressources/monitor/generation_types.py @@ -34,6 +34,7 @@ class GenerationParams(BaseLogParams): input (Optional[str]): The input provided to the model. output (Optional[str]): The output generated by the model. variables (Optional[Dict[str, Any]]): Variables used in the prompt template. + options (Optional[Dict[str, Any]]): Additional options for the generation. Example: ```python @@ -57,6 +58,7 @@ class GenerationParams(BaseLogParams): input: Optional[str] = None output: Optional[str] = None variables: Optional[Dict[str, Any]] = None + options: Optional[Dict[str, Any]] = None @dataclass class Generation(BaseLog): diff --git a/basalt/sdk/datasetsdk.py b/basalt/sdk/datasetsdk.py index c7a6cf6..05b5acc 100644 --- a/basalt/sdk/datasetsdk.py +++ b/basalt/sdk/datasetsdk.py @@ -2,6 +2,7 @@ SDK for interacting with Basalt datasets """ from typing import Dict, List, Optional, Tuple, Any +import asyncio from ..utils.dtos import ( ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO, @@ -46,6 +47,26 @@ def list(self) -> ListDatasetsResult: name=dataset.name, columns=dataset.columns ) for dataset in result.datasets] + + async def async_list(self) -> ListDatasetsResult: + """ + Asynchronously list all datasets available in the workspace. + + Returns: + Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional + exception and an optional list of DatasetDTO objects. + """ + dto = ListDatasetsDTO() + err, result = await self._api.async_invoke(ListDatasetsEndpoint, dto) + + if err is not None: + return err, None + + return None, [DatasetDTO( + slug=dataset.slug, + name=dataset.name, + columns=dataset.columns + ) for dataset in result.datasets] def get(self, slug: str) -> GetDatasetResult: """ @@ -68,6 +89,28 @@ def get(self, slug: str) -> GetDatasetResult: return Exception(result.error), None return None, result.dataset + + async def async_get(self, slug: str) -> GetDatasetResult: + """ + Asynchronously get a dataset by its slug. + + Args: + slug (str): The slug identifier for the dataset. + + Returns: + Tuple[Optional[Exception], Optional[DatasetDTO]]: A tuple containing an optional + exception and an optional DatasetDTO. + """ + dto = GetDatasetDTO(slug=slug) + err, result = await self._api.async_invoke(GetDatasetEndpoint, dto) + + if err is not None: + return err, None + + if result.error: + return Exception(result.error), None + + return None, result.dataset def addRow( self, @@ -108,3 +151,43 @@ def addRow( return Exception(result.error), None, None return None, result.datasetRow, result.warning + + async def async_addRow( + self, + slug: str, + values: Dict[str, str], + name: Optional[str] = None, + ideal_output: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> CreateDatasetItemResult: + """ + Asynchronously create a new item in a dataset. + + Args: + slug (str): The slug identifier for the dataset. + values (Dict[str, str]): A dictionary of column values for the dataset item. + name (Optional[str]): An optional name for the dataset item. + ideal_output (Optional[str]): An optional ideal output for the dataset item. + metadata (Optional[Dict[str, Any]]): An optional metadata dictionary. + + Returns: + Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]]: A tuple containing + an optional exception, an optional DatasetRowDTO, and an optional warning message. + """ + dto = CreateDatasetItemDTO( + slug=slug, + values=values, + name=name, + idealOutput=ideal_output, + metadata=metadata + ) + + err, result = await self._api.async_invoke(CreateDatasetItemEndpoint, dto) + + if err is not None: + return err, None, None + + if result.error: + return Exception(result.error), None, None + + return None, result.datasetRow, result.warning diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index 6277296..dd97189 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -1,4 +1,5 @@ from typing import Dict, Optional, Any, Tuple +import asyncio from ..utils.protocols import IApi, ILogger from ..ressources.monitor.trace_types import TraceParams @@ -40,6 +41,23 @@ def create_experiment( Experiment: A new Experiment instance. """ return self._create_experiment(feature_slug, params) + + async def async_create_experiment( + self, + feature_slug: str, + params: ExperimentParams + ) -> Tuple[Optional[Exception], Optional[Experiment]]: + """ + Asynchronously creates a new experiment for monitoring. + + Args: + feature_slug (str): The feature slug for the experiment. + params (Dict[str, Any]): Parameters for the experiment. + + Returns: + Experiment: A new Experiment instance. + """ + return await self._async_create_experiment(feature_slug, params) def create_trace( @@ -63,6 +81,28 @@ def create_trace( trace_params = TraceParams(**params) return self._create_trace(slug, trace_params) + + async def async_create_trace( + self, + slug: str, + params: Optional[TraceParams] = None + ) -> Trace: + """ + Asynchronously creates a new trace for monitoring. + + Args: + slug (str): The unique identifier for the trace. + params (TraceParams): Parameters for the trace. + + Returns: + Trace: A new Trace instance. + """ + if params is None: + params = {} + + trace_params = TraceParams(**params) + + return await self._async_create_trace(slug, trace_params) def create_generation( self, @@ -79,6 +119,22 @@ def create_generation( """ generation_params = GenerationParams(**params) return self._create_generation(generation_params) + + async def async_create_generation( + self, + params: Dict[str, Any] + ) -> Generation: + """ + Asynchronously creates a new generation for monitoring. + + Args: + params (Dict[str, Any]): Parameters for the generation. + + Returns: + Generation: A new Generation instance. + """ + generation_params = GenerationParams(**params) + return await self._async_create_generation(generation_params) def create_log( self, @@ -95,6 +151,22 @@ def create_log( """ log_params = LogParams(**params) return self._create_log(log_params) + + async def async_create_log( + self, + params: Dict[str, Any] + ) -> Log: + """ + Asynchronously creates a new log for monitoring. + + Args: + params (Dict[str, Any]): Parameters for the log. + + Returns: + Log: A new Log instance. + """ + log_params = LogParams(**params) + return await self._async_create_log(log_params) def _create_experiment( self, @@ -123,6 +195,34 @@ def _create_experiment( return None, Experiment(result.experiment) return err, None + + async def _async_create_experiment( + self, + feature_slug: str, + params: ExperimentParams + ) -> Tuple[Optional[Exception], Optional[Experiment]]: + """ + Internal implementation for asynchronously creating an experiment. + + Args: + feature_slug (str): The feature slug for the experiment. + params (ExperimentParams): Parameters for the experiment. + + Returns: + Experiment: A new Experiment instance. + """ + dto = CreateExperimentDTO( + feature_slug=feature_slug, + name=params.get("name"), + ) + + # Call the API endpoint + err, result = await self._api.async_invoke(CreateExperimentEndpoint, dto) + + if err is None: + return None, Experiment(result.experiment) + + return err, None def _create_trace( @@ -157,6 +257,39 @@ def _create_trace( } trace = Trace(slug, params_dict, flusher, self._logger) return trace + + async def _async_create_trace( + self, + slug: str, + params: TraceParams + ) -> Trace: + """ + Internal implementation for asynchronously creating a trace. + + Args: + slug (str): The unique identifier for the trace. + params (TraceParams): Parameters for the trace. + + Returns: + Trace: A new Trace instance. + """ + flusher = Flusher(self._api, self._logger) + # Convert TraceParams to a dictionary before passing to Trace + params_dict = { + "input": params.input, + "output": params.output, + "name": params.name, + "start_time": params.start_time, + "end_time": params.end_time, + "user": params.user, + "organization": params.organization, + "metadata": params.metadata, + "experiment": params.experiment, + "evaluators": params.evaluators, + "evaluationConfig": params.evaluation_config + } + trace = Trace(slug, params_dict, flusher, self._logger) + return trace def _create_generation( self, @@ -186,6 +319,35 @@ def _create_generation( "options": params.options } return Generation(params_dict) + + async def _async_create_generation( + self, + params: GenerationParams + ) -> Generation: + """ + Internal implementation for asynchronously creating a generation. + + Args: + params (GenerationParams): Parameters for the generation. + + Returns: + Generation: A new Generation instance. + """ + # Convert GenerationParams to a dictionary before passing to Generation + params_dict = { + "name": params.name, + "trace": params.trace, + "prompt": params.prompt, + "input": params.input, + "output": params.output, + "variables": params.variables, + "parent": params.parent, + "metadata": params.metadata, + "start_time": params.start_time, + "end_time": params.end_time, + "options": params.options + } + return Generation(params_dict) def _create_log( self, @@ -194,6 +356,32 @@ def _create_log( """ Internal implementation for creating a log. + Args: + params (LogParams): Parameters for the log. + + Returns: + Log: A new Log instance. + """ + # Convert LogParams to a dictionary before passing to Log + params_dict = { + "name": params.name, + "trace": params.trace, + "input": params.input, + "output": params.output, + "parent": params.parent, + "metadata": params.metadata, + "start_time": params.start_time, + "end_time": params.end_time + } + return Log(params_dict) + + async def _async_create_log( + self, + params: LogParams + ) -> Log: + """ + Internal implementation for asynchronously creating a log. + Args: params (LogParams): Parameters for the log. diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 85c9e05..f909c4e 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Tuple, Any +from typing import Optional, Dict, Tuple, Any, Awaitable from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, GetResult, DescribeResult, ListResult, PromptListResponse, PromptListDTO from ..utils.protocols import ICache, IApi, ILogger @@ -11,6 +11,7 @@ from ..objects.generation import Generation from ..utils.flusher import Flusher from datetime import datetime +import asyncio class PromptSDK: """ @@ -87,6 +88,63 @@ def get( return err, prompt_response, generation return err, None, None + + async def async_get( + self, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + variables: Dict[str, str] = {}, + cache_enabled: bool = True + ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + """ + Asynchronously retrieve a prompt by slug, optionally specifying version and tag. + + Args: + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + variables (dict): A dictionnary of variables to replace in the prompt text. + cache_enabled (bool): Enable or disable cache for this request. + + Returns: + Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. + """ + dto = GetPromptDTO( + slug=slug, + version=version, + tag=tag + ) + + cached = self._cache.get(dto) if cache_enabled else None + + if cached: + original_prompt_text = cached.text + err, prompt_response = self._replace_vars(cached, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + err, result = await self._api.async_invoke(GetPromptEndpoint, dto) + + if err is None: + original_prompt_text = result.prompt.text + self._cache.put(dto, result.prompt, self._cache_duration) + self._fallback_cache.put(dto, result.prompt) + + err, prompt_response = self._replace_vars(result.prompt, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + fallback = self._fallback_cache.get(dto) if cache_enabled else None + + if fallback: + original_prompt_text = fallback.text + err, prompt_response = self._replace_vars(fallback, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + return err, None, None def _prepare_monitoring( self, @@ -135,6 +193,54 @@ def _prepare_monitoring( }) return generation + + async def _async_prepare_monitoring( + self, + prompt: PromptResponse, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + original_prompt_text: Optional[str] = None + ) -> Generation: + """ + Asynchronously prepare monitoring by creating a trace and generation object. + + Args: + prompt (PromptResponse): The prompt response. + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + variables (Optional[Dict[str, Any]]): Variables used in the prompt. + original_prompt_text (Optional[str]): The original prompt text. + + Returns: + Generation: The generation object. + """ + # Create a flusher + flusher = Flusher(self._api, self._logger) + + # Create a trace + trace = Trace(slug, { + "input": original_prompt_text or prompt.text, + "start_time": datetime.now() + }, flusher, self._logger) + + # Create a generation + generation = Generation({ + "name": slug, + "trace": trace, + "prompt": { + "slug": slug, + "version": version, + "tag": tag + }, + "input": original_prompt_text or prompt.text, + "variables": variables, + "options": {"type": "single"} + }) + + return generation def describe( self, @@ -176,6 +282,46 @@ def describe( ) return err, None + + async def async_describe( + self, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + ) -> DescribeResult: + """ + Asynchronously get details about a prompt by slug, optionally specifying version and tag. + + Args: + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + + Returns: + Tuple[Optional[Exception], Optional[DescribePromptResponse]]: A tuple containing an optional exception and an optional DescribePromptResponse. + """ + dto = DescribePromptDTO( + slug=slug, + version=version, + tag=tag + ) + + err, result = await self._api.async_invoke(DescribePromptEndpoint, dto) + + if err is None: + prompt = result.prompt + + return None, DescribePromptResponse( + slug=prompt.slug, + status=prompt.status, + name=prompt.name, + description=prompt.description, + available_versions=prompt.available_versions, + available_tags=prompt.available_tags, + variables=prompt.variables + ) + + return err, None def list(self, feature_slug: Optional[str] = None) -> ListResult: dto = PromptListDTO(featureSlug=feature_slug) @@ -193,6 +339,32 @@ def list(self, feature_slug: Optional[str] = None) -> ListResult: available_versions=prompt.available_versions, available_tags=prompt.available_tags ) for prompt in result.prompts] + + async def async_list(self, feature_slug: Optional[str] = None) -> ListResult: + """ + Asynchronously list prompts, optionally filtering by feature_slug. + + Args: + feature_slug (Optional[str]): Optional feature slug to filter prompts by. + + Returns: + Tuple[Optional[Exception], Optional[List[PromptListResponse]]]: A tuple containing an optional exception and an optional list of PromptListResponse objects. + """ + dto = PromptListDTO(featureSlug=feature_slug) + + err, result = await self._api.async_invoke(ListPromptsEndpoint, dto) + + if err is not None: + return err, None + + return None, [PromptListResponse( + slug=prompt.slug, + status=prompt.status, + name=prompt.name, + description=prompt.description, + available_versions=prompt.available_versions, + available_tags=prompt.available_tags + ) for prompt in result.prompts] def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): missing_vars, replaced = replace_variables(prompt.text, variables) diff --git a/basalt/utils/api.py b/basalt/utils/api.py index 7bb928c..9fd4bf2 100644 --- a/basalt/utils/api.py +++ b/basalt/utils/api.py @@ -1,6 +1,7 @@ from typing import Dict, TypeVar, Optional, Tuple from .protocols import IEndpoint, INetworker, ILogger +import asyncio from .networker import Networker Input = TypeVar('Input') @@ -83,3 +84,38 @@ def _headers(self) -> Dict[str, str]: 'X-BASALT-SDK-TYPE': self._sdk_type, 'Content-Type': 'application/json' } + + async def async_invoke( + self, + endpoint: IEndpoint[Input, Output], + dto: Optional[Input] = None + ) -> Tuple[Optional[Exception], Optional[Output]]: + """ + Asynchronously invoke an API endpoint with the given data transfer object (DTO). + + Args: + endpoint: The endpoint to be invoked. + dto: The data transfer object to be sent to the endpoint. + + Returns: + A tuple containing an optional exception and an optional output. + """ + # Prepare the request information using the endpoint and input data + if dto is None: + request_info = endpoint.prepare_request() + else: + request_info = endpoint.prepare_request(dto) + + # Fetch the result from the network using the prepared request information + error, result = await self._network.async_fetch( + self._root + request_info['path'], + request_info['method'], + request_info.get('body'), + params=request_info.get('query', {}), + headers=self._headers(), + ) + + if error: + return error, None + + return endpoint.decode_response(result) diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index c21c305..af87960 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -1,4 +1,5 @@ import requests +import aiohttp from typing import Any, Dict, Optional, Tuple from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized @@ -64,3 +65,56 @@ def fetch( except Exception as e: return NetworkBaseError(str(e)), None + + async def async_fetch( + self, + url: str, + method: str, + body = None, + headers = None, + params = None + ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: + """ + Asynchronously fetch data from a given URL using the specified HTTP method. + + Args: + url (str): The URL to fetch data from. + method (str): The HTTP method to use (e.g., 'GET', 'POST'). + body (Optional[Any]): The request payload to send (default is None). + headers (Optional[Dict[str, str]]): The request headers to send (default is None). + params (Optional[Dict[str, str]]): The query parameters to send (default is None). + + Returns: + A result tuple (err, json_response), possible responses: + - (None, json_response) + - (FetchError, None) + """ + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method, + url, + params=params, + json=body, + headers=headers + ) as response: + json_response = await response.json() + + if response.status == 400: + return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None + + if response.status == 401: + return Unauthorized(json_response.get('error', 'Unauthorized')), None + + if response.status == 403: + return Forbidden(json_response.get('error', 'Forbidden')), None + + if response.status == 404: + return NotFound(json_response.get('error', 'Not Found')), None + + response.raise_for_status() + + return None, json_response + + except Exception as e: + return NetworkBaseError(str(e)), None diff --git a/examples/dataset_sdk_async_demo.ipynb b/examples/dataset_sdk_async_demo.ipynb new file mode 100644 index 0000000..a11ac85 --- /dev/null +++ b/examples/dataset_sdk_async_demo.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt DatasetSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the DatasetSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously List All Datasets\n", + "\n", + "This example demonstrates how to list all datasets asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def list_datasets():\n", + " print(\"Listing all datasets asynchronously...\")\n", + " err, datasets = await basalt.datasets.async_list()\n", + " if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + " else:\n", + " print(f\"Found {len(datasets)} datasets\")\n", + " for dataset in datasets:\n", + " print(f\"- {dataset.name} (slug: {dataset.slug})\")\n", + " return datasets\n", + "\n", + "# Run the async function\n", + "datasets = await list_datasets()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Get a Specific Dataset\n", + "\n", + "This example demonstrates how to retrieve a specific dataset by its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_dataset(datasets):\n", + " print(\"\\nGetting a specific dataset asynchronously...\")\n", + " if len(datasets) > 0:\n", + " sample_dataset = datasets[0]\n", + " err, dataset = await basalt.datasets.async_get(sample_dataset.slug)\n", + " if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + " else:\n", + " print(f\"Retrieved dataset: {dataset.name}\")\n", + " print(f\"Columns: {dataset.columns}\")\n", + " print(f\"Number of rows: {len(dataset.rows) if dataset.rows else 0}\")\n", + " return sample_dataset, dataset\n", + " else:\n", + " print(\"No datasets available\")\n", + " return None, None\n", + "\n", + "# Run the async function\n", + "sample_dataset, dataset = await get_dataset(datasets)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Add a Row to a Dataset\n", + "\n", + "This example demonstrates how to add a new row to an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def add_row(sample_dataset):\n", + " print(\"\\nAdding a row to a dataset asynchronously...\")\n", + " if sample_dataset:\n", + " # Create some sample values for the dataset row\n", + " values = {column: f\"Sample {column} value\" for column in sample_dataset.columns}\n", + " \n", + " err, row, warning = await basalt.datasets.async_addRow(\n", + " slug=sample_dataset.slug,\n", + " values=values,\n", + " name=\"Async Sample Row\",\n", + " ideal_output=\"Expected output for this row\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error adding row to dataset: {err}\")\n", + " elif warning:\n", + " print(f\"Row added with warning: {warning}\")\n", + " print(f\"Row values: {row.values}\")\n", + " else:\n", + " print(f\"Row added successfully\")\n", + " print(f\"Row values: {row.values}\")\n", + " print(f\"Row name: {row.name}\")\n", + "\n", + "# Run the async function\n", + "await add_row(sample_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Get a Dataset as an Object\n", + "\n", + "This example demonstrates how to retrieve a dataset as an object for further manipulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_dataset_object(sample_dataset):\n", + " print(\"\\nGetting a dataset as an object asynchronously...\")\n", + " if sample_dataset:\n", + " dataset_obj = await basalt.datasets.async_get_dataset_object(sample_dataset.slug)\n", + " if dataset_obj:\n", + " print(f\"Retrieved dataset object: {dataset_obj.name}\")\n", + " print(f\"Number of rows: {len(dataset_obj.rows)}\")\n", + " else:\n", + " print(\"Failed to get dataset object\")\n", + " return dataset_obj\n", + " else:\n", + " print(\"No sample dataset available\")\n", + " return None\n", + "\n", + "# Run the async function\n", + "dataset_obj = await get_dataset_object(sample_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Asynchronously Add a Row to a Dataset Object\n", + "\n", + "This example demonstrates how to add a row directly to a dataset object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def add_row_to_dataset_object(dataset_obj):\n", + " print(\"\\nAdding a row to a dataset object asynchronously...\")\n", + " if dataset_obj:\n", + " # Create some sample values for the dataset row\n", + " values = {column: f\"Another sample {column} value\" for column in dataset_obj.columns}\n", + " \n", + " row = await basalt.datasets.async_add_row_to_dataset(\n", + " dataset=dataset_obj,\n", + " values=values,\n", + " name=\"Another Async Sample Row\",\n", + " ideal_output=\"Another expected output\",\n", + " metadata={\"source\": \"async_dataset_object_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " if row:\n", + " print(f\"Row added to dataset object successfully\")\n", + " print(f\"Row values: {row.values}\")\n", + " print(f\"Updated number of rows in dataset object: {len(dataset_obj.rows)}\")\n", + " else:\n", + " print(\"Failed to add row to dataset object\")\n", + "\n", + "# Run the async function\n", + "await add_row_to_dataset_object(dataset_obj)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 6: Execute Multiple Dataset Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations(sample_dataset):\n", + " print(\"\\nExecuting multiple dataset operations concurrently...\")\n", + " \n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.datasets.async_list(),\n", + " basalt.datasets.async_get(sample_dataset.slug) if sample_dataset else None\n", + " ]\n", + " \n", + " # Filter out None tasks\n", + " tasks = [t for t in tasks if t is not None]\n", + " \n", + " if tasks:\n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"Number of datasets listed: {len(results[0][1]) if results[0][1] else 'Error'}\")\n", + " if len(tasks) > 1:\n", + " print(f\"Retrieved dataset: {results[1][1].name if results[1][1] else 'Error'}\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations(sample_dataset)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb index 29c4fb4..e732583 100644 --- a/examples/dataset_sdk_demo.ipynb +++ b/examples/dataset_sdk_demo.ipynb @@ -1,232 +1,232 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Basalt Dataset SDK Demo\n", - "\n", - "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", - "\n", - "os.environ[\"BASALT_BUILD\"] = \"development\"\n", - "\n", - "from basalt import Basalt\n", - "\n", - "# Initialize the SDK\n", - "basalt = Basalt(\n", - " api_key=\"sk-df55a...\", # Replace with your API key\n", - " log_level=\"debug\" # Optional: Set log level\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Listing Available Datasets\n", - "\n", - "Retrieve all datasets available in your workspace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List all datasets in the workspace\n", - "err, datasets = basalt.datasets.list()\n", - "\n", - "if err:\n", - " print(f\"Error listing datasets: {err}\")\n", - "else:\n", - " print(f\"Found {len(datasets)} datasets:\")\n", - " for i, dataset in enumerate(datasets):\n", - " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", - " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", - " \n", - " # Store the first dataset slug for later use (if available)\n", - " first_dataset_slug = datasets[0].slug if datasets else None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Getting a Specific Dataset\n", - "\n", - "Retrieve details for a specific dataset using its slug." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the first dataset from the list or replace with a specific slug\n", - "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", - "\n", - "err, dataset = basalt.datasets.get(dataset_slug)\n", - "\n", - "if err:\n", - " print(f\"Error getting dataset: {err}\")\n", - "else:\n", - " print(f\"Dataset details for '{dataset.name}'\")\n", - " print(f\"Slug: {dataset.slug}\")\n", - " print(f\"Columns: {', '.join(dataset.columns)}\")\n", - " print(f\"Number of rows: {len(dataset.rows)}\")\n", - " \n", - " if dataset.rows:\n", - " print(\"\\nSample rows:\")\n", - " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", - " print(f\"Row {i+1}:\")\n", - " print(f\" Values: {row.get('values')}\")\n", - " if 'name' in row:\n", - " print(f\" Name: {row['name']}\")\n", - " if 'idealOutput' in row:\n", - " print(f\" Ideal output: {row['idealOutput']}\")\n", - " if 'metadata' in row:\n", - " print(f\" Metadata: {row['metadata']}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Adding a Row to a Dataset\n", - "\n", - "Create a new row (item) in an existing dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the dataset from the previous example\n", - "if 'dataset' in locals() and dataset:\n", - " # Build values for all columns in the dataset\n", - " values = {}\n", - " for column in dataset.columns:\n", - " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", - " \n", - " # Create the row\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=dataset.slug,\n", - " values=values,\n", - " name=\"Notebook Example Row\",\n", - " ideal_output=\"This is an ideal output for this row\",\n", - " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " else:\n", - " print(\"Successfully created new dataset row:\")\n", - " print(f\"Values: {row.values}\")\n", - " print(f\"Name: {row.name}\")\n", - " print(f\"Ideal output: {row.idealOutput}\")\n", - " print(f\"Metadata: {row.metadata}\")\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - "else:\n", - " print(\"Please run the previous cell to get a dataset first\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Error Handling with Dataset SDK\n", - "\n", - "Demonstrate proper error handling when working with datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", - " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", - " try:\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=slug,\n", - " values=values,\n", - " name=name,\n", - " ideal_output=ideal_output,\n", - " metadata=metadata\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " return None\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - " \n", - " return row\n", - " except Exception as e:\n", - " print(f\"Unexpected error: {str(e)}\")\n", - " return None\n", - "\n", - "# Test with a valid dataset\n", - "if 'dataset_slug' in locals() and dataset_slug:\n", - " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", - " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", - " \n", - " if row:\n", - " print(f\"Successfully created row: {row.name}\")\n", - "\n", - "# Test with an invalid dataset slug\n", - "print(\"\\nTesting with invalid dataset slug:\")\n", - "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", - "print(f\"Result with invalid slug: {invalid_row}\")\n", - "\n", - "# Test with missing required values\n", - "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", - " print(\"\\nTesting with missing required values:\")\n", - " # Deliberately create incomplete values dict\n", - " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", - " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", - " print(f\"Result with incomplete values: {incomplete_row}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt Dataset SDK Demo\n", + "\n", + "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "os.environ[\"BASALT_BUILD\"] = \"development\"\n", + "\n", + "from basalt import Basalt\n", + "\n", + "# Initialize the SDK\n", + "basalt = Basalt(\n", + " api_key=\"sk-...\", # Replace with your API key\n", + " log_level=\"debug\" # Optional: Set log level\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Listing Available Datasets\n", + "\n", + "Retrieve all datasets available in your workspace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List all datasets in the workspace\n", + "err, datasets = basalt.datasets.list()\n", + "\n", + "if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + "else:\n", + " print(f\"Found {len(datasets)} datasets:\")\n", + " for i, dataset in enumerate(datasets):\n", + " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", + " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", + " \n", + " # Store the first dataset slug for later use (if available)\n", + " first_dataset_slug = datasets[0].slug if datasets else None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Getting a Specific Dataset\n", + "\n", + "Retrieve details for a specific dataset using its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the first dataset from the list or replace with a specific slug\n", + "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", + "\n", + "err, dataset = basalt.datasets.get(dataset_slug)\n", + "\n", + "if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + "else:\n", + " print(f\"Dataset details for '{dataset.name}'\")\n", + " print(f\"Slug: {dataset.slug}\")\n", + " print(f\"Columns: {', '.join(dataset.columns)}\")\n", + " print(f\"Number of rows: {len(dataset.rows)}\")\n", + " \n", + " if dataset.rows:\n", + " print(\"\\nSample rows:\")\n", + " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", + " print(f\"Row {i+1}:\")\n", + " print(f\" Values: {row.get('values')}\")\n", + " if 'name' in row:\n", + " print(f\" Name: {row['name']}\")\n", + " if 'idealOutput' in row:\n", + " print(f\" Ideal output: {row['idealOutput']}\")\n", + " if 'metadata' in row:\n", + " print(f\" Metadata: {row['metadata']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Adding a Row to a Dataset\n", + "\n", + "Create a new row (item) in an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the dataset from the previous example\n", + "if 'dataset' in locals() and dataset:\n", + " # Build values for all columns in the dataset\n", + " values = {}\n", + " for column in dataset.columns:\n", + " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", + " \n", + " # Create the row\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=dataset.slug,\n", + " values=values,\n", + " name=\"Notebook Example Row\",\n", + " ideal_output=\"This is an ideal output for this row\",\n", + " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " else:\n", + " print(\"Successfully created new dataset row:\")\n", + " print(f\"Values: {row.values}\")\n", + " print(f\"Name: {row.name}\")\n", + " print(f\"Ideal output: {row.idealOutput}\")\n", + " print(f\"Metadata: {row.metadata}\")\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + "else:\n", + " print(\"Please run the previous cell to get a dataset first\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Error Handling with Dataset SDK\n", + "\n", + "Demonstrate proper error handling when working with datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", + " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", + " try:\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=slug,\n", + " values=values,\n", + " name=name,\n", + " ideal_output=ideal_output,\n", + " metadata=metadata\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " return None\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + " \n", + " return row\n", + " except Exception as e:\n", + " print(f\"Unexpected error: {str(e)}\")\n", + " return None\n", + "\n", + "# Test with a valid dataset\n", + "if 'dataset_slug' in locals() and dataset_slug:\n", + " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", + " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", + " \n", + " if row:\n", + " print(f\"Successfully created row: {row.name}\")\n", + "\n", + "# Test with an invalid dataset slug\n", + "print(\"\\nTesting with invalid dataset slug:\")\n", + "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", + "print(f\"Result with invalid slug: {invalid_row}\")\n", + "\n", + "# Test with missing required values\n", + "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", + " print(\"\\nTesting with missing required values:\")\n", + " # Deliberately create incomplete values dict\n", + " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", + " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", + " print(f\"Result with incomplete values: {incomplete_row}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/examples/monitor_sdk_async_demo.ipynb b/examples/monitor_sdk_async_demo.ipynb new file mode 100644 index 0000000..1c756af --- /dev/null +++ b/examples/monitor_sdk_async_demo.ipynb @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt MonitorSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the MonitorSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os\n", + "from basalt.ressources.monitor.monitorsdk_types import (\n", + " ExperimentParams, TraceParams, GenerationParams, LogParams\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously Create a Trace\n", + "\n", + "This example demonstrates how to create a trace asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_trace():\n", + " print(\"Creating a trace asynchronously...\")\n", + " trace_params = TraceParams(\n", + " name=\"Async Test Trace\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " trace = await basalt.monitor.async_create_trace(\n", + " slug=\"async-test-trace\",\n", + " params=trace_params\n", + " )\n", + " \n", + " print(f\"Created trace: {trace.id}\")\n", + " print(f\"Trace name: {trace.name}\")\n", + " print(f\"Trace slug: {trace.slug}\")\n", + " \n", + " return trace\n", + "\n", + "# Run the async function\n", + "trace = await create_trace()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Create a Generation\n", + "\n", + "This example demonstrates how to create a generation associated with a trace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_generation(trace):\n", + " print(\"\\nCreating a generation asynchronously...\")\n", + " gen_params = GenerationParams(\n", + " trace_id=trace.id,\n", + " text=\"This is an async test generation\",\n", + " model_id=\"gpt-4\",\n", + " prompt=\"Generate a response asynchronously\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " generation = await basalt.monitor.async_create_generation(gen_params)\n", + " \n", + " print(f\"Created generation: {generation.id}\")\n", + " print(f\"Generation text: {generation.text}\")\n", + " print(f\"Generation model: {generation.model_id}\")\n", + " \n", + " return generation\n", + "\n", + "# Run the async function\n", + "generation = await create_generation(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Create a Log\n", + "\n", + "This example demonstrates how to create a log entry associated with a trace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_log(trace):\n", + " print(\"\\nCreating a log asynchronously...\")\n", + " log_params = LogParams(\n", + " trace_id=trace.id,\n", + " type=\"info\",\n", + " message=\"This is an async test log message\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " log = await basalt.monitor.async_create_log(log_params)\n", + " \n", + " print(f\"Created log: {log['id']}\")\n", + " print(f\"Log message: {log['message']}\")\n", + " print(f\"Log type: {log['type']}\")\n", + " \n", + " return log\n", + "\n", + "# Run the async function\n", + "log = await create_log(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Create an Experiment\n", + "\n", + "This example demonstrates how to create an experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_experiment():\n", + " print(\"\\nCreating an experiment asynchronously...\")\n", + " exp_params = ExperimentParams(\n", + " type=\"A/B Test\",\n", + " name=\"Async Test Experiment\",\n", + " setup={\n", + " \"control_id\": \"control-prompt\",\n", + " \"variation_id\": \"test-prompt\"\n", + " }\n", + " )\n", + " \n", + " experiment = await basalt.monitor.async_create_experiment(\n", + " \"async-test-feature\",\n", + " exp_params\n", + " )\n", + " \n", + " print(f\"Created experiment: {experiment.id}\")\n", + " print(f\"Experiment name: {experiment.name}\")\n", + " print(f\"Experiment type: {experiment.type}\")\n", + " \n", + " return experiment\n", + "\n", + "# Run the async function\n", + "experiment = await create_experiment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Execute Multiple Monitoring Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous monitoring operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations():\n", + " print(\"\\nExecuting multiple monitoring operations concurrently...\")\n", + " \n", + " # Create trace parameters for concurrent operations\n", + " trace_params1 = TraceParams(\n", + " name=\"Concurrent Trace 1\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 1}\n", + " )\n", + " \n", + " trace_params2 = TraceParams(\n", + " name=\"Concurrent Trace 2\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 2}\n", + " )\n", + " \n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.monitor.async_create_trace(\"concurrent-trace-1\", trace_params1),\n", + " basalt.monitor.async_create_trace(\"concurrent-trace-2\", trace_params2)\n", + " ]\n", + " \n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"First trace: {results[0].name} (id: {results[0].id})\")\n", + " print(f\"Second trace: {results[1].name} (id: {results[1].id})\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/prompt_sdk_async_demo.ipynb b/examples/prompt_sdk_async_demo.ipynb new file mode 100644 index 0000000..ca996ce --- /dev/null +++ b/examples/prompt_sdk_async_demo.ipynb @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt PromptSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the PromptSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously List All Prompts\n", + "\n", + "This example demonstrates how to list all prompts asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def list_prompts():\n", + " print(\"Listing all prompts asynchronously...\")\n", + " err, prompts = await basalt.prompt.async_list()\n", + " if err:\n", + " print(f\"Error listing prompts: {err}\")\n", + " else:\n", + " print(f\"Found {len(prompts)} prompts\")\n", + " for prompt in prompts:\n", + " print(f\"- {prompt.name} (slug: {prompt.slug})\")\n", + " return prompts\n", + "\n", + "# Run the async function\n", + "prompts = await list_prompts()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Get a Specific Prompt\n", + "\n", + "This example demonstrates how to retrieve a specific prompt by its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_prompt(prompts):\n", + " print(\"\\nGetting a specific prompt asynchronously...\")\n", + " if len(prompts) > 0:\n", + " sample_prompt = prompts[0]\n", + " err, prompt_response, generation = await basalt.prompt.async_get(sample_prompt.slug)\n", + " if err:\n", + " print(f\"Error getting prompt: {err}\")\n", + " else:\n", + " print(f\"Retrieved prompt: {sample_prompt.name}\")\n", + " print(f\"Text: {prompt_response.text}\")\n", + " return sample_prompt, prompt_response, generation\n", + " else:\n", + " print(\"No prompts available\")\n", + " return None, None, None\n", + "\n", + "# Run the async function\n", + "sample_prompt, prompt_response, generation = await get_prompt(prompts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Describe a Prompt\n", + "\n", + "This example demonstrates how to get detailed description information about a prompt." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def describe_prompt(sample_prompt):\n", + " print(\"\\nDescribing a prompt asynchronously...\")\n", + " if sample_prompt:\n", + " err, description = await basalt.prompt.async_describe(sample_prompt.slug)\n", + " if err:\n", + " print(f\"Error describing prompt: {err}\")\n", + " else:\n", + " print(f\"Prompt: {description.prompt.name}\")\n", + " print(f\"Versions available: {len(description.versions)}\")\n", + " for version in description.versions:\n", + " print(f\"- Version {version.version} created at {version.created_at}\")\n", + " return description\n", + " else:\n", + " print(\"No sample prompt available\")\n", + " return None\n", + "\n", + "# Run the async function\n", + "description = await describe_prompt(sample_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Get a Prompt with Variable Substitution\n", + "\n", + "This example demonstrates how to retrieve a prompt with variables substituted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_prompt_with_variables(sample_prompt):\n", + " print(\"\\nGetting a prompt with variable substitution asynchronously...\")\n", + " if sample_prompt:\n", + " err, prompt_response, generation = await basalt.prompt.async_get(\n", + " sample_prompt.slug,\n", + " variables={\"name\": \"John\", \"company\": \"Acme Inc\"}\n", + " )\n", + " if err:\n", + " print(f\"Error getting prompt with variables: {err}\")\n", + " else:\n", + " print(f\"Retrieved prompt with variables: {sample_prompt.name}\")\n", + " print(f\"Text with variables: {prompt_response.text}\")\n", + " return prompt_response, generation\n", + " else:\n", + " print(\"No sample prompt available\")\n", + " return None, None\n", + "\n", + "# Run the async function\n", + "prompt_with_vars, generation_with_vars = await get_prompt_with_variables(sample_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Execute Multiple Prompt Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations(prompts):\n", + " print(\"\\nExecuting multiple prompt operations concurrently...\")\n", + " if len(prompts) >= 2:\n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.prompt.async_get(prompts[0].slug),\n", + " basalt.prompt.async_get(prompts[1].slug),\n", + " basalt.prompt.async_list()\n", + " ]\n", + " \n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"First prompt: {results[0][1].slug if results[0][1] else 'Error'}\")\n", + " print(f\"Second prompt: {results[1][1].slug if results[1][1] else 'Error'}\")\n", + " print(f\"Number of prompts listed: {len(results[2][1]) if results[2][1] else 'Error'}\")\n", + " else:\n", + " print(\"Not enough prompts available for concurrent operations example\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations(prompts)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.py b/setup.py index 83d073f..eb0a52a 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ def get_version(): packages=find_packages(), install_requires=[ "requests>=2.32", + "aiohttp>=3.8.0", ], python_requires=">=3.6" ) diff --git a/tests/test_datasetsdk_async.py b/tests/test_datasetsdk_async.py new file mode 100644 index 0000000..bd1fc2d --- /dev/null +++ b/tests/test_datasetsdk_async.py @@ -0,0 +1,203 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from basalt.sdk.datasetsdk import DatasetSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import DatasetDTO, DatasetRowDTO, ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO +from basalt.endpoints.list_datasets import ListDatasetsEndpoint, ListDatasetsEndpointResponse +from basalt.endpoints.get_dataset import GetDatasetEndpoint, GetDatasetEndpointResponse +from basalt.endpoints.create_dataset_item import CreateDatasetItemEndpoint, CreateDatasetItemEndpointResponse + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock responses for different endpoints - same as in test_datasetsdk.py +dataset_list_response = ListDatasetsEndpointResponse( + datasets=[ + DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"] + ), + DatasetDTO( + slug="another-dataset", + name="Another Dataset", + columns=["col1", "col2", "col3"] + ) + ] +) + +dataset_get_response = GetDatasetEndpointResponse( + dataset=DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"], + rows=[ + { + "values": { + "input": "Sample input", + "output": "Sample output" + }, + "name": "Sample Row", + "idealOutput": "Ideal output", + "metadata": {"source": "test"} + } + ] + ), + error=None +) + +dataset_add_row_response = CreateDatasetItemEndpointResponse( + datasetRow=DatasetRowDTO( + values={"input": "New input", "output": "New output"}, + name="New Row", + idealOutput="New ideal output", + metadata={"source": "test"} + ), + warning=None, + error=None +) + + +class TestDatasetSDKAsync(unittest.TestCase): + def setUp(self): + self.dataset_sdk = DatasetSDK( + api=mocked_api, + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_list_datasets(self): + """Test asynchronously listing all datasets""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_list_response) + + # Call the method + err, datasets = await self.dataset_sdk.async_list() + + # Assertions + self.assertIsNone(err) + self.assertEqual(len(datasets), 2) + self.assertEqual(datasets[0].slug, "test-dataset") + self.assertEqual(datasets[0].name, "Test Dataset") + self.assertEqual(datasets[1].slug, "another-dataset") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, ListDatasetsEndpoint) + + async def test_async_get_dataset(self): + """Test asynchronously getting a dataset by slug""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_get_response) + + # Call the method + err, dataset = await self.dataset_sdk.async_get("test-dataset") + + # Assertions + self.assertIsNone(err) + self.assertEqual(dataset.slug, "test-dataset") + self.assertEqual(dataset.name, "Test Dataset") + self.assertEqual(len(dataset.columns), 2) + self.assertEqual(len(dataset.rows), 1) + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, GetDatasetEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + + async def test_async_create_dataset_item(self): + """Test asynchronously creating a dataset item""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_add_row_response) + + # Call the method + values = {"input": "New input", "output": "New output"} + err, row, warning = await self.dataset_sdk.async_addRow( + slug="test-dataset", + values=values, + name="New Row", + ideal_output="New ideal output", + metadata={"source": "test"} + ) + + # Assertions + self.assertIsNone(err) + self.assertIsNone(warning) + self.assertEqual(row.values, values) + self.assertEqual(row.name, "New Row") + self.assertEqual(row.idealOutput, "New ideal output") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateDatasetItemEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + self.assertEqual(dto.values, values) + self.assertEqual(dto.name, "New Row") + self.assertEqual(dto.idealOutput, "New ideal output") + + async def test_async_error_handling_get_dataset(self): + """Test error handling when asynchronously getting a dataset""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + err, dataset = await self.dataset_sdk.async_get("non-existent") + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(dataset) + self.assertEqual(str(err), "API Error") + + + +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + + def close(self): + self.loop.close() + + +if __name__ == "__main__": + runner = AsyncTestRunner() + try: + runner.run_test_case(TestDatasetSDKAsync) + finally: + runner.close() diff --git a/tests/test_monitorsdk_async.py b/tests/test_monitorsdk_async.py new file mode 100644 index 0000000..ec117ee --- /dev/null +++ b/tests/test_monitorsdk_async.py @@ -0,0 +1,182 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from basalt.sdk.monitorsdk import MonitorSDK +from basalt.utils.logger import Logger +from basalt.ressources.monitor.experiment_types import ExperimentParams +from basalt.ressources.monitor.trace_types import TraceParams +from basalt.ressources.monitor.generation_types import GenerationParams +from basalt.ressources.monitor.log_types import LogParams +from basalt.objects.experiment import Experiment +from basalt.objects.trace import Trace +from basalt.objects.generation import Generation +from basalt.objects.log import Log +from basalt.endpoints.monitor.create_experiment import CreateExperimentEndpoint, Output as ExperimentOutput + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock experiment data that matches the actual Experiment structure +experiment_data = { + "id": "exp-123", + "featureSlug": "test-feature", + "name": "Test Experiment", + "createdAt": "2023-01-01T00:00:00Z" +} + +experiment_output = ExperimentOutput( + experiment=type('Experiment', (), experiment_data)() +) + +# Set experiment attributes properly +experiment_output.experiment.id = experiment_data["id"] +experiment_output.experiment.feature_slug = experiment_data["featureSlug"] +experiment_output.experiment.name = experiment_data["name"] +experiment_output.experiment.created_at = experiment_data["createdAt"] + + +class TestMonitorSDKAsync(unittest.TestCase): + def setUp(self): + self.monitor_sdk = MonitorSDK( + api=mocked_api, + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_create_experiment(self): + """Test asynchronously creating an experiment""" + # Configure mock + mocked_api.async_invoke.return_value = (None, experiment_output) + + # Call the method + params = {"name": "Test Experiment"} + err, result = await self.monitor_sdk.async_create_experiment("test-feature", params) + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(result) + self.assertEqual(result.id, "exp-123") + self.assertEqual(result.feature_slug, "test-feature") + self.assertEqual(result.name, "Test Experiment") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateExperimentEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.feature_slug, "test-feature") + self.assertEqual(dto.name, "Test Experiment") + + async def test_async_create_trace(self): + """Test asynchronously creating a trace""" + # Call the method - traces are created directly without API calls + params = { + "name": "Test Trace", + "metadata": {"source": "test"} + } + result = await self.monitor_sdk.async_create_trace("test-trace", params) + + # Assertions + self.assertIsNotNone(result) + self.assertIsInstance(result, Trace) + self.assertEqual(result.name, "Test Trace") + self.assertEqual(result.feature_slug, "test-trace") + self.assertEqual(result.metadata, {"source": "test"}) + + async def test_async_create_generation(self): + """Test asynchronously creating a generation""" + # Call the method - generations are created directly without API calls + params = { + "name": "Test Generation", + "input": "Test input", + "metadata": {"source": "test"} + } + result = await self.monitor_sdk.async_create_generation(params) + + # Assertions + self.assertIsNotNone(result) + self.assertIsInstance(result, Generation) + self.assertEqual(result.input, "Test input") + self.assertEqual(result.name, "Test Generation") + self.assertEqual(result.metadata, {"source": "test"}) + # Options will be None since not provided + self.assertIsNone(result.options) + + async def test_async_create_log(self): + """Test asynchronously creating a log""" + # Call the method - logs are created directly without API calls + params = { + "name": "Test Log", + "input": "Test log input", + "metadata": {"source": "test"} + } + result = await self.monitor_sdk.async_create_log(params) + + # Assertions + self.assertIsNotNone(result) + self.assertIsInstance(result, Log) + self.assertEqual(result.name, "Test Log") + self.assertEqual(result.input, "Test log input") + self.assertEqual(result.metadata, {"source": "test"}) + + async def test_async_error_handling_create_experiment(self): + """Test error handling when asynchronously creating an experiment""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + params = {"name": "Test Experiment"} + + err, result = await self.monitor_sdk.async_create_experiment("test-feature", params) + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(result) + self.assertEqual(str(err), "API Error") + + +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + + def close(self): + self.loop.close() + + +if __name__ == "__main__": + runner = AsyncTestRunner() + try: + runner.run_test_case(TestMonitorSDKAsync) + finally: + runner.close() diff --git a/tests/test_promptsdk_async.py b/tests/test_promptsdk_async.py new file mode 100644 index 0000000..4ecb96e --- /dev/null +++ b/tests/test_promptsdk_async.py @@ -0,0 +1,243 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from basalt.sdk.promptsdk import PromptSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import GetPromptDTO, PromptResponse, DescribePromptDTO, DescribePromptResponse, PromptListDTO, PromptListResponse +from basalt.endpoints.get_prompt import GetPromptEndpoint, GetPromptEndpointResponse +from basalt.endpoints.list_prompts import ListPromptsEndpoint, ListPromptsEndpointResponse +from basalt.endpoints.describe_prompt import DescribePromptEndpoint, DescribePromptEndpointResponse + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock model for PromptResponse +mock_model = type('Model', (), { + 'provider': 'openai', + 'model': 'gpt-4', + 'version': 'latest', + 'parameters': type('Params', (), { + 'temperature': 0.7, + 'max_length': 100, + 'top_p': 1.0, + 'top_k': None, + 'frequency_penalty': None, + 'presence_penalty': None, + 'response_format': 'text', + 'json_object': None + })() +})() + +# Mock responses for different endpoints +prompt_get_response = GetPromptEndpointResponse( + warning=None, + prompt=PromptResponse( + text="This is a test prompt: {{variable}}", + systemText="You are a helpful assistant", + version="1.0", + model=mock_model + ) +) + +prompt_list_response = ListPromptsEndpointResponse( + warning=None, + prompts=[ + PromptListResponse( + slug="test-prompt-1", + status="active", + name="Test Prompt 1", + description="First test prompt", + available_versions=["1.0"], + available_tags=["latest"] + ), + PromptListResponse( + slug="test-prompt-2", + status="active", + name="Test Prompt 2", + description="Second test prompt", + available_versions=["1.0", "2.0"], + available_tags=["latest", "stable"] + ) + ] +) + +prompt_describe_response = DescribePromptEndpointResponse( + warning=None, + prompt=DescribePromptResponse( + slug="test-prompt", + status="active", + name="Test Prompt", + description="A test prompt for unit testing", + available_versions=["1.0", "1.1", "2.0"], + available_tags=["latest", "stable"], + variables=[{"name": "variable", "type": "string"}] + ) +) + + +class TestPromptSDKAsync(unittest.TestCase): + def setUp(self): + from basalt.utils.memcache import MemoryCache + self.prompt_sdk = PromptSDK( + api=mocked_api, + cache=MemoryCache(), + fallback_cache=MemoryCache(), + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_get_prompt(self): + """Test asynchronously getting a prompt""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_get_response) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get("test-prompt") + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(prompt_response) + self.assertEqual(prompt_response.text, "This is a test prompt: {{variable}}") + self.assertEqual(prompt_response.version, "1.0") + self.assertIsNotNone(generation) # Generation is created for monitoring + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, GetPromptEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-prompt") + + async def test_async_list_prompts(self): + """Test asynchronously listing prompts""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_list_response) + + # Call the method + err, prompts = await self.prompt_sdk.async_list() + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(prompts) + self.assertEqual(len(prompts), 2) + self.assertEqual(prompts[0].slug, "test-prompt-1") + self.assertEqual(prompts[1].slug, "test-prompt-2") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, ListPromptsEndpoint) + + async def test_async_list_prompts_with_feature_filter(self): + """Test asynchronously listing prompts with feature filter""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_list_response) + + # Call the method + err, prompts = await self.prompt_sdk.async_list(feature_slug="test-feature") + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(prompts) + self.assertEqual(len(prompts), 2) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.featureSlug, "test-feature") + + async def test_async_describe_prompt(self): + """Test asynchronously describing a prompt""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_describe_response) + + # Call the method + err, prompt_description = await self.prompt_sdk.async_describe("test-prompt") + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(prompt_description) + self.assertEqual(prompt_description.slug, "test-prompt") + self.assertEqual(prompt_description.name, "Test Prompt") + self.assertEqual(len(prompt_description.available_versions), 3) + self.assertIn("1.0", prompt_description.available_versions) + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, DescribePromptEndpoint) + + async def test_async_get_prompt_with_variables(self): + """Test asynchronously getting a prompt with variables replaced""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_get_response) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get( + "test-prompt", + variables={"variable": "test-value"} + ) + + # Assertions + self.assertIsNone(err) + self.assertIsNotNone(prompt_response) + # Variables should be replaced in the response + self.assertEqual(prompt_response.text, "This is a test prompt: test-value") + + async def test_async_get_prompt_error_handling(self): + """Test error handling when asynchronously getting a prompt""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get("non-existent") + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(prompt_response) + self.assertIsNone(generation) + self.assertEqual(str(err), "API Error") + + +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + + def close(self): + self.loop.close() + + +if __name__ == "__main__": + runner = AsyncTestRunner() + try: + runner.run_test_case(TestPromptSDKAsync) + finally: + runner.close() \ No newline at end of file