From 038ed5fd78a903c38067dfa432f5ee969f631d1a Mon Sep 17 00:00:00 2001 From: Theophile Cousin Date: Fri, 27 Jun 2025 13:43:29 +0200 Subject: [PATCH 1/5] implement datasets sdk --- basalt/objects/dataset.py | 12 ++- examples/dataset_example.py | 157 ++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 examples/dataset_example.py diff --git a/basalt/objects/dataset.py b/basalt/objects/dataset.py index 51cd679..d502608 100644 --- a/basalt/objects/dataset.py +++ b/basalt/objects/dataset.py @@ -20,16 +20,14 @@ def to_dict(self) -> Dict[str, Any]: """Convert the DatasetRow to a dictionary for API requests""" result = { "values": self.values, - "metadata": self.metadata, - "name": self.name, - "idealOutput": self.ideal_output + "metadata": self.metadata } - # if self.name: - # result["name"] = self.name + if self.name: + result["name"] = self.name - # if self.ideal_output: - # result["idealOutput"] = self.ideal_output + if self.ideal_output: + result["idealOutput"] = self.ideal_output return result diff --git a/examples/dataset_example.py b/examples/dataset_example.py new file mode 100644 index 0000000..bc80ab0 --- /dev/null +++ b/examples/dataset_example.py @@ -0,0 +1,157 @@ +""" +Dataset Example - Demonstrates how to use the Basalt Dataset SDK + +This example shows how to: +1. List all available datasets +2. Get a specific dataset with its rows +3. Create a new dataset item +4. Work with dataset objects + +Make sure to set your API key as an environment variable: +export BASALT_API_KEY=your_api_key +""" + +import os +import sys +from typing import Dict, List, Optional + +from basalt import Basalt + +# Set up the Basalt client +api_key = os.getenv("BASALT_API_KEY") +if not api_key: + print("Please set your BASALT_API_KEY environment variable") + sys.exit(1) + +basalt = Basalt(api_key=api_key) + +def list_datasets(): + """List all datasets in the workspace""" + print("\n=== Listing all datasets ===") + err, datasets = basalt.datasets.list() + + if err: + print(f"Error listing datasets: {err}") + return None + + print(f"Found {len(datasets)} datasets:") + for dataset in datasets: + print(f"- {dataset.name} (slug: {dataset.slug})") + print(f" Columns: {', '.join(dataset.columns)}") + + # Return the first dataset slug for other examples to use + return datasets[0].slug if datasets else None + +def get_dataset(slug: str): + """Get a specific dataset by slug""" + print(f"\n=== Getting dataset: {slug} ===") + err, dataset = basalt.datasets.get(slug) + + if err: + print(f"Error getting dataset: {err}") + return + + print(f"Dataset: {dataset.name}") + print(f"Columns: {', '.join(dataset.columns)}") + + if dataset.rows: + print(f"Number of rows: {len(dataset.rows)}") + print("First few rows:") + for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows + print(f"Row {i+1}:") + print(f" Values: {row.values}") + if row.idealOutput: + print(f" Ideal output: {row.idealOutput}") + if row.metadata: + print(f" Metadata: {row.metadata}") + else: + print("Dataset has no rows") + +def add_row_to_dataset(slug: str): + """Create a new dataset item""" + print(f"\n=== Creating new dataset item in {slug} ===") + + # Get the dataset to understand its structure + err, dataset = basalt.datasets.get(slug) + if err or not dataset: + print(f"Error getting dataset structure: {err}") + return + + # Create values for all columns + values = {} + for column in dataset.columns: + values[column] = f"Example value for {column}" + + # Create the item + err, row, warning = basalt.datasets.addRow( + slug=slug, + values=values, + name="Example Row", + ideal_output="Example ideal output", + metadata={"source": "Python SDK example", "timestamp": "2025-06-27"} + ) + + if err: + print(f"Error creating dataset item: {err}") + return + + if warning: + print(f"Warning: {warning}") + + print("Dataset item created successfully:") + print(f"Values: {row.values}") + print(f"Name: {row.name}") + print(f"Ideal output: {row.idealOutput}") + print(f"Metadata: {row.metadata}") + +def work_with_dataset_objects(slug: str): + """Demonstrate working with Dataset objects""" + print(f"\n=== Working with dataset objects for {slug} ===") + + dataset = basalt.datasets.get_dataset_object(slug) + if not dataset: + print("Failed to get dataset object") + return + + print(f"Dataset object: {dataset.name} with {len(dataset.columns)} columns") + + # Add a new row to the dataset object + values = {column: f"New object value for {column}" for column in dataset.columns} + + row = basalt.datasets.add_row( + dataset=dataset, + values=values, + name="Object Example Row", + ideal_output="Object example ideal output", + metadata={"created_by": "dataset_object_example"} + ) + + if row: + print("Added new row to dataset object:") + print(f"Values: {row.values}") + print(f"Name: {row.name}") + print(f"Dataset now has {len(dataset.rows)} rows") + else: + print("Failed to add row to dataset object") + +def main(): + """Main function to run the example""" + print("Basalt Dataset Example") + + # List all datasets and get the first one's slug + dataset_slug = list_datasets() + if not dataset_slug: + print("No datasets available to continue the example") + return + + # Get details for the selected dataset + get_dataset(dataset_slug) + + # Create a new item in the dataset + create_dataset_item(dataset_slug) + + # Work with dataset objects + work_with_dataset_objects(dataset_slug) + +if __name__ == "__main__": + main() From e56b1a8ae6614ef5646a2baf2f53a27be0dd5e59 Mon Sep 17 00:00:00 2001 From: Theophile Cousin Date: Fri, 27 Jun 2025 14:23:24 +0200 Subject: [PATCH 2/5] rewrite example as notebook jupyter --- examples/dataset_example.py | 157 -------------------------------- examples/dataset_sdk_demo.ipynb | 26 ++++++ 2 files changed, 26 insertions(+), 157 deletions(-) delete mode 100644 examples/dataset_example.py diff --git a/examples/dataset_example.py b/examples/dataset_example.py deleted file mode 100644 index bc80ab0..0000000 --- a/examples/dataset_example.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Dataset Example - Demonstrates how to use the Basalt Dataset SDK - -This example shows how to: -1. List all available datasets -2. Get a specific dataset with its rows -3. Create a new dataset item -4. Work with dataset objects - -Make sure to set your API key as an environment variable: -export BASALT_API_KEY=your_api_key -""" - -import os -import sys -from typing import Dict, List, Optional - -from basalt import Basalt - -# Set up the Basalt client -api_key = os.getenv("BASALT_API_KEY") -if not api_key: - print("Please set your BASALT_API_KEY environment variable") - sys.exit(1) - -basalt = Basalt(api_key=api_key) - -def list_datasets(): - """List all datasets in the workspace""" - print("\n=== Listing all datasets ===") - err, datasets = basalt.datasets.list() - - if err: - print(f"Error listing datasets: {err}") - return None - - print(f"Found {len(datasets)} datasets:") - for dataset in datasets: - print(f"- {dataset.name} (slug: {dataset.slug})") - print(f" Columns: {', '.join(dataset.columns)}") - - # Return the first dataset slug for other examples to use - return datasets[0].slug if datasets else None - -def get_dataset(slug: str): - """Get a specific dataset by slug""" - print(f"\n=== Getting dataset: {slug} ===") - err, dataset = basalt.datasets.get(slug) - - if err: - print(f"Error getting dataset: {err}") - return - - print(f"Dataset: {dataset.name}") - print(f"Columns: {', '.join(dataset.columns)}") - - if dataset.rows: - print(f"Number of rows: {len(dataset.rows)}") - print("First few rows:") - for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows - print(f"Row {i+1}:") - print(f" Values: {row.values}") - if row.idealOutput: - print(f" Ideal output: {row.idealOutput}") - if row.metadata: - print(f" Metadata: {row.metadata}") - else: - print("Dataset has no rows") - -def add_row_to_dataset(slug: str): - """Create a new dataset item""" - print(f"\n=== Creating new dataset item in {slug} ===") - - # Get the dataset to understand its structure - err, dataset = basalt.datasets.get(slug) - if err or not dataset: - print(f"Error getting dataset structure: {err}") - return - - # Create values for all columns - values = {} - for column in dataset.columns: - values[column] = f"Example value for {column}" - - # Create the item - err, row, warning = basalt.datasets.addRow( - slug=slug, - values=values, - name="Example Row", - ideal_output="Example ideal output", - metadata={"source": "Python SDK example", "timestamp": "2025-06-27"} - ) - - if err: - print(f"Error creating dataset item: {err}") - return - - if warning: - print(f"Warning: {warning}") - - print("Dataset item created successfully:") - print(f"Values: {row.values}") - print(f"Name: {row.name}") - print(f"Ideal output: {row.idealOutput}") - print(f"Metadata: {row.metadata}") - -def work_with_dataset_objects(slug: str): - """Demonstrate working with Dataset objects""" - print(f"\n=== Working with dataset objects for {slug} ===") - - dataset = basalt.datasets.get_dataset_object(slug) - if not dataset: - print("Failed to get dataset object") - return - - print(f"Dataset object: {dataset.name} with {len(dataset.columns)} columns") - - # Add a new row to the dataset object - values = {column: f"New object value for {column}" for column in dataset.columns} - - row = basalt.datasets.add_row( - dataset=dataset, - values=values, - name="Object Example Row", - ideal_output="Object example ideal output", - metadata={"created_by": "dataset_object_example"} - ) - - if row: - print("Added new row to dataset object:") - print(f"Values: {row.values}") - print(f"Name: {row.name}") - print(f"Dataset now has {len(dataset.rows)} rows") - else: - print("Failed to add row to dataset object") - -def main(): - """Main function to run the example""" - print("Basalt Dataset Example") - - # List all datasets and get the first one's slug - dataset_slug = list_datasets() - if not dataset_slug: - print("No datasets available to continue the example") - return - - # Get details for the selected dataset - get_dataset(dataset_slug) - - # Create a new item in the dataset - create_dataset_item(dataset_slug) - - # Work with dataset objects - work_with_dataset_objects(dataset_slug) - -if __name__ == "__main__": - main() diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb index 29c4fb4..26daf7e 100644 --- a/examples/dataset_sdk_demo.ipynb +++ b/examples/dataset_sdk_demo.ipynb @@ -25,7 +25,11 @@ "\n", "# Initialize the SDK\n", "basalt = Basalt(\n", +<<<<<<< HEAD " api_key=\"sk-df55a...\", # Replace with your API key\n", +======= + " api_key=\"sk-...\", # Replace with your API key\n", +>>>>>>> 6abcf23 (rewrite example as notebook jupyter) " log_level=\"debug\" # Optional: Set log level\n", ")" ] @@ -54,6 +58,10 @@ " print(f\"Found {len(datasets)} datasets:\")\n", " for i, dataset in enumerate(datasets):\n", " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", +<<<<<<< HEAD +======= + " print(f\" - Description: {dataset.description if dataset.description else 'No description'}\")\n", +>>>>>>> 6abcf23 (rewrite example as notebook jupyter) " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", " \n", " # Store the first dataset slug for later use (if available)\n", @@ -85,6 +93,10 @@ "else:\n", " print(f\"Dataset details for '{dataset.name}'\")\n", " print(f\"Slug: {dataset.slug}\")\n", +<<<<<<< HEAD +======= + " print(f\"Description: {dataset.description if dataset.description else 'No description'}\")\n", +>>>>>>> 6abcf23 (rewrite example as notebook jupyter) " print(f\"Columns: {', '.join(dataset.columns)}\")\n", " print(f\"Number of rows: {len(dataset.rows)}\")\n", " \n", @@ -92,6 +104,7 @@ " print(\"\\nSample rows:\")\n", " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", " print(f\"Row {i+1}:\")\n", +<<<<<<< HEAD " print(f\" Values: {row.get('values')}\")\n", " if 'name' in row:\n", " print(f\" Name: {row['name']}\")\n", @@ -99,6 +112,15 @@ " print(f\" Ideal output: {row['idealOutput']}\")\n", " if 'metadata' in row:\n", " print(f\" Metadata: {row['metadata']}\")" +======= + " print(f\" Values: {row.values}\")\n", + " if row.name:\n", + " print(f\" Name: {row.name}\")\n", + " if row.idealOutput:\n", + " print(f\" Ideal output: {row.idealOutput}\")\n", + " if row.metadata:\n", + " print(f\" Metadata: {row.metadata}\")" +>>>>>>> 6abcf23 (rewrite example as notebook jupyter) ] }, { @@ -224,7 +246,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", +<<<<<<< HEAD "version": "3.13.3" +======= + "version": "3.13.2" +>>>>>>> 6abcf23 (rewrite example as notebook jupyter) } }, "nbformat": 4, From 2c50f649edbbc86ee2f9157d1770f6cb821d9cfe Mon Sep 17 00:00:00 2001 From: Theophile Cousin Date: Fri, 27 Jun 2025 15:12:31 +0200 Subject: [PATCH 3/5] add support for async methods --- basalt/basaltsdk.py | 1 + basalt/sdk/datasetsdk.py | 83 ++++++++++++ basalt/sdk/monitorsdk.py | 188 ++++++++++++++++++++++++++ basalt/sdk/promptsdk.py | 174 +++++++++++++++++++++++- basalt/utils/api.py | 36 +++++ basalt/utils/networker.py | 54 ++++++++ setup.py | 1 + tests/test_datasetsdk_async.py | 224 +++++++++++++++++++++++++++++++ tests/test_monitorsdk_async.py | 235 +++++++++++++++++++++++++++++++++ tests/test_promptsdk_async.py | 207 +++++++++++++++++++++++++++++ 10 files changed, 1202 insertions(+), 1 deletion(-) create mode 100644 tests/test_datasetsdk_async.py create mode 100644 tests/test_monitorsdk_async.py create mode 100644 tests/test_promptsdk_async.py diff --git a/basalt/basaltsdk.py b/basalt/basaltsdk.py index 57a451f..04cab7c 100644 --- a/basalt/basaltsdk.py +++ b/basalt/basaltsdk.py @@ -5,6 +5,7 @@ class BasaltSDK(IBasaltSDK): """ The BasaltSDK class implements the IBasaltSDK interface. It serves as the main entry point for interacting with the Basalt SDK. + """ def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK, dataset_sdk: IDatasetSDK): diff --git a/basalt/sdk/datasetsdk.py b/basalt/sdk/datasetsdk.py index c7a6cf6..05b5acc 100644 --- a/basalt/sdk/datasetsdk.py +++ b/basalt/sdk/datasetsdk.py @@ -2,6 +2,7 @@ SDK for interacting with Basalt datasets """ from typing import Dict, List, Optional, Tuple, Any +import asyncio from ..utils.dtos import ( ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO, @@ -46,6 +47,26 @@ def list(self) -> ListDatasetsResult: name=dataset.name, columns=dataset.columns ) for dataset in result.datasets] + + async def async_list(self) -> ListDatasetsResult: + """ + Asynchronously list all datasets available in the workspace. + + Returns: + Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional + exception and an optional list of DatasetDTO objects. + """ + dto = ListDatasetsDTO() + err, result = await self._api.async_invoke(ListDatasetsEndpoint, dto) + + if err is not None: + return err, None + + return None, [DatasetDTO( + slug=dataset.slug, + name=dataset.name, + columns=dataset.columns + ) for dataset in result.datasets] def get(self, slug: str) -> GetDatasetResult: """ @@ -68,6 +89,28 @@ def get(self, slug: str) -> GetDatasetResult: return Exception(result.error), None return None, result.dataset + + async def async_get(self, slug: str) -> GetDatasetResult: + """ + Asynchronously get a dataset by its slug. + + Args: + slug (str): The slug identifier for the dataset. + + Returns: + Tuple[Optional[Exception], Optional[DatasetDTO]]: A tuple containing an optional + exception and an optional DatasetDTO. + """ + dto = GetDatasetDTO(slug=slug) + err, result = await self._api.async_invoke(GetDatasetEndpoint, dto) + + if err is not None: + return err, None + + if result.error: + return Exception(result.error), None + + return None, result.dataset def addRow( self, @@ -108,3 +151,43 @@ def addRow( return Exception(result.error), None, None return None, result.datasetRow, result.warning + + async def async_addRow( + self, + slug: str, + values: Dict[str, str], + name: Optional[str] = None, + ideal_output: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> CreateDatasetItemResult: + """ + Asynchronously create a new item in a dataset. + + Args: + slug (str): The slug identifier for the dataset. + values (Dict[str, str]): A dictionary of column values for the dataset item. + name (Optional[str]): An optional name for the dataset item. + ideal_output (Optional[str]): An optional ideal output for the dataset item. + metadata (Optional[Dict[str, Any]]): An optional metadata dictionary. + + Returns: + Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]]: A tuple containing + an optional exception, an optional DatasetRowDTO, and an optional warning message. + """ + dto = CreateDatasetItemDTO( + slug=slug, + values=values, + name=name, + idealOutput=ideal_output, + metadata=metadata + ) + + err, result = await self._api.async_invoke(CreateDatasetItemEndpoint, dto) + + if err is not None: + return err, None, None + + if result.error: + return Exception(result.error), None, None + + return None, result.datasetRow, result.warning diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index 6277296..dd97189 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -1,4 +1,5 @@ from typing import Dict, Optional, Any, Tuple +import asyncio from ..utils.protocols import IApi, ILogger from ..ressources.monitor.trace_types import TraceParams @@ -40,6 +41,23 @@ def create_experiment( Experiment: A new Experiment instance. """ return self._create_experiment(feature_slug, params) + + async def async_create_experiment( + self, + feature_slug: str, + params: ExperimentParams + ) -> Tuple[Optional[Exception], Optional[Experiment]]: + """ + Asynchronously creates a new experiment for monitoring. + + Args: + feature_slug (str): The feature slug for the experiment. + params (Dict[str, Any]): Parameters for the experiment. + + Returns: + Experiment: A new Experiment instance. + """ + return await self._async_create_experiment(feature_slug, params) def create_trace( @@ -63,6 +81,28 @@ def create_trace( trace_params = TraceParams(**params) return self._create_trace(slug, trace_params) + + async def async_create_trace( + self, + slug: str, + params: Optional[TraceParams] = None + ) -> Trace: + """ + Asynchronously creates a new trace for monitoring. + + Args: + slug (str): The unique identifier for the trace. + params (TraceParams): Parameters for the trace. + + Returns: + Trace: A new Trace instance. + """ + if params is None: + params = {} + + trace_params = TraceParams(**params) + + return await self._async_create_trace(slug, trace_params) def create_generation( self, @@ -79,6 +119,22 @@ def create_generation( """ generation_params = GenerationParams(**params) return self._create_generation(generation_params) + + async def async_create_generation( + self, + params: Dict[str, Any] + ) -> Generation: + """ + Asynchronously creates a new generation for monitoring. + + Args: + params (Dict[str, Any]): Parameters for the generation. + + Returns: + Generation: A new Generation instance. + """ + generation_params = GenerationParams(**params) + return await self._async_create_generation(generation_params) def create_log( self, @@ -95,6 +151,22 @@ def create_log( """ log_params = LogParams(**params) return self._create_log(log_params) + + async def async_create_log( + self, + params: Dict[str, Any] + ) -> Log: + """ + Asynchronously creates a new log for monitoring. + + Args: + params (Dict[str, Any]): Parameters for the log. + + Returns: + Log: A new Log instance. + """ + log_params = LogParams(**params) + return await self._async_create_log(log_params) def _create_experiment( self, @@ -123,6 +195,34 @@ def _create_experiment( return None, Experiment(result.experiment) return err, None + + async def _async_create_experiment( + self, + feature_slug: str, + params: ExperimentParams + ) -> Tuple[Optional[Exception], Optional[Experiment]]: + """ + Internal implementation for asynchronously creating an experiment. + + Args: + feature_slug (str): The feature slug for the experiment. + params (ExperimentParams): Parameters for the experiment. + + Returns: + Experiment: A new Experiment instance. + """ + dto = CreateExperimentDTO( + feature_slug=feature_slug, + name=params.get("name"), + ) + + # Call the API endpoint + err, result = await self._api.async_invoke(CreateExperimentEndpoint, dto) + + if err is None: + return None, Experiment(result.experiment) + + return err, None def _create_trace( @@ -157,6 +257,39 @@ def _create_trace( } trace = Trace(slug, params_dict, flusher, self._logger) return trace + + async def _async_create_trace( + self, + slug: str, + params: TraceParams + ) -> Trace: + """ + Internal implementation for asynchronously creating a trace. + + Args: + slug (str): The unique identifier for the trace. + params (TraceParams): Parameters for the trace. + + Returns: + Trace: A new Trace instance. + """ + flusher = Flusher(self._api, self._logger) + # Convert TraceParams to a dictionary before passing to Trace + params_dict = { + "input": params.input, + "output": params.output, + "name": params.name, + "start_time": params.start_time, + "end_time": params.end_time, + "user": params.user, + "organization": params.organization, + "metadata": params.metadata, + "experiment": params.experiment, + "evaluators": params.evaluators, + "evaluationConfig": params.evaluation_config + } + trace = Trace(slug, params_dict, flusher, self._logger) + return trace def _create_generation( self, @@ -186,6 +319,35 @@ def _create_generation( "options": params.options } return Generation(params_dict) + + async def _async_create_generation( + self, + params: GenerationParams + ) -> Generation: + """ + Internal implementation for asynchronously creating a generation. + + Args: + params (GenerationParams): Parameters for the generation. + + Returns: + Generation: A new Generation instance. + """ + # Convert GenerationParams to a dictionary before passing to Generation + params_dict = { + "name": params.name, + "trace": params.trace, + "prompt": params.prompt, + "input": params.input, + "output": params.output, + "variables": params.variables, + "parent": params.parent, + "metadata": params.metadata, + "start_time": params.start_time, + "end_time": params.end_time, + "options": params.options + } + return Generation(params_dict) def _create_log( self, @@ -194,6 +356,32 @@ def _create_log( """ Internal implementation for creating a log. + Args: + params (LogParams): Parameters for the log. + + Returns: + Log: A new Log instance. + """ + # Convert LogParams to a dictionary before passing to Log + params_dict = { + "name": params.name, + "trace": params.trace, + "input": params.input, + "output": params.output, + "parent": params.parent, + "metadata": params.metadata, + "start_time": params.start_time, + "end_time": params.end_time + } + return Log(params_dict) + + async def _async_create_log( + self, + params: LogParams + ) -> Log: + """ + Internal implementation for asynchronously creating a log. + Args: params (LogParams): Parameters for the log. diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 85c9e05..f909c4e 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Tuple, Any +from typing import Optional, Dict, Tuple, Any, Awaitable from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, GetResult, DescribeResult, ListResult, PromptListResponse, PromptListDTO from ..utils.protocols import ICache, IApi, ILogger @@ -11,6 +11,7 @@ from ..objects.generation import Generation from ..utils.flusher import Flusher from datetime import datetime +import asyncio class PromptSDK: """ @@ -87,6 +88,63 @@ def get( return err, prompt_response, generation return err, None, None + + async def async_get( + self, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + variables: Dict[str, str] = {}, + cache_enabled: bool = True + ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + """ + Asynchronously retrieve a prompt by slug, optionally specifying version and tag. + + Args: + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + variables (dict): A dictionnary of variables to replace in the prompt text. + cache_enabled (bool): Enable or disable cache for this request. + + Returns: + Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. + """ + dto = GetPromptDTO( + slug=slug, + version=version, + tag=tag + ) + + cached = self._cache.get(dto) if cache_enabled else None + + if cached: + original_prompt_text = cached.text + err, prompt_response = self._replace_vars(cached, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + err, result = await self._api.async_invoke(GetPromptEndpoint, dto) + + if err is None: + original_prompt_text = result.prompt.text + self._cache.put(dto, result.prompt, self._cache_duration) + self._fallback_cache.put(dto, result.prompt) + + err, prompt_response = self._replace_vars(result.prompt, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + fallback = self._fallback_cache.get(dto) if cache_enabled else None + + if fallback: + original_prompt_text = fallback.text + err, prompt_response = self._replace_vars(fallback, variables) + generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + return err, prompt_response, generation + + return err, None, None def _prepare_monitoring( self, @@ -135,6 +193,54 @@ def _prepare_monitoring( }) return generation + + async def _async_prepare_monitoring( + self, + prompt: PromptResponse, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + original_prompt_text: Optional[str] = None + ) -> Generation: + """ + Asynchronously prepare monitoring by creating a trace and generation object. + + Args: + prompt (PromptResponse): The prompt response. + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + variables (Optional[Dict[str, Any]]): Variables used in the prompt. + original_prompt_text (Optional[str]): The original prompt text. + + Returns: + Generation: The generation object. + """ + # Create a flusher + flusher = Flusher(self._api, self._logger) + + # Create a trace + trace = Trace(slug, { + "input": original_prompt_text or prompt.text, + "start_time": datetime.now() + }, flusher, self._logger) + + # Create a generation + generation = Generation({ + "name": slug, + "trace": trace, + "prompt": { + "slug": slug, + "version": version, + "tag": tag + }, + "input": original_prompt_text or prompt.text, + "variables": variables, + "options": {"type": "single"} + }) + + return generation def describe( self, @@ -176,6 +282,46 @@ def describe( ) return err, None + + async def async_describe( + self, + slug: str, + version: Optional[str] = None, + tag: Optional[str] = None, + ) -> DescribeResult: + """ + Asynchronously get details about a prompt by slug, optionally specifying version and tag. + + Args: + slug (str): The slug identifier for the prompt. + version (Optional[str]): The version of the prompt. + tag (Optional[str]): The tag associated with the prompt. + + Returns: + Tuple[Optional[Exception], Optional[DescribePromptResponse]]: A tuple containing an optional exception and an optional DescribePromptResponse. + """ + dto = DescribePromptDTO( + slug=slug, + version=version, + tag=tag + ) + + err, result = await self._api.async_invoke(DescribePromptEndpoint, dto) + + if err is None: + prompt = result.prompt + + return None, DescribePromptResponse( + slug=prompt.slug, + status=prompt.status, + name=prompt.name, + description=prompt.description, + available_versions=prompt.available_versions, + available_tags=prompt.available_tags, + variables=prompt.variables + ) + + return err, None def list(self, feature_slug: Optional[str] = None) -> ListResult: dto = PromptListDTO(featureSlug=feature_slug) @@ -193,6 +339,32 @@ def list(self, feature_slug: Optional[str] = None) -> ListResult: available_versions=prompt.available_versions, available_tags=prompt.available_tags ) for prompt in result.prompts] + + async def async_list(self, feature_slug: Optional[str] = None) -> ListResult: + """ + Asynchronously list prompts, optionally filtering by feature_slug. + + Args: + feature_slug (Optional[str]): Optional feature slug to filter prompts by. + + Returns: + Tuple[Optional[Exception], Optional[List[PromptListResponse]]]: A tuple containing an optional exception and an optional list of PromptListResponse objects. + """ + dto = PromptListDTO(featureSlug=feature_slug) + + err, result = await self._api.async_invoke(ListPromptsEndpoint, dto) + + if err is not None: + return err, None + + return None, [PromptListResponse( + slug=prompt.slug, + status=prompt.status, + name=prompt.name, + description=prompt.description, + available_versions=prompt.available_versions, + available_tags=prompt.available_tags + ) for prompt in result.prompts] def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): missing_vars, replaced = replace_variables(prompt.text, variables) diff --git a/basalt/utils/api.py b/basalt/utils/api.py index 7bb928c..9fd4bf2 100644 --- a/basalt/utils/api.py +++ b/basalt/utils/api.py @@ -1,6 +1,7 @@ from typing import Dict, TypeVar, Optional, Tuple from .protocols import IEndpoint, INetworker, ILogger +import asyncio from .networker import Networker Input = TypeVar('Input') @@ -83,3 +84,38 @@ def _headers(self) -> Dict[str, str]: 'X-BASALT-SDK-TYPE': self._sdk_type, 'Content-Type': 'application/json' } + + async def async_invoke( + self, + endpoint: IEndpoint[Input, Output], + dto: Optional[Input] = None + ) -> Tuple[Optional[Exception], Optional[Output]]: + """ + Asynchronously invoke an API endpoint with the given data transfer object (DTO). + + Args: + endpoint: The endpoint to be invoked. + dto: The data transfer object to be sent to the endpoint. + + Returns: + A tuple containing an optional exception and an optional output. + """ + # Prepare the request information using the endpoint and input data + if dto is None: + request_info = endpoint.prepare_request() + else: + request_info = endpoint.prepare_request(dto) + + # Fetch the result from the network using the prepared request information + error, result = await self._network.async_fetch( + self._root + request_info['path'], + request_info['method'], + request_info.get('body'), + params=request_info.get('query', {}), + headers=self._headers(), + ) + + if error: + return error, None + + return endpoint.decode_response(result) diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index c21c305..af87960 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -1,4 +1,5 @@ import requests +import aiohttp from typing import Any, Dict, Optional, Tuple from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized @@ -64,3 +65,56 @@ def fetch( except Exception as e: return NetworkBaseError(str(e)), None + + async def async_fetch( + self, + url: str, + method: str, + body = None, + headers = None, + params = None + ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: + """ + Asynchronously fetch data from a given URL using the specified HTTP method. + + Args: + url (str): The URL to fetch data from. + method (str): The HTTP method to use (e.g., 'GET', 'POST'). + body (Optional[Any]): The request payload to send (default is None). + headers (Optional[Dict[str, str]]): The request headers to send (default is None). + params (Optional[Dict[str, str]]): The query parameters to send (default is None). + + Returns: + A result tuple (err, json_response), possible responses: + - (None, json_response) + - (FetchError, None) + """ + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method, + url, + params=params, + json=body, + headers=headers + ) as response: + json_response = await response.json() + + if response.status == 400: + return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None + + if response.status == 401: + return Unauthorized(json_response.get('error', 'Unauthorized')), None + + if response.status == 403: + return Forbidden(json_response.get('error', 'Forbidden')), None + + if response.status == 404: + return NotFound(json_response.get('error', 'Not Found')), None + + response.raise_for_status() + + return None, json_response + + except Exception as e: + return NetworkBaseError(str(e)), None diff --git a/setup.py b/setup.py index 83d073f..eb0a52a 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ def get_version(): packages=find_packages(), install_requires=[ "requests>=2.32", + "aiohttp>=3.8.0", ], python_requires=">=3.6" ) diff --git a/tests/test_datasetsdk_async.py b/tests/test_datasetsdk_async.py new file mode 100644 index 0000000..d5090a8 --- /dev/null +++ b/tests/test_datasetsdk_async.py @@ -0,0 +1,224 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from basalt.sdk.datasetsdk import DatasetSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import DatasetDTO, DatasetRowDTO +from basalt.endpoints.list_datasets import ListDatasetsEndpoint, ListDatasetsEndpointResponse +from basalt.endpoints.get_dataset import GetDatasetEndpoint, GetDatasetEndpointResponse +from basalt.endpoints.create_dataset_item import CreateDatasetItemEndpoint, CreateDatasetItemEndpointResponse + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock responses for different endpoints - same as in test_datasetsdk.py +dataset_list_response = ListDatasetsEndpointResponse( + datasets=[ + DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"] + ), + DatasetDTO( + slug="another-dataset", + name="Another Dataset", + columns=["col1", "col2", "col3"] + ) + ] +) + +dataset_get_response = GetDatasetEndpointResponse( + dataset=DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"], + rows=[ + { + "values": { + "input": "Sample input", + "output": "Sample output" + }, + "name": "Sample Row", + "idealOutput": "Ideal output", + "metadata": {"source": "test"} + } + ] + ), + error=None +) + +dataset_add_row_response = CreateDatasetItemEndpointResponse( + datasetRow=DatasetRowDTO( + values={"input": "New input", "output": "New output"}, + name="New Row", + idealOutput="New ideal output", + metadata={"source": "test"} + ), + warning=None, + error=None +) + + +class TestDatasetSDKAsync(unittest.TestCase): + def setUp(self): + self.dataset_sdk = DatasetSDK( + api=mocked_api, + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_list_datasets(self): + """Test asynchronously listing all datasets""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_list_response) + + # Call the method + err, datasets = await self.dataset_sdk.async_list() + + # Assertions + self.assertIsNone(err) + self.assertEqual(len(datasets), 2) + self.assertEqual(datasets[0].slug, "test-dataset") + self.assertEqual(datasets[0].name, "Test Dataset") + self.assertEqual(datasets[1].slug, "another-dataset") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, ListDatasetsEndpoint) + + async def test_async_get_dataset(self): + """Test asynchronously getting a dataset by slug""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_get_response) + + # Call the method + err, dataset = await self.dataset_sdk.async_get("test-dataset") + + # Assertions + self.assertIsNone(err) + self.assertEqual(dataset.slug, "test-dataset") + self.assertEqual(dataset.name, "Test Dataset") + self.assertEqual(len(dataset.columns), 2) + self.assertEqual(len(dataset.rows), 1) + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, GetDatasetEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + + async def test_async_create_dataset_item(self): + """Test asynchronously creating a dataset item""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_add_row_response) + + # Call the method + values = {"input": "New input", "output": "New output"} + err, row, warning = await self.dataset_sdk.async_addRow( + slug="test-dataset", + values=values, + name="New Row", + ideal_output="New ideal output", + metadata={"source": "test"} + ) + + # Assertions + self.assertIsNone(err) + self.assertIsNone(warning) + self.assertEqual(row.values, values) + self.assertEqual(row.name, "New Row") + self.assertEqual(row.idealOutput, "New ideal output") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateDatasetItemEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + self.assertEqual(dto.values, values) + self.assertEqual(dto.name, "New Row") + self.assertEqual(dto.idealOutput, "New ideal output") + + async def test_async_error_handling_get_dataset(self): + """Test error handling when asynchronously getting a dataset""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + err, dataset = await self.dataset_sdk.async_get("non-existent") + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(dataset) + self.assertEqual(str(err), "API Error") + + async def test_async_get_dataset_object(self): + """Test asynchronously getting a dataset as an object""" + # Configure mock + mocked_api.async_invoke.return_value = (None, dataset_get_response) + + # Call the method + dataset = await self.dataset_sdk.async_get_dataset_object("test-dataset") + + # Assertions + self.assertIsNotNone(dataset) + self.assertEqual(dataset.slug, "test-dataset") + self.assertEqual(dataset.name, "Test Dataset") + self.assertEqual(len(dataset.rows), 1) + self.assertEqual(dataset.rows[0].values, {"input": "Sample input", "output": "Sample output"}) + + async def test_async_add_row_to_dataset(self): + """Test asynchronously adding a row to a dataset object""" + # First get a dataset object + mocked_api.async_invoke.return_value = (None, dataset_get_response) + dataset = await self.dataset_sdk.async_get_dataset_object("test-dataset") + + # Then add a row to it + mocked_api.async_invoke.return_value = (None, dataset_add_row_response) + + values = {"input": "New input", "output": "New output"} + row = await self.dataset_sdk.async_add_row_to_dataset( + dataset=dataset, + values=values, + name="New Row", + ideal_output="New ideal output", + metadata={"source": "test"} + ) + + # Assertions + self.assertIsNotNone(row) + self.assertEqual(row.values, values) + self.assertEqual(row.name, "New Row") + self.assertEqual(row.ideal_output, "New ideal output") + + # Check that the row was added to the dataset + self.assertEqual(len(dataset.rows), 2) + self.assertEqual(dataset.rows[-1], row) + + +def run_async_tests(): + """ + Helper function to run async tests + """ + loop = asyncio.get_event_loop() + + # Create and run the test suite + suite = unittest.TestLoader().loadTestsFromTestCase(TestDatasetSDKAsync) + runner = unittest.TextTestRunner() + + for test in suite: + if test._testMethodName.startswith('test_async_'): + coro = getattr(test, test._testMethodName)() + loop.run_until_complete(coro) + + +if __name__ == "__main__": + run_async_tests() diff --git a/tests/test_monitorsdk_async.py b/tests/test_monitorsdk_async.py new file mode 100644 index 0000000..e527ef1 --- /dev/null +++ b/tests/test_monitorsdk_async.py @@ -0,0 +1,235 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from basalt.sdk.monitorsdk import MonitorSDK +from basalt.utils.logger import Logger +from basalt.ressources.monitor.monitorsdk_types import ( + Experiment, Trace, Generation, LogParams, GenerationParams, + TraceParams, ExperimentParams +) +from basalt.endpoints.monitor.create_experiment import CreateExperimentEndpoint, CreateExperimentEndpointResponse +from basalt.endpoints.monitor.send_trace import CreateTraceEndpoint, CreateTraceEndpointResponse +from basalt.endpoints.monitor.create_generation import CreateGenerationEndpoint, CreateGenerationEndpointResponse +from basalt.endpoints.monitor.create_log import CreateLogEndpoint, CreateLogEndpointResponse + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock responses for different endpoints +experiment_response = CreateExperimentEndpointResponse( + experiment=Experiment( + id="exp-123", + feature_slug="test-feature", + run_id="run-123", + type="A/B Test", + name="Test Experiment", + setup={ + "control_id": "control-123", + "variation_id": "variation-123" + } + ) +) + +trace_response = CreateTraceEndpointResponse( + trace=Trace( + id="trace-123", + name="Test Trace", + slug="test-trace", + metadata={"source": "test"}, + run_id="run-123", + tool=None, + model_id=None, + created_at="2023-01-01T00:00:00Z" + ) +) + +generation_response = CreateGenerationEndpointResponse( + generation=Generation( + id="gen-123", + trace_id="trace-123", + run_id="run-123", + text="Generated text", + model_id="gpt-4", + prompt="Test prompt", + metadata={"source": "test"}, + created_at="2023-01-01T00:00:00Z" + ) +) + +log_response = CreateLogEndpointResponse( + log={ + "id": "log-123", + "trace_id": "trace-123", + "run_id": "run-123", + "type": "info", + "message": "Test log message", + "metadata": {"source": "test"}, + "created_at": "2023-01-01T00:00:00Z" + } +) + + +class TestMonitorSDKAsync(unittest.TestCase): + def setUp(self): + self.monitor_sdk = MonitorSDK( + api=mocked_api, + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_create_experiment(self): + """Test asynchronously creating an experiment""" + # Configure mock + mocked_api.async_invoke.return_value = (None, experiment_response) + + # Call the method + params = ExperimentParams( + type="A/B Test", + name="Test Experiment", + setup={ + "control_id": "control-123", + "variation_id": "variation-123" + } + ) + result = await self.monitor_sdk.async_create_experiment("test-feature", params) + + # Assertions + self.assertIsNotNone(result) + self.assertEqual(result.id, "exp-123") + self.assertEqual(result.feature_slug, "test-feature") + self.assertEqual(result.type, "A/B Test") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateExperimentEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.feature_slug, "test-feature") + self.assertEqual(dto.params.type, "A/B Test") + + async def test_async_create_trace(self): + """Test asynchronously creating a trace""" + # Configure mock + mocked_api.async_invoke.return_value = (None, trace_response) + + # Call the method + params = TraceParams( + name="Test Trace", + metadata={"source": "test"} + ) + result = await self.monitor_sdk.async_create_trace("test-trace", params) + + # Assertions + self.assertIsNotNone(result) + self.assertEqual(result.id, "trace-123") + self.assertEqual(result.name, "Test Trace") + self.assertEqual(result.slug, "test-trace") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateTraceEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-trace") + self.assertEqual(dto.params.name, "Test Trace") + + async def test_async_create_generation(self): + """Test asynchronously creating a generation""" + # Configure mock + mocked_api.async_invoke.return_value = (None, generation_response) + + # Call the method + params = GenerationParams( + trace_id="trace-123", + text="Generated text", + model_id="gpt-4", + prompt="Test prompt", + metadata={"source": "test"} + ) + result = await self.monitor_sdk.async_create_generation(params) + + # Assertions + self.assertIsNotNone(result) + self.assertEqual(result.id, "gen-123") + self.assertEqual(result.trace_id, "trace-123") + self.assertEqual(result.text, "Generated text") + self.assertEqual(result.model_id, "gpt-4") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateGenerationEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.params.trace_id, "trace-123") + self.assertEqual(dto.params.text, "Generated text") + + async def test_async_create_log(self): + """Test asynchronously creating a log""" + # Configure mock + mocked_api.async_invoke.return_value = (None, log_response) + + # Call the method + params = LogParams( + trace_id="trace-123", + type="info", + message="Test log message", + metadata={"source": "test"} + ) + result = await self.monitor_sdk.async_create_log(params) + + # Assertions + self.assertIsNotNone(result) + self.assertEqual(result["id"], "log-123") + self.assertEqual(result["trace_id"], "trace-123") + self.assertEqual(result["message"], "Test log message") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, CreateLogEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.params.trace_id, "trace-123") + self.assertEqual(dto.params.message, "Test log message") + + async def test_async_error_handling_create_trace(self): + """Test error handling when asynchronously creating a trace""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + params = TraceParams(name="Test Trace") + + with self.assertRaises(Exception) as context: + await self.monitor_sdk.async_create_trace("non-existent", params) + + # Assertions + self.assertEqual(str(context.exception), "API Error") + + +def run_async_tests(): + """ + Helper function to run async tests + """ + loop = asyncio.get_event_loop() + + # Create and run the test suite + suite = unittest.TestLoader().loadTestsFromTestCase(TestMonitorSDKAsync) + runner = unittest.TextTestRunner() + + for test in suite: + if test._testMethodName.startswith('test_async_'): + coro = getattr(test, test._testMethodName)() + loop.run_until_complete(coro) + + +if __name__ == "__main__": + run_async_tests() diff --git a/tests/test_promptsdk_async.py b/tests/test_promptsdk_async.py new file mode 100644 index 0000000..6f8673c --- /dev/null +++ b/tests/test_promptsdk_async.py @@ -0,0 +1,207 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +from basalt.sdk.promptsdk import PromptSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import PromptDTO, PromptVersionDTO, PromptVariablesDTO, GetPromptDTO +from basalt.endpoints.get_prompt import GetPromptEndpoint, GetPromptEndpointResponse +from basalt.endpoints.list_prompts import ListPromptsEndpoint, ListPromptsEndpointResponse +from basalt.endpoints.describe_prompt import DescribePromptEndpoint, DescribePromptEndpointResponse + +logger = Logger() +mocked_api = MagicMock() +# Make sure async_invoke is an AsyncMock +mocked_api.async_invoke = AsyncMock() + +# Mock responses for different endpoints +prompt_get_response = GetPromptEndpointResponse( + prompt=PromptDTO( + slug="test-prompt", + feature_slug="test-feature", + name="Test Prompt", + description="A test prompt", + text="This is a test prompt with {{variable}}", + version="1.0", + tags=["test"], + variables=[PromptVariablesDTO(name="variable", description="A test variable")] + ) +) + +prompt_list_response = ListPromptsEndpointResponse( + prompts=[ + PromptDTO( + slug="test-prompt-1", + feature_slug="test-feature", + name="Test Prompt 1", + version="1.0" + ), + PromptDTO( + slug="test-prompt-2", + feature_slug="test-feature", + name="Test Prompt 2", + version="1.0" + ) + ] +) + +prompt_describe_response = DescribePromptEndpointResponse( + prompt=PromptDTO( + slug="test-prompt", + feature_slug="test-feature", + name="Test Prompt", + description="A test prompt", + text="This is a test prompt with {{variable}}", + version="1.0", + tags=["test"], + variables=[PromptVariablesDTO(name="variable", description="A test variable")] + ), + versions=[ + PromptVersionDTO( + version="1.0", + text="This is a test prompt with {{variable}}", + created_at="2023-01-01T00:00:00Z" + ), + PromptVersionDTO( + version="0.9", + text="This is an older test prompt with {{variable}}", + created_at="2022-12-01T00:00:00Z" + ) + ] +) + + +class TestPromptSDKAsync(unittest.TestCase): + def setUp(self): + self.prompt_sdk = PromptSDK( + api=mocked_api, + logger=logger + ) + # Reset mock calls before each test + mocked_api.async_invoke.reset_mock() + + async def test_async_get_prompt(self): + """Test asynchronously getting a prompt""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_get_response) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get("test-prompt") + + # Assertions + self.assertIsNone(err) + self.assertEqual(prompt_response.text, "This is a test prompt with {{variable}}") + self.assertEqual(prompt_response.slug, "test-prompt") + self.assertIsNone(generation) # No monitoring in this test + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, GetPromptEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-prompt") + + async def test_async_list_prompts(self): + """Test asynchronously listing prompts""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_list_response) + + # Call the method + err, prompts = await self.prompt_sdk.async_list() + + # Assertions + self.assertIsNone(err) + self.assertEqual(len(prompts), 2) + self.assertEqual(prompts[0].slug, "test-prompt-1") + self.assertEqual(prompts[1].slug, "test-prompt-2") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, ListPromptsEndpoint) + + async def test_async_list_prompts_with_feature_filter(self): + """Test asynchronously listing prompts with feature filter""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_list_response) + + # Call the method + err, prompts = await self.prompt_sdk.async_list(feature_slug="test-feature") + + # Assertions + self.assertIsNone(err) + self.assertEqual(len(prompts), 2) + + # Verify DTO was created correctly + dto = mocked_api.async_invoke.call_args[0][1] + self.assertEqual(dto.feature_slug, "test-feature") + + async def test_async_describe_prompt(self): + """Test asynchronously describing a prompt""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_describe_response) + + # Call the method + err, prompt_description = await self.prompt_sdk.async_describe("test-prompt") + + # Assertions + self.assertIsNone(err) + self.assertEqual(prompt_description.prompt.slug, "test-prompt") + self.assertEqual(len(prompt_description.versions), 2) + self.assertEqual(prompt_description.versions[0].version, "1.0") + self.assertEqual(prompt_description.versions[1].version, "0.9") + + # Verify correct endpoint was used + endpoint = mocked_api.async_invoke.call_args[0][0] + self.assertEqual(endpoint, DescribePromptEndpoint) + + async def test_async_get_prompt_with_variables(self): + """Test asynchronously getting a prompt with variables replaced""" + # Configure mock + mocked_api.async_invoke.return_value = (None, prompt_get_response) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get( + "test-prompt", + variables={"variable": "test-value"} + ) + + # Assertions + self.assertIsNone(err) + # Variables should be replaced in the response + self.assertEqual(prompt_response.text, "This is a test prompt with test-value") + + async def test_async_get_prompt_error_handling(self): + """Test error handling when asynchronously getting a prompt""" + # Configure mock to return an error + error = Exception("API Error") + mocked_api.async_invoke.return_value = (error, None) + + # Call the method + err, prompt_response, generation = await self.prompt_sdk.async_get("non-existent") + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(prompt_response) + self.assertIsNone(generation) + self.assertEqual(str(err), "API Error") + + +def run_async_tests(): + """ + Helper function to run async tests + """ + loop = asyncio.get_event_loop() + + # Create and run the test suite + suite = unittest.TestLoader().loadTestsFromTestCase(TestPromptSDKAsync) + runner = unittest.TextTestRunner() + + for test in suite: + if test._testMethodName.startswith('test_async_'): + coro = getattr(test, test._testMethodName)() + loop.run_until_complete(coro) + + +if __name__ == "__main__": + run_async_tests() From b3658fc25dee8f21ec85b2d818da7ef013256d95 Mon Sep 17 00:00:00 2001 From: Theophile Cousin Date: Mon, 30 Jun 2025 10:52:59 +0200 Subject: [PATCH 4/5] add async jup notebooks --- examples/dataset_sdk_async_demo.ipynb | 288 ++++++++++++++++++++++++++ examples/monitor_sdk_async_demo.ipynb | 267 ++++++++++++++++++++++++ examples/prompt_sdk_async_demo.ipynb | 241 +++++++++++++++++++++ 3 files changed, 796 insertions(+) create mode 100644 examples/dataset_sdk_async_demo.ipynb create mode 100644 examples/monitor_sdk_async_demo.ipynb create mode 100644 examples/prompt_sdk_async_demo.ipynb diff --git a/examples/dataset_sdk_async_demo.ipynb b/examples/dataset_sdk_async_demo.ipynb new file mode 100644 index 0000000..a11ac85 --- /dev/null +++ b/examples/dataset_sdk_async_demo.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt DatasetSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the DatasetSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously List All Datasets\n", + "\n", + "This example demonstrates how to list all datasets asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def list_datasets():\n", + " print(\"Listing all datasets asynchronously...\")\n", + " err, datasets = await basalt.datasets.async_list()\n", + " if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + " else:\n", + " print(f\"Found {len(datasets)} datasets\")\n", + " for dataset in datasets:\n", + " print(f\"- {dataset.name} (slug: {dataset.slug})\")\n", + " return datasets\n", + "\n", + "# Run the async function\n", + "datasets = await list_datasets()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Get a Specific Dataset\n", + "\n", + "This example demonstrates how to retrieve a specific dataset by its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_dataset(datasets):\n", + " print(\"\\nGetting a specific dataset asynchronously...\")\n", + " if len(datasets) > 0:\n", + " sample_dataset = datasets[0]\n", + " err, dataset = await basalt.datasets.async_get(sample_dataset.slug)\n", + " if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + " else:\n", + " print(f\"Retrieved dataset: {dataset.name}\")\n", + " print(f\"Columns: {dataset.columns}\")\n", + " print(f\"Number of rows: {len(dataset.rows) if dataset.rows else 0}\")\n", + " return sample_dataset, dataset\n", + " else:\n", + " print(\"No datasets available\")\n", + " return None, None\n", + "\n", + "# Run the async function\n", + "sample_dataset, dataset = await get_dataset(datasets)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Add a Row to a Dataset\n", + "\n", + "This example demonstrates how to add a new row to an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def add_row(sample_dataset):\n", + " print(\"\\nAdding a row to a dataset asynchronously...\")\n", + " if sample_dataset:\n", + " # Create some sample values for the dataset row\n", + " values = {column: f\"Sample {column} value\" for column in sample_dataset.columns}\n", + " \n", + " err, row, warning = await basalt.datasets.async_addRow(\n", + " slug=sample_dataset.slug,\n", + " values=values,\n", + " name=\"Async Sample Row\",\n", + " ideal_output=\"Expected output for this row\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error adding row to dataset: {err}\")\n", + " elif warning:\n", + " print(f\"Row added with warning: {warning}\")\n", + " print(f\"Row values: {row.values}\")\n", + " else:\n", + " print(f\"Row added successfully\")\n", + " print(f\"Row values: {row.values}\")\n", + " print(f\"Row name: {row.name}\")\n", + "\n", + "# Run the async function\n", + "await add_row(sample_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Get a Dataset as an Object\n", + "\n", + "This example demonstrates how to retrieve a dataset as an object for further manipulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_dataset_object(sample_dataset):\n", + " print(\"\\nGetting a dataset as an object asynchronously...\")\n", + " if sample_dataset:\n", + " dataset_obj = await basalt.datasets.async_get_dataset_object(sample_dataset.slug)\n", + " if dataset_obj:\n", + " print(f\"Retrieved dataset object: {dataset_obj.name}\")\n", + " print(f\"Number of rows: {len(dataset_obj.rows)}\")\n", + " else:\n", + " print(\"Failed to get dataset object\")\n", + " return dataset_obj\n", + " else:\n", + " print(\"No sample dataset available\")\n", + " return None\n", + "\n", + "# Run the async function\n", + "dataset_obj = await get_dataset_object(sample_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Asynchronously Add a Row to a Dataset Object\n", + "\n", + "This example demonstrates how to add a row directly to a dataset object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def add_row_to_dataset_object(dataset_obj):\n", + " print(\"\\nAdding a row to a dataset object asynchronously...\")\n", + " if dataset_obj:\n", + " # Create some sample values for the dataset row\n", + " values = {column: f\"Another sample {column} value\" for column in dataset_obj.columns}\n", + " \n", + " row = await basalt.datasets.async_add_row_to_dataset(\n", + " dataset=dataset_obj,\n", + " values=values,\n", + " name=\"Another Async Sample Row\",\n", + " ideal_output=\"Another expected output\",\n", + " metadata={\"source\": \"async_dataset_object_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " if row:\n", + " print(f\"Row added to dataset object successfully\")\n", + " print(f\"Row values: {row.values}\")\n", + " print(f\"Updated number of rows in dataset object: {len(dataset_obj.rows)}\")\n", + " else:\n", + " print(\"Failed to add row to dataset object\")\n", + "\n", + "# Run the async function\n", + "await add_row_to_dataset_object(dataset_obj)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 6: Execute Multiple Dataset Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations(sample_dataset):\n", + " print(\"\\nExecuting multiple dataset operations concurrently...\")\n", + " \n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.datasets.async_list(),\n", + " basalt.datasets.async_get(sample_dataset.slug) if sample_dataset else None\n", + " ]\n", + " \n", + " # Filter out None tasks\n", + " tasks = [t for t in tasks if t is not None]\n", + " \n", + " if tasks:\n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"Number of datasets listed: {len(results[0][1]) if results[0][1] else 'Error'}\")\n", + " if len(tasks) > 1:\n", + " print(f\"Retrieved dataset: {results[1][1].name if results[1][1] else 'Error'}\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations(sample_dataset)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/monitor_sdk_async_demo.ipynb b/examples/monitor_sdk_async_demo.ipynb new file mode 100644 index 0000000..1c756af --- /dev/null +++ b/examples/monitor_sdk_async_demo.ipynb @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt MonitorSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the MonitorSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os\n", + "from basalt.ressources.monitor.monitorsdk_types import (\n", + " ExperimentParams, TraceParams, GenerationParams, LogParams\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously Create a Trace\n", + "\n", + "This example demonstrates how to create a trace asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_trace():\n", + " print(\"Creating a trace asynchronously...\")\n", + " trace_params = TraceParams(\n", + " name=\"Async Test Trace\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " trace = await basalt.monitor.async_create_trace(\n", + " slug=\"async-test-trace\",\n", + " params=trace_params\n", + " )\n", + " \n", + " print(f\"Created trace: {trace.id}\")\n", + " print(f\"Trace name: {trace.name}\")\n", + " print(f\"Trace slug: {trace.slug}\")\n", + " \n", + " return trace\n", + "\n", + "# Run the async function\n", + "trace = await create_trace()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Create a Generation\n", + "\n", + "This example demonstrates how to create a generation associated with a trace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_generation(trace):\n", + " print(\"\\nCreating a generation asynchronously...\")\n", + " gen_params = GenerationParams(\n", + " trace_id=trace.id,\n", + " text=\"This is an async test generation\",\n", + " model_id=\"gpt-4\",\n", + " prompt=\"Generate a response asynchronously\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " generation = await basalt.monitor.async_create_generation(gen_params)\n", + " \n", + " print(f\"Created generation: {generation.id}\")\n", + " print(f\"Generation text: {generation.text}\")\n", + " print(f\"Generation model: {generation.model_id}\")\n", + " \n", + " return generation\n", + "\n", + "# Run the async function\n", + "generation = await create_generation(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Create a Log\n", + "\n", + "This example demonstrates how to create a log entry associated with a trace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_log(trace):\n", + " print(\"\\nCreating a log asynchronously...\")\n", + " log_params = LogParams(\n", + " trace_id=trace.id,\n", + " type=\"info\",\n", + " message=\"This is an async test log message\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + " \n", + " log = await basalt.monitor.async_create_log(log_params)\n", + " \n", + " print(f\"Created log: {log['id']}\")\n", + " print(f\"Log message: {log['message']}\")\n", + " print(f\"Log type: {log['type']}\")\n", + " \n", + " return log\n", + "\n", + "# Run the async function\n", + "log = await create_log(trace)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Create an Experiment\n", + "\n", + "This example demonstrates how to create an experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def create_experiment():\n", + " print(\"\\nCreating an experiment asynchronously...\")\n", + " exp_params = ExperimentParams(\n", + " type=\"A/B Test\",\n", + " name=\"Async Test Experiment\",\n", + " setup={\n", + " \"control_id\": \"control-prompt\",\n", + " \"variation_id\": \"test-prompt\"\n", + " }\n", + " )\n", + " \n", + " experiment = await basalt.monitor.async_create_experiment(\n", + " \"async-test-feature\",\n", + " exp_params\n", + " )\n", + " \n", + " print(f\"Created experiment: {experiment.id}\")\n", + " print(f\"Experiment name: {experiment.name}\")\n", + " print(f\"Experiment type: {experiment.type}\")\n", + " \n", + " return experiment\n", + "\n", + "# Run the async function\n", + "experiment = await create_experiment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Execute Multiple Monitoring Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous monitoring operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations():\n", + " print(\"\\nExecuting multiple monitoring operations concurrently...\")\n", + " \n", + " # Create trace parameters for concurrent operations\n", + " trace_params1 = TraceParams(\n", + " name=\"Concurrent Trace 1\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 1}\n", + " )\n", + " \n", + " trace_params2 = TraceParams(\n", + " name=\"Concurrent Trace 2\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 2}\n", + " )\n", + " \n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.monitor.async_create_trace(\"concurrent-trace-1\", trace_params1),\n", + " basalt.monitor.async_create_trace(\"concurrent-trace-2\", trace_params2)\n", + " ]\n", + " \n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"First trace: {results[0].name} (id: {results[0].id})\")\n", + " print(f\"Second trace: {results[1].name} (id: {results[1].id})\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/prompt_sdk_async_demo.ipynb b/examples/prompt_sdk_async_demo.ipynb new file mode 100644 index 0000000..ca996ce --- /dev/null +++ b/examples/prompt_sdk_async_demo.ipynb @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt PromptSDK Async Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the PromptSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from basalt import Basalt\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Initialize the SDK with your API key. If you have the BASALT_API_KEY environment variable set, it will use that; otherwise, replace \"your-api-key\" with your actual API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the SDK with your API key\n", + "api_key = os.environ.get(\"BASALT_API_KEY\", \"your-api-key\")\n", + "\n", + "# Create a Basalt client\n", + "basalt = Basalt(api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: Asynchronously List All Prompts\n", + "\n", + "This example demonstrates how to list all prompts asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def list_prompts():\n", + " print(\"Listing all prompts asynchronously...\")\n", + " err, prompts = await basalt.prompt.async_list()\n", + " if err:\n", + " print(f\"Error listing prompts: {err}\")\n", + " else:\n", + " print(f\"Found {len(prompts)} prompts\")\n", + " for prompt in prompts:\n", + " print(f\"- {prompt.name} (slug: {prompt.slug})\")\n", + " return prompts\n", + "\n", + "# Run the async function\n", + "prompts = await list_prompts()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Asynchronously Get a Specific Prompt\n", + "\n", + "This example demonstrates how to retrieve a specific prompt by its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_prompt(prompts):\n", + " print(\"\\nGetting a specific prompt asynchronously...\")\n", + " if len(prompts) > 0:\n", + " sample_prompt = prompts[0]\n", + " err, prompt_response, generation = await basalt.prompt.async_get(sample_prompt.slug)\n", + " if err:\n", + " print(f\"Error getting prompt: {err}\")\n", + " else:\n", + " print(f\"Retrieved prompt: {sample_prompt.name}\")\n", + " print(f\"Text: {prompt_response.text}\")\n", + " return sample_prompt, prompt_response, generation\n", + " else:\n", + " print(\"No prompts available\")\n", + " return None, None, None\n", + "\n", + "# Run the async function\n", + "sample_prompt, prompt_response, generation = await get_prompt(prompts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Asynchronously Describe a Prompt\n", + "\n", + "This example demonstrates how to get detailed description information about a prompt." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def describe_prompt(sample_prompt):\n", + " print(\"\\nDescribing a prompt asynchronously...\")\n", + " if sample_prompt:\n", + " err, description = await basalt.prompt.async_describe(sample_prompt.slug)\n", + " if err:\n", + " print(f\"Error describing prompt: {err}\")\n", + " else:\n", + " print(f\"Prompt: {description.prompt.name}\")\n", + " print(f\"Versions available: {len(description.versions)}\")\n", + " for version in description.versions:\n", + " print(f\"- Version {version.version} created at {version.created_at}\")\n", + " return description\n", + " else:\n", + " print(\"No sample prompt available\")\n", + " return None\n", + "\n", + "# Run the async function\n", + "description = await describe_prompt(sample_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 4: Asynchronously Get a Prompt with Variable Substitution\n", + "\n", + "This example demonstrates how to retrieve a prompt with variables substituted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_prompt_with_variables(sample_prompt):\n", + " print(\"\\nGetting a prompt with variable substitution asynchronously...\")\n", + " if sample_prompt:\n", + " err, prompt_response, generation = await basalt.prompt.async_get(\n", + " sample_prompt.slug,\n", + " variables={\"name\": \"John\", \"company\": \"Acme Inc\"}\n", + " )\n", + " if err:\n", + " print(f\"Error getting prompt with variables: {err}\")\n", + " else:\n", + " print(f\"Retrieved prompt with variables: {sample_prompt.name}\")\n", + " print(f\"Text with variables: {prompt_response.text}\")\n", + " return prompt_response, generation\n", + " else:\n", + " print(\"No sample prompt available\")\n", + " return None, None\n", + "\n", + "# Run the async function\n", + "prompt_with_vars, generation_with_vars = await get_prompt_with_variables(sample_prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Execute Multiple Prompt Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def execute_concurrent_operations(prompts):\n", + " print(\"\\nExecuting multiple prompt operations concurrently...\")\n", + " if len(prompts) >= 2:\n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.prompt.async_get(prompts[0].slug),\n", + " basalt.prompt.async_get(prompts[1].slug),\n", + " basalt.prompt.async_list()\n", + " ]\n", + " \n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + " \n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"First prompt: {results[0][1].slug if results[0][1] else 'Error'}\")\n", + " print(f\"Second prompt: {results[1][1].slug if results[1][1] else 'Error'}\")\n", + " print(f\"Number of prompts listed: {len(results[2][1]) if results[2][1] else 'Error'}\")\n", + " else:\n", + " print(\"Not enough prompts available for concurrent operations example\")\n", + "\n", + "# Run the async function\n", + "await execute_concurrent_operations(prompts)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From d44a11a4420ddbf103172f85ae90647ee00f9b0c Mon Sep 17 00:00:00 2001 From: Theophile Cousin Date: Tue, 1 Jul 2025 17:18:24 +0200 Subject: [PATCH 5/5] fix rebase --- basalt/objects/dataset.py | 10 +- basalt/ressources/monitor/generation_types.py | 2 + examples/dataset_sdk_demo.ipynb | 488 +++++++++--------- tests/test_datasetsdk_async.py | 91 ++-- tests/test_monitorsdk_async.py | 253 ++++----- tests/test_promptsdk_async.py | 152 +++--- 6 files changed, 465 insertions(+), 531 deletions(-) diff --git a/basalt/objects/dataset.py b/basalt/objects/dataset.py index d502608..a2577d9 100644 --- a/basalt/objects/dataset.py +++ b/basalt/objects/dataset.py @@ -20,14 +20,10 @@ def to_dict(self) -> Dict[str, Any]: """Convert the DatasetRow to a dictionary for API requests""" result = { "values": self.values, - "metadata": self.metadata + "metadata": self.metadata, + "name": self.name, + "idealOutput": self.ideal_output } - - if self.name: - result["name"] = self.name - - if self.ideal_output: - result["idealOutput"] = self.ideal_output return result diff --git a/basalt/ressources/monitor/generation_types.py b/basalt/ressources/monitor/generation_types.py index 24f56e8..761aabd 100644 --- a/basalt/ressources/monitor/generation_types.py +++ b/basalt/ressources/monitor/generation_types.py @@ -34,6 +34,7 @@ class GenerationParams(BaseLogParams): input (Optional[str]): The input provided to the model. output (Optional[str]): The output generated by the model. variables (Optional[Dict[str, Any]]): Variables used in the prompt template. + options (Optional[Dict[str, Any]]): Additional options for the generation. Example: ```python @@ -57,6 +58,7 @@ class GenerationParams(BaseLogParams): input: Optional[str] = None output: Optional[str] = None variables: Optional[Dict[str, Any]] = None + options: Optional[Dict[str, Any]] = None @dataclass class Generation(BaseLog): diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb index 26daf7e..e732583 100644 --- a/examples/dataset_sdk_demo.ipynb +++ b/examples/dataset_sdk_demo.ipynb @@ -1,258 +1,232 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Basalt Dataset SDK Demo\n", - "\n", - "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", - "\n", - "os.environ[\"BASALT_BUILD\"] = \"development\"\n", - "\n", - "from basalt import Basalt\n", - "\n", - "# Initialize the SDK\n", - "basalt = Basalt(\n", -<<<<<<< HEAD - " api_key=\"sk-df55a...\", # Replace with your API key\n", -======= - " api_key=\"sk-...\", # Replace with your API key\n", ->>>>>>> 6abcf23 (rewrite example as notebook jupyter) - " log_level=\"debug\" # Optional: Set log level\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Listing Available Datasets\n", - "\n", - "Retrieve all datasets available in your workspace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List all datasets in the workspace\n", - "err, datasets = basalt.datasets.list()\n", - "\n", - "if err:\n", - " print(f\"Error listing datasets: {err}\")\n", - "else:\n", - " print(f\"Found {len(datasets)} datasets:\")\n", - " for i, dataset in enumerate(datasets):\n", - " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", -<<<<<<< HEAD -======= - " print(f\" - Description: {dataset.description if dataset.description else 'No description'}\")\n", ->>>>>>> 6abcf23 (rewrite example as notebook jupyter) - " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", - " \n", - " # Store the first dataset slug for later use (if available)\n", - " first_dataset_slug = datasets[0].slug if datasets else None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Getting a Specific Dataset\n", - "\n", - "Retrieve details for a specific dataset using its slug." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the first dataset from the list or replace with a specific slug\n", - "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", - "\n", - "err, dataset = basalt.datasets.get(dataset_slug)\n", - "\n", - "if err:\n", - " print(f\"Error getting dataset: {err}\")\n", - "else:\n", - " print(f\"Dataset details for '{dataset.name}'\")\n", - " print(f\"Slug: {dataset.slug}\")\n", -<<<<<<< HEAD -======= - " print(f\"Description: {dataset.description if dataset.description else 'No description'}\")\n", ->>>>>>> 6abcf23 (rewrite example as notebook jupyter) - " print(f\"Columns: {', '.join(dataset.columns)}\")\n", - " print(f\"Number of rows: {len(dataset.rows)}\")\n", - " \n", - " if dataset.rows:\n", - " print(\"\\nSample rows:\")\n", - " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", - " print(f\"Row {i+1}:\")\n", -<<<<<<< HEAD - " print(f\" Values: {row.get('values')}\")\n", - " if 'name' in row:\n", - " print(f\" Name: {row['name']}\")\n", - " if 'idealOutput' in row:\n", - " print(f\" Ideal output: {row['idealOutput']}\")\n", - " if 'metadata' in row:\n", - " print(f\" Metadata: {row['metadata']}\")" -======= - " print(f\" Values: {row.values}\")\n", - " if row.name:\n", - " print(f\" Name: {row.name}\")\n", - " if row.idealOutput:\n", - " print(f\" Ideal output: {row.idealOutput}\")\n", - " if row.metadata:\n", - " print(f\" Metadata: {row.metadata}\")" ->>>>>>> 6abcf23 (rewrite example as notebook jupyter) - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Adding a Row to a Dataset\n", - "\n", - "Create a new row (item) in an existing dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the dataset from the previous example\n", - "if 'dataset' in locals() and dataset:\n", - " # Build values for all columns in the dataset\n", - " values = {}\n", - " for column in dataset.columns:\n", - " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", - " \n", - " # Create the row\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=dataset.slug,\n", - " values=values,\n", - " name=\"Notebook Example Row\",\n", - " ideal_output=\"This is an ideal output for this row\",\n", - " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " else:\n", - " print(\"Successfully created new dataset row:\")\n", - " print(f\"Values: {row.values}\")\n", - " print(f\"Name: {row.name}\")\n", - " print(f\"Ideal output: {row.idealOutput}\")\n", - " print(f\"Metadata: {row.metadata}\")\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - "else:\n", - " print(\"Please run the previous cell to get a dataset first\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Error Handling with Dataset SDK\n", - "\n", - "Demonstrate proper error handling when working with datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", - " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", - " try:\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=slug,\n", - " values=values,\n", - " name=name,\n", - " ideal_output=ideal_output,\n", - " metadata=metadata\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " return None\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - " \n", - " return row\n", - " except Exception as e:\n", - " print(f\"Unexpected error: {str(e)}\")\n", - " return None\n", - "\n", - "# Test with a valid dataset\n", - "if 'dataset_slug' in locals() and dataset_slug:\n", - " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", - " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", - " \n", - " if row:\n", - " print(f\"Successfully created row: {row.name}\")\n", - "\n", - "# Test with an invalid dataset slug\n", - "print(\"\\nTesting with invalid dataset slug:\")\n", - "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", - "print(f\"Result with invalid slug: {invalid_row}\")\n", - "\n", - "# Test with missing required values\n", - "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", - " print(\"\\nTesting with missing required values:\")\n", - " # Deliberately create incomplete values dict\n", - " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", - " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", - " print(f\"Result with incomplete values: {incomplete_row}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", -<<<<<<< HEAD - "version": "3.13.3" -======= - "version": "3.13.2" ->>>>>>> 6abcf23 (rewrite example as notebook jupyter) - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt Dataset SDK Demo\n", + "\n", + "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "os.environ[\"BASALT_BUILD\"] = \"development\"\n", + "\n", + "from basalt import Basalt\n", + "\n", + "# Initialize the SDK\n", + "basalt = Basalt(\n", + " api_key=\"sk-...\", # Replace with your API key\n", + " log_level=\"debug\" # Optional: Set log level\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Listing Available Datasets\n", + "\n", + "Retrieve all datasets available in your workspace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List all datasets in the workspace\n", + "err, datasets = basalt.datasets.list()\n", + "\n", + "if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + "else:\n", + " print(f\"Found {len(datasets)} datasets:\")\n", + " for i, dataset in enumerate(datasets):\n", + " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", + " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", + " \n", + " # Store the first dataset slug for later use (if available)\n", + " first_dataset_slug = datasets[0].slug if datasets else None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Getting a Specific Dataset\n", + "\n", + "Retrieve details for a specific dataset using its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the first dataset from the list or replace with a specific slug\n", + "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", + "\n", + "err, dataset = basalt.datasets.get(dataset_slug)\n", + "\n", + "if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + "else:\n", + " print(f\"Dataset details for '{dataset.name}'\")\n", + " print(f\"Slug: {dataset.slug}\")\n", + " print(f\"Columns: {', '.join(dataset.columns)}\")\n", + " print(f\"Number of rows: {len(dataset.rows)}\")\n", + " \n", + " if dataset.rows:\n", + " print(\"\\nSample rows:\")\n", + " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", + " print(f\"Row {i+1}:\")\n", + " print(f\" Values: {row.get('values')}\")\n", + " if 'name' in row:\n", + " print(f\" Name: {row['name']}\")\n", + " if 'idealOutput' in row:\n", + " print(f\" Ideal output: {row['idealOutput']}\")\n", + " if 'metadata' in row:\n", + " print(f\" Metadata: {row['metadata']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Adding a Row to a Dataset\n", + "\n", + "Create a new row (item) in an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the dataset from the previous example\n", + "if 'dataset' in locals() and dataset:\n", + " # Build values for all columns in the dataset\n", + " values = {}\n", + " for column in dataset.columns:\n", + " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", + " \n", + " # Create the row\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=dataset.slug,\n", + " values=values,\n", + " name=\"Notebook Example Row\",\n", + " ideal_output=\"This is an ideal output for this row\",\n", + " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " else:\n", + " print(\"Successfully created new dataset row:\")\n", + " print(f\"Values: {row.values}\")\n", + " print(f\"Name: {row.name}\")\n", + " print(f\"Ideal output: {row.idealOutput}\")\n", + " print(f\"Metadata: {row.metadata}\")\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + "else:\n", + " print(\"Please run the previous cell to get a dataset first\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Error Handling with Dataset SDK\n", + "\n", + "Demonstrate proper error handling when working with datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", + " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", + " try:\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=slug,\n", + " values=values,\n", + " name=name,\n", + " ideal_output=ideal_output,\n", + " metadata=metadata\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " return None\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + " \n", + " return row\n", + " except Exception as e:\n", + " print(f\"Unexpected error: {str(e)}\")\n", + " return None\n", + "\n", + "# Test with a valid dataset\n", + "if 'dataset_slug' in locals() and dataset_slug:\n", + " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", + " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", + " \n", + " if row:\n", + " print(f\"Successfully created row: {row.name}\")\n", + "\n", + "# Test with an invalid dataset slug\n", + "print(\"\\nTesting with invalid dataset slug:\")\n", + "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", + "print(f\"Result with invalid slug: {invalid_row}\")\n", + "\n", + "# Test with missing required values\n", + "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", + " print(\"\\nTesting with missing required values:\")\n", + " # Deliberately create incomplete values dict\n", + " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", + " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", + " print(f\"Result with incomplete values: {incomplete_row}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/tests/test_datasetsdk_async.py b/tests/test_datasetsdk_async.py index d5090a8..bd1fc2d 100644 --- a/tests/test_datasetsdk_async.py +++ b/tests/test_datasetsdk_async.py @@ -4,7 +4,7 @@ from basalt.sdk.datasetsdk import DatasetSDK from basalt.utils.logger import Logger -from basalt.utils.dtos import DatasetDTO, DatasetRowDTO +from basalt.utils.dtos import DatasetDTO, DatasetRowDTO, ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO from basalt.endpoints.list_datasets import ListDatasetsEndpoint, ListDatasetsEndpointResponse from basalt.endpoints.get_dataset import GetDatasetEndpoint, GetDatasetEndpointResponse from basalt.endpoints.create_dataset_item import CreateDatasetItemEndpoint, CreateDatasetItemEndpointResponse @@ -160,65 +160,44 @@ async def test_async_error_handling_get_dataset(self): self.assertIsNone(dataset) self.assertEqual(str(err), "API Error") - async def test_async_get_dataset_object(self): - """Test asynchronously getting a dataset as an object""" - # Configure mock - mocked_api.async_invoke.return_value = (None, dataset_get_response) - - # Call the method - dataset = await self.dataset_sdk.async_get_dataset_object("test-dataset") - - # Assertions - self.assertIsNotNone(dataset) - self.assertEqual(dataset.slug, "test-dataset") - self.assertEqual(dataset.name, "Test Dataset") - self.assertEqual(len(dataset.rows), 1) - self.assertEqual(dataset.rows[0].values, {"input": "Sample input", "output": "Sample output"}) - - async def test_async_add_row_to_dataset(self): - """Test asynchronously adding a row to a dataset object""" - # First get a dataset object - mocked_api.async_invoke.return_value = (None, dataset_get_response) - dataset = await self.dataset_sdk.async_get_dataset_object("test-dataset") - - # Then add a row to it - mocked_api.async_invoke.return_value = (None, dataset_add_row_response) - - values = {"input": "New input", "output": "New output"} - row = await self.dataset_sdk.async_add_row_to_dataset( - dataset=dataset, - values=values, - name="New Row", - ideal_output="New ideal output", - metadata={"source": "test"} - ) - - # Assertions - self.assertIsNotNone(row) - self.assertEqual(row.values, values) - self.assertEqual(row.name, "New Row") - self.assertEqual(row.ideal_output, "New ideal output") - - # Check that the row was added to the dataset - self.assertEqual(len(dataset.rows), 2) - self.assertEqual(dataset.rows[-1], row) -def run_async_tests(): - """ - Helper function to run async tests - """ - loop = asyncio.get_event_loop() +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - # Create and run the test suite - suite = unittest.TestLoader().loadTestsFromTestCase(TestDatasetSDKAsync) - runner = unittest.TextTestRunner() + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") - for test in suite: - if test._testMethodName.startswith('test_async_'): - coro = getattr(test, test._testMethodName)() - loop.run_until_complete(coro) + def close(self): + self.loop.close() if __name__ == "__main__": - run_async_tests() + runner = AsyncTestRunner() + try: + runner.run_test_case(TestDatasetSDKAsync) + finally: + runner.close() diff --git a/tests/test_monitorsdk_async.py b/tests/test_monitorsdk_async.py index e527ef1..ec117ee 100644 --- a/tests/test_monitorsdk_async.py +++ b/tests/test_monitorsdk_async.py @@ -4,72 +4,38 @@ from basalt.sdk.monitorsdk import MonitorSDK from basalt.utils.logger import Logger -from basalt.ressources.monitor.monitorsdk_types import ( - Experiment, Trace, Generation, LogParams, GenerationParams, - TraceParams, ExperimentParams -) -from basalt.endpoints.monitor.create_experiment import CreateExperimentEndpoint, CreateExperimentEndpointResponse -from basalt.endpoints.monitor.send_trace import CreateTraceEndpoint, CreateTraceEndpointResponse -from basalt.endpoints.monitor.create_generation import CreateGenerationEndpoint, CreateGenerationEndpointResponse -from basalt.endpoints.monitor.create_log import CreateLogEndpoint, CreateLogEndpointResponse +from basalt.ressources.monitor.experiment_types import ExperimentParams +from basalt.ressources.monitor.trace_types import TraceParams +from basalt.ressources.monitor.generation_types import GenerationParams +from basalt.ressources.monitor.log_types import LogParams +from basalt.objects.experiment import Experiment +from basalt.objects.trace import Trace +from basalt.objects.generation import Generation +from basalt.objects.log import Log +from basalt.endpoints.monitor.create_experiment import CreateExperimentEndpoint, Output as ExperimentOutput logger = Logger() mocked_api = MagicMock() # Make sure async_invoke is an AsyncMock mocked_api.async_invoke = AsyncMock() -# Mock responses for different endpoints -experiment_response = CreateExperimentEndpointResponse( - experiment=Experiment( - id="exp-123", - feature_slug="test-feature", - run_id="run-123", - type="A/B Test", - name="Test Experiment", - setup={ - "control_id": "control-123", - "variation_id": "variation-123" - } - ) -) - -trace_response = CreateTraceEndpointResponse( - trace=Trace( - id="trace-123", - name="Test Trace", - slug="test-trace", - metadata={"source": "test"}, - run_id="run-123", - tool=None, - model_id=None, - created_at="2023-01-01T00:00:00Z" - ) -) +# Mock experiment data that matches the actual Experiment structure +experiment_data = { + "id": "exp-123", + "featureSlug": "test-feature", + "name": "Test Experiment", + "createdAt": "2023-01-01T00:00:00Z" +} -generation_response = CreateGenerationEndpointResponse( - generation=Generation( - id="gen-123", - trace_id="trace-123", - run_id="run-123", - text="Generated text", - model_id="gpt-4", - prompt="Test prompt", - metadata={"source": "test"}, - created_at="2023-01-01T00:00:00Z" - ) +experiment_output = ExperimentOutput( + experiment=type('Experiment', (), experiment_data)() ) -log_response = CreateLogEndpointResponse( - log={ - "id": "log-123", - "trace_id": "trace-123", - "run_id": "run-123", - "type": "info", - "message": "Test log message", - "metadata": {"source": "test"}, - "created_at": "2023-01-01T00:00:00Z" - } -) +# Set experiment attributes properly +experiment_output.experiment.id = experiment_data["id"] +experiment_output.experiment.feature_slug = experiment_data["featureSlug"] +experiment_output.experiment.name = experiment_data["name"] +experiment_output.experiment.created_at = experiment_data["createdAt"] class TestMonitorSDKAsync(unittest.TestCase): @@ -84,24 +50,18 @@ def setUp(self): async def test_async_create_experiment(self): """Test asynchronously creating an experiment""" # Configure mock - mocked_api.async_invoke.return_value = (None, experiment_response) + mocked_api.async_invoke.return_value = (None, experiment_output) # Call the method - params = ExperimentParams( - type="A/B Test", - name="Test Experiment", - setup={ - "control_id": "control-123", - "variation_id": "variation-123" - } - ) - result = await self.monitor_sdk.async_create_experiment("test-feature", params) + params = {"name": "Test Experiment"} + err, result = await self.monitor_sdk.async_create_experiment("test-feature", params) # Assertions + self.assertIsNone(err) self.assertIsNotNone(result) self.assertEqual(result.id, "exp-123") self.assertEqual(result.feature_slug, "test-feature") - self.assertEqual(result.type, "A/B Test") + self.assertEqual(result.name, "Test Experiment") # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] @@ -110,126 +70,113 @@ async def test_async_create_experiment(self): # Verify DTO was created correctly dto = mocked_api.async_invoke.call_args[0][1] self.assertEqual(dto.feature_slug, "test-feature") - self.assertEqual(dto.params.type, "A/B Test") + self.assertEqual(dto.name, "Test Experiment") async def test_async_create_trace(self): """Test asynchronously creating a trace""" - # Configure mock - mocked_api.async_invoke.return_value = (None, trace_response) - - # Call the method - params = TraceParams( - name="Test Trace", - metadata={"source": "test"} - ) + # Call the method - traces are created directly without API calls + params = { + "name": "Test Trace", + "metadata": {"source": "test"} + } result = await self.monitor_sdk.async_create_trace("test-trace", params) # Assertions self.assertIsNotNone(result) - self.assertEqual(result.id, "trace-123") + self.assertIsInstance(result, Trace) self.assertEqual(result.name, "Test Trace") - self.assertEqual(result.slug, "test-trace") - - # Verify correct endpoint was used - endpoint = mocked_api.async_invoke.call_args[0][0] - self.assertEqual(endpoint, CreateTraceEndpoint) - - # Verify DTO was created correctly - dto = mocked_api.async_invoke.call_args[0][1] - self.assertEqual(dto.slug, "test-trace") - self.assertEqual(dto.params.name, "Test Trace") + self.assertEqual(result.feature_slug, "test-trace") + self.assertEqual(result.metadata, {"source": "test"}) async def test_async_create_generation(self): """Test asynchronously creating a generation""" - # Configure mock - mocked_api.async_invoke.return_value = (None, generation_response) - - # Call the method - params = GenerationParams( - trace_id="trace-123", - text="Generated text", - model_id="gpt-4", - prompt="Test prompt", - metadata={"source": "test"} - ) + # Call the method - generations are created directly without API calls + params = { + "name": "Test Generation", + "input": "Test input", + "metadata": {"source": "test"} + } result = await self.monitor_sdk.async_create_generation(params) # Assertions self.assertIsNotNone(result) - self.assertEqual(result.id, "gen-123") - self.assertEqual(result.trace_id, "trace-123") - self.assertEqual(result.text, "Generated text") - self.assertEqual(result.model_id, "gpt-4") - - # Verify correct endpoint was used - endpoint = mocked_api.async_invoke.call_args[0][0] - self.assertEqual(endpoint, CreateGenerationEndpoint) - - # Verify DTO was created correctly - dto = mocked_api.async_invoke.call_args[0][1] - self.assertEqual(dto.params.trace_id, "trace-123") - self.assertEqual(dto.params.text, "Generated text") + self.assertIsInstance(result, Generation) + self.assertEqual(result.input, "Test input") + self.assertEqual(result.name, "Test Generation") + self.assertEqual(result.metadata, {"source": "test"}) + # Options will be None since not provided + self.assertIsNone(result.options) async def test_async_create_log(self): """Test asynchronously creating a log""" - # Configure mock - mocked_api.async_invoke.return_value = (None, log_response) - - # Call the method - params = LogParams( - trace_id="trace-123", - type="info", - message="Test log message", - metadata={"source": "test"} - ) + # Call the method - logs are created directly without API calls + params = { + "name": "Test Log", + "input": "Test log input", + "metadata": {"source": "test"} + } result = await self.monitor_sdk.async_create_log(params) # Assertions self.assertIsNotNone(result) - self.assertEqual(result["id"], "log-123") - self.assertEqual(result["trace_id"], "trace-123") - self.assertEqual(result["message"], "Test log message") - - # Verify correct endpoint was used - endpoint = mocked_api.async_invoke.call_args[0][0] - self.assertEqual(endpoint, CreateLogEndpoint) + self.assertIsInstance(result, Log) + self.assertEqual(result.name, "Test Log") + self.assertEqual(result.input, "Test log input") + self.assertEqual(result.metadata, {"source": "test"}) - # Verify DTO was created correctly - dto = mocked_api.async_invoke.call_args[0][1] - self.assertEqual(dto.params.trace_id, "trace-123") - self.assertEqual(dto.params.message, "Test log message") - - async def test_async_error_handling_create_trace(self): - """Test error handling when asynchronously creating a trace""" + async def test_async_error_handling_create_experiment(self): + """Test error handling when asynchronously creating an experiment""" # Configure mock to return an error error = Exception("API Error") mocked_api.async_invoke.return_value = (error, None) # Call the method - params = TraceParams(name="Test Trace") + params = {"name": "Test Experiment"} - with self.assertRaises(Exception) as context: - await self.monitor_sdk.async_create_trace("non-existent", params) + err, result = await self.monitor_sdk.async_create_experiment("test-feature", params) # Assertions - self.assertEqual(str(context.exception), "API Error") + self.assertIsNotNone(err) + self.assertIsNone(result) + self.assertEqual(str(err), "API Error") -def run_async_tests(): - """ - Helper function to run async tests - """ - loop = asyncio.get_event_loop() +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - # Create and run the test suite - suite = unittest.TestLoader().loadTestsFromTestCase(TestMonitorSDKAsync) - runner = unittest.TextTestRunner() + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") - for test in suite: - if test._testMethodName.startswith('test_async_'): - coro = getattr(test, test._testMethodName)() - loop.run_until_complete(coro) + def close(self): + self.loop.close() if __name__ == "__main__": - run_async_tests() + runner = AsyncTestRunner() + try: + runner.run_test_case(TestMonitorSDKAsync) + finally: + runner.close() diff --git a/tests/test_promptsdk_async.py b/tests/test_promptsdk_async.py index 6f8673c..4ecb96e 100644 --- a/tests/test_promptsdk_async.py +++ b/tests/test_promptsdk_async.py @@ -1,10 +1,10 @@ import unittest import asyncio -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock from basalt.sdk.promptsdk import PromptSDK from basalt.utils.logger import Logger -from basalt.utils.dtos import PromptDTO, PromptVersionDTO, PromptVariablesDTO, GetPromptDTO +from basalt.utils.dtos import GetPromptDTO, PromptResponse, DescribePromptDTO, DescribePromptResponse, PromptListDTO, PromptListResponse from basalt.endpoints.get_prompt import GetPromptEndpoint, GetPromptEndpointResponse from basalt.endpoints.list_prompts import ListPromptsEndpoint, ListPromptsEndpointResponse from basalt.endpoints.describe_prompt import DescribePromptEndpoint, DescribePromptEndpointResponse @@ -14,67 +14,77 @@ # Make sure async_invoke is an AsyncMock mocked_api.async_invoke = AsyncMock() +# Mock model for PromptResponse +mock_model = type('Model', (), { + 'provider': 'openai', + 'model': 'gpt-4', + 'version': 'latest', + 'parameters': type('Params', (), { + 'temperature': 0.7, + 'max_length': 100, + 'top_p': 1.0, + 'top_k': None, + 'frequency_penalty': None, + 'presence_penalty': None, + 'response_format': 'text', + 'json_object': None + })() +})() + # Mock responses for different endpoints prompt_get_response = GetPromptEndpointResponse( - prompt=PromptDTO( - slug="test-prompt", - feature_slug="test-feature", - name="Test Prompt", - description="A test prompt", - text="This is a test prompt with {{variable}}", + warning=None, + prompt=PromptResponse( + text="This is a test prompt: {{variable}}", + systemText="You are a helpful assistant", version="1.0", - tags=["test"], - variables=[PromptVariablesDTO(name="variable", description="A test variable")] + model=mock_model ) ) prompt_list_response = ListPromptsEndpointResponse( + warning=None, prompts=[ - PromptDTO( + PromptListResponse( slug="test-prompt-1", - feature_slug="test-feature", + status="active", name="Test Prompt 1", - version="1.0" + description="First test prompt", + available_versions=["1.0"], + available_tags=["latest"] ), - PromptDTO( + PromptListResponse( slug="test-prompt-2", - feature_slug="test-feature", + status="active", name="Test Prompt 2", - version="1.0" + description="Second test prompt", + available_versions=["1.0", "2.0"], + available_tags=["latest", "stable"] ) ] ) prompt_describe_response = DescribePromptEndpointResponse( - prompt=PromptDTO( + warning=None, + prompt=DescribePromptResponse( slug="test-prompt", - feature_slug="test-feature", + status="active", name="Test Prompt", - description="A test prompt", - text="This is a test prompt with {{variable}}", - version="1.0", - tags=["test"], - variables=[PromptVariablesDTO(name="variable", description="A test variable")] - ), - versions=[ - PromptVersionDTO( - version="1.0", - text="This is a test prompt with {{variable}}", - created_at="2023-01-01T00:00:00Z" - ), - PromptVersionDTO( - version="0.9", - text="This is an older test prompt with {{variable}}", - created_at="2022-12-01T00:00:00Z" - ) - ] + description="A test prompt for unit testing", + available_versions=["1.0", "1.1", "2.0"], + available_tags=["latest", "stable"], + variables=[{"name": "variable", "type": "string"}] + ) ) class TestPromptSDKAsync(unittest.TestCase): def setUp(self): + from basalt.utils.memcache import MemoryCache self.prompt_sdk = PromptSDK( api=mocked_api, + cache=MemoryCache(), + fallback_cache=MemoryCache(), logger=logger ) # Reset mock calls before each test @@ -90,9 +100,10 @@ async def test_async_get_prompt(self): # Assertions self.assertIsNone(err) - self.assertEqual(prompt_response.text, "This is a test prompt with {{variable}}") - self.assertEqual(prompt_response.slug, "test-prompt") - self.assertIsNone(generation) # No monitoring in this test + self.assertIsNotNone(prompt_response) + self.assertEqual(prompt_response.text, "This is a test prompt: {{variable}}") + self.assertEqual(prompt_response.version, "1.0") + self.assertIsNotNone(generation) # Generation is created for monitoring # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] @@ -112,6 +123,7 @@ async def test_async_list_prompts(self): # Assertions self.assertIsNone(err) + self.assertIsNotNone(prompts) self.assertEqual(len(prompts), 2) self.assertEqual(prompts[0].slug, "test-prompt-1") self.assertEqual(prompts[1].slug, "test-prompt-2") @@ -130,11 +142,12 @@ async def test_async_list_prompts_with_feature_filter(self): # Assertions self.assertIsNone(err) + self.assertIsNotNone(prompts) self.assertEqual(len(prompts), 2) # Verify DTO was created correctly dto = mocked_api.async_invoke.call_args[0][1] - self.assertEqual(dto.feature_slug, "test-feature") + self.assertEqual(dto.featureSlug, "test-feature") async def test_async_describe_prompt(self): """Test asynchronously describing a prompt""" @@ -146,10 +159,11 @@ async def test_async_describe_prompt(self): # Assertions self.assertIsNone(err) - self.assertEqual(prompt_description.prompt.slug, "test-prompt") - self.assertEqual(len(prompt_description.versions), 2) - self.assertEqual(prompt_description.versions[0].version, "1.0") - self.assertEqual(prompt_description.versions[1].version, "0.9") + self.assertIsNotNone(prompt_description) + self.assertEqual(prompt_description.slug, "test-prompt") + self.assertEqual(prompt_description.name, "Test Prompt") + self.assertEqual(len(prompt_description.available_versions), 3) + self.assertIn("1.0", prompt_description.available_versions) # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] @@ -168,8 +182,9 @@ async def test_async_get_prompt_with_variables(self): # Assertions self.assertIsNone(err) + self.assertIsNotNone(prompt_response) # Variables should be replaced in the response - self.assertEqual(prompt_response.text, "This is a test prompt with test-value") + self.assertEqual(prompt_response.text, "This is a test prompt: test-value") async def test_async_get_prompt_error_handling(self): """Test error handling when asynchronously getting a prompt""" @@ -187,21 +202,42 @@ async def test_async_get_prompt_error_handling(self): self.assertEqual(str(err), "API Error") -def run_async_tests(): - """ - Helper function to run async tests - """ - loop = asyncio.get_event_loop() +class AsyncTestRunner: + """Helper class to run async tests properly""" + + def __init__(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) - # Create and run the test suite - suite = unittest.TestLoader().loadTestsFromTestCase(TestPromptSDKAsync) - runner = unittest.TextTestRunner() + def run_test_case(self, test_case_class): + """Run all async test methods in a test case""" + suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) + + for test in suite: + test_method = getattr(test, test._testMethodName) + if asyncio.iscoroutinefunction(test_method): + try: + test.setUp() + self.loop.run_until_complete(test_method()) + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") + else: + # Run sync tests normally + try: + test.setUp() + test_method() + print(f"✓ {test._testMethodName}") + except Exception as e: + print(f"✗ {test._testMethodName}: {e}") - for test in suite: - if test._testMethodName.startswith('test_async_'): - coro = getattr(test, test._testMethodName)() - loop.run_until_complete(coro) + def close(self): + self.loop.close() if __name__ == "__main__": - run_async_tests() + runner = AsyncTestRunner() + try: + runner.run_test_case(TestPromptSDKAsync) + finally: + runner.close() \ No newline at end of file