diff --git a/basalt/_version.py b/basalt/_version.py index 3ced358..b5fdc75 100644 --- a/basalt/_version.py +++ b/basalt/_version.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/basalt/basalt_facade.py b/basalt/basalt_facade.py index 9c0e93f..925773d 100644 --- a/basalt/basalt_facade.py +++ b/basalt/basalt_facade.py @@ -1,7 +1,8 @@ from .utils.api import Api -from .utils.protocols import IPromptSDK, IBasaltSDK, LogLevel +from .utils.protocols import IPromptSDK, IBasaltSDK, LogLevel, IDatasetSDK from .sdk.promptsdk import PromptSDK from .sdk.monitorsdk import MonitorSDK +from .sdk.datasetsdk import DatasetSDK from .basaltsdk import BasaltSDK from .utils.memcache import MemoryCache from .utils.networker import Networker @@ -40,8 +41,9 @@ def __init__(self, api_key: str, log_level: LogLevel = 'all'): prompt = PromptSDK(api, cache, global_fallback_cache, logger) monitor = MonitorSDK(api, logger) + datasets = DatasetSDK(api, logger) - self._basalt = BasaltSDK(prompt, monitor) + self._basalt = BasaltSDK(prompt, monitor, datasets) @property def prompt(self) -> IPromptSDK: @@ -56,3 +58,10 @@ def monitor(self) -> IMonitorSDK: Read-only access to the MonitorSDK instance. """ return self._basalt.monitor + + @property + def datasets(self) -> IDatasetSDK: + """ + Read-only access to the DatasetSDK instance. + """ + return self._basalt.datasets diff --git a/basalt/basaltsdk.py b/basalt/basaltsdk.py index b3ef744..57a451f 100644 --- a/basalt/basaltsdk.py +++ b/basalt/basaltsdk.py @@ -1,4 +1,4 @@ -from .utils.protocols import IPromptSDK, IBasaltSDK +from .utils.protocols import IPromptSDK, IBasaltSDK, IDatasetSDK from .ressources.monitor.monitorsdk_types import IMonitorSDK class BasaltSDK(IBasaltSDK): @@ -7,9 +7,10 @@ class BasaltSDK(IBasaltSDK): It serves as the main entry point for interacting with the Basalt SDK. """ - def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK): + def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK, dataset_sdk: IDatasetSDK): self._prompt = prompt_sdk self._monitor = monitor_sdk + self._datasets = dataset_sdk @property def prompt(self) -> IPromptSDK: @@ -20,3 +21,8 @@ def prompt(self) -> IPromptSDK: def monitor(self) -> IMonitorSDK: """Read-only access to the MonitorSDK instance""" return self._monitor + + @property + def datasets(self) -> IDatasetSDK: + """Read-only access to the DatasetSDK instance""" + return self._datasets diff --git a/basalt/endpoints/create_dataset_item.py b/basalt/endpoints/create_dataset_item.py new file mode 100644 index 0000000..30ef73a --- /dev/null +++ b/basalt/endpoints/create_dataset_item.py @@ -0,0 +1,87 @@ +""" +Endpoint for creating a new dataset item +""" +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from ..utils.dtos import DatasetRowDTO, CreateDatasetItemDTO + +@dataclass +class CreateDatasetItemEndpointResponse: + """ + Response from the create dataset item endpoint + """ + datasetRow: DatasetRowDTO + warning: Optional[str] = None + error: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CreateDatasetItemEndpointResponse": + """ + Create an instance of CreateDatasetItemEndpointResponse from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary containing the response data. + + Returns: + CreateDatasetItemEndpointResponse + """ + if "error" in data: + return cls(datasetRow=None, error=data["error"]) + + return cls( + datasetRow=DatasetRowDTO.from_dict(data["datasetRow"]), + warning=data.get("warning"), + error=None + ) + + +class CreateDatasetItemEndpoint: + """ + Endpoint class for creating a dataset item. + """ + @staticmethod + def prepare_request(dto: CreateDatasetItemDTO) -> Dict[str, Any]: + """ + Prepare the request dictionary for the CreateDatasetItem endpoint. + + Args: + dto (CreateDatasetItemDTO): The DTO containing dataset item data. + + Returns: + The path, method, and body for creating a dataset item on the API. + """ + body = { + "values": dto.values + } + + if dto.name: + body["name"] = dto.name + + if dto.idealOutput: + body["idealOutput"] = dto.idealOutput + + if dto.metadata: + body["metadata"] = dto.metadata + + return { + "path": f"/datasets/{dto.slug}/items", + "method": "POST", + "body": body + } + + @staticmethod + def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[CreateDatasetItemEndpointResponse]]: + """ + Decode the response returned from the API + + Args: + response (dict): The JSON response to encode into a CreateDatasetItemEndpointResponse + + Returns: + A tuple containing an optional exception and an optional CreateDatasetItemEndpointResponse. + """ + try: + return None, CreateDatasetItemEndpointResponse.from_dict(response) + except Exception as e: + return e, None diff --git a/basalt/endpoints/get_dataset.py b/basalt/endpoints/get_dataset.py new file mode 100644 index 0000000..99ab705 --- /dev/null +++ b/basalt/endpoints/get_dataset.py @@ -0,0 +1,72 @@ +""" +Endpoint for fetching a specific dataset by slug +""" +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +from ..utils.dtos import DatasetDTO, GetDatasetDTO + +@dataclass +class GetDatasetEndpointResponse: + """ + Response from the get dataset endpoint + """ + dataset: DatasetDTO + error: Optional[str] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetDatasetEndpointResponse": + """ + Create an instance of GetDatasetEndpointResponse from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary containing the response data. + + Returns: + GetDatasetEndpointResponse + """ + if "error" in data: + return cls(dataset=None, error=data["error"]) + + return cls( + dataset=DatasetDTO.from_dict(data["dataset"]), + error=None + ) + + +class GetDatasetEndpoint: + """ + Endpoint class for fetching a specific dataset. + """ + @staticmethod + def prepare_request(dto: GetDatasetDTO) -> Dict[str, Any]: + """ + Prepare the request dictionary for the GetDataset endpoint. + + Args: + dto (GetDatasetDTO): The DTO containing dataset slug. + + Returns: + The path, method, and query parameters for getting a dataset on the API. + """ + return { + "path": f"/datasets/{dto.slug}", + "method": "GET", + "query": {} + } + + @staticmethod + def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[GetDatasetEndpointResponse]]: + """ + Decode the response returned from the API + + Args: + response (dict): The JSON response to encode into a GetDatasetEndpointResponse + + Returns: + A tuple containing an optional exception and an optional GetDatasetEndpointResponse. + """ + try: + return None, GetDatasetEndpointResponse.from_dict(response) + except Exception as e: + return e, None diff --git a/basalt/endpoints/list_datasets.py b/basalt/endpoints/list_datasets.py new file mode 100644 index 0000000..9076c0f --- /dev/null +++ b/basalt/endpoints/list_datasets.py @@ -0,0 +1,64 @@ +""" +Endpoint for listing all datasets +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from ..utils.dtos import DatasetDTO, ListDatasetsDTO + +@dataclass +class ListDatasetsEndpointResponse: + """ + Response from the list datasets endpoint + """ + datasets: List[DatasetDTO] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ListDatasetsEndpointResponse": + """ + Create an instance of ListDatasetsEndpointResponse from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary containing the response data. + + Returns: + ListDatasetsEndpointResponse + """ + return cls( + datasets=[DatasetDTO.from_dict(dataset) for dataset in data["datasets"]], + ) + + +class ListDatasetsEndpoint: + """ + Endpoint class for fetching all datasets. + """ + @staticmethod + def prepare_request(dto: ListDatasetsDTO) -> Dict[str, Any]: + """ + Prepare the request dictionary for the ListDatasets endpoint. + + Returns: + The path, method, and query parameters for getting datasets on the API. + """ + return { + "path": "/datasets", + "method": "GET", + "query": {} + } + + @staticmethod + def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[ListDatasetsEndpointResponse]]: + """ + Decode the response returned from the API + + Args: + response (dict): The JSON response to encode into a ListDatasetsEndpointResponse + + Returns: + A tuple containing an optional exception and an optional ListDatasetsEndpointResponse. + """ + try: + return None, ListDatasetsEndpointResponse.from_dict(response) + except Exception as e: + return e, None diff --git a/basalt/objects/dataset.py b/basalt/objects/dataset.py new file mode 100644 index 0000000..51cd679 --- /dev/null +++ b/basalt/objects/dataset.py @@ -0,0 +1,92 @@ +""" +Dataset object for Basalt SDK +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Any, Optional + + +@dataclass +class DatasetRow: + """ + A row in a dataset with values and metadata + """ + values: Dict[str, str] + name: Optional[str] = None + ideal_output: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert the DatasetRow to a dictionary for API requests""" + result = { + "values": self.values, + "metadata": self.metadata, + "name": self.name, + "idealOutput": self.ideal_output + } + + # if self.name: + # result["name"] = self.name + + # if self.ideal_output: + # result["idealOutput"] = self.ideal_output + + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DatasetRow": + """ + Create a DatasetRow instance from a dictionary + + Args: + data: Dictionary containing dataset row data + + Returns: + DatasetRow: A new DatasetRow instance + """ + return cls( + values=data.get("values", {}), + name=data.get("name", None), + ideal_output=data.get("idealOutput", None), + metadata=data.get("metadata", {}) + ) + + +@dataclass +class Dataset: + """ + A dataset with rows and metadata + """ + slug: str + name: str + columns: List[str] + rows: List[DatasetRow] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert the Dataset to a dictionary for API responses""" + return { + "slug": self.slug, + "name": self.name, + "columns": self.columns, + "rows": [row.to_dict() for row in self.rows] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Dataset": + """ + Create a Dataset instance from a dictionary + + Args: + data: Dictionary containing dataset data + + Returns: + Dataset: A new Dataset instance + """ + rows = [DatasetRow.from_dict(row) for row in data.get("rows", [])] + + return cls( + slug=data["slug"], + name=data["name"], + columns=data["columns"], + rows=rows + ) diff --git a/basalt/ressources/datasets/__init__.py b/basalt/ressources/datasets/__init__.py new file mode 100644 index 0000000..4950d47 --- /dev/null +++ b/basalt/ressources/datasets/__init__.py @@ -0,0 +1,3 @@ +""" +Datasets resource module +""" diff --git a/basalt/ressources/datasets/dataset_types.py b/basalt/ressources/datasets/dataset_types.py new file mode 100644 index 0000000..e8bc5a5 --- /dev/null +++ b/basalt/ressources/datasets/dataset_types.py @@ -0,0 +1,68 @@ +""" +Dataset types module for Basalt SDK +""" +from dataclasses import dataclass, field +from typing import List, Dict, Optional, Any + + +@dataclass +class DatasetRowValue: + """ + A value in a dataset row + """ + label: str + value: str + + +@dataclass +class DatasetRow: + """ + A row in a dataset + """ + values: List[DatasetRowValue] + name: Optional[str] = None + idealOutput: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DatasetRow": + """Create a DatasetRow from a dictionary""" + values_list = [] + if "values" in data: + if isinstance(data["values"], list): + values_list = [DatasetRowValue(**val) if isinstance(val, dict) else DatasetRowValue(label=val["label"], value=val["value"]) + for val in data["values"]] + elif isinstance(data["values"], dict): + values_list = [DatasetRowValue(label=key, value=val) for key, val in data["values"].items()] + + return cls( + values=values_list, + name=data.get("name", None), + idealOutput=data.get("idealOutput", None), + metadata=data.get("metadata", {}) + ) + + +@dataclass +class Dataset: + """ + A dataset in the Basalt system + """ + slug: str + name: str + columns: List[str] + rows: List[DatasetRow] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Dataset": + """Create a Dataset from a dictionary""" + rows = [] + if "rows" in data: + rows = [DatasetRow.from_dict(row) for row in data["rows"]] + + return cls( + slug=data["slug"], + name=data["name"], + columns=data["columns"], + rows=rows + ) diff --git a/basalt/sdk/datasetsdk.py b/basalt/sdk/datasetsdk.py new file mode 100644 index 0000000..c7a6cf6 --- /dev/null +++ b/basalt/sdk/datasetsdk.py @@ -0,0 +1,110 @@ +""" +SDK for interacting with Basalt datasets +""" +from typing import Dict, List, Optional, Tuple, Any + +from ..utils.dtos import ( + ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO, + ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult, + DatasetDTO, DatasetRowDTO +) +from ..utils.protocols import IApi, ILogger +from ..endpoints.list_datasets import ListDatasetsEndpoint +from ..endpoints.get_dataset import GetDatasetEndpoint +from ..endpoints.create_dataset_item import CreateDatasetItemEndpoint +from ..objects.dataset import Dataset, DatasetRow + + +class DatasetSDK: + """ + SDK for interacting with Basalt datasets. + """ + def __init__( + self, + api: IApi, + logger: ILogger + ): + self._api = api + self._logger = logger + + def list(self) -> ListDatasetsResult: + """ + List all datasets available in the workspace. + + Returns: + Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional + exception and an optional list of DatasetDTO objects. + """ + dto = ListDatasetsDTO() + err, result = self._api.invoke(ListDatasetsEndpoint, dto) + + if err is not None: + return err, None + + return None, [DatasetDTO( + slug=dataset.slug, + name=dataset.name, + columns=dataset.columns + ) for dataset in result.datasets] + + def get(self, slug: str) -> GetDatasetResult: + """ + Get a dataset by its slug. + + Args: + slug (str): The slug identifier for the dataset. + + Returns: + Tuple[Optional[Exception], Optional[DatasetDTO]]: A tuple containing an optional + exception and an optional DatasetDTO. + """ + dto = GetDatasetDTO(slug=slug) + err, result = self._api.invoke(GetDatasetEndpoint, dto) + + if err is not None: + return err, None + + if result.error: + return Exception(result.error), None + + return None, result.dataset + + def addRow( + self, + slug: str, + values: Dict[str, str], + name: Optional[str] = None, + ideal_output: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> CreateDatasetItemResult: + """ + Create a new item in a dataset. + + Args: + slug (str): The slug identifier for the dataset. + values (Dict[str, str]): A dictionary of column values for the dataset item. + name (Optional[str]): An optional name for the dataset item. + ideal_output (Optional[str]): An optional ideal output for the dataset item. + metadata (Optional[Dict[str, Any]]): An optional metadata dictionary. + + Returns: + Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]]: A tuple containing + an optional exception, an optional DatasetRowDTO, and an optional warning message. + """ + dto = CreateDatasetItemDTO( + slug=slug, + values=values, + name=name, + idealOutput=ideal_output, + metadata=metadata + ) + + err, result = self._api.invoke(CreateDatasetItemEndpoint, dto) + + if err is not None: + return err, None, None + + if result.error: + return Exception(result.error), None, None + + return None, result.datasetRow, result.warning diff --git a/basalt/utils/dtos.py b/basalt/utils/dtos.py index 2de0621..59a9326 100644 --- a/basalt/utils/dtos.py +++ b/basalt/utils/dtos.py @@ -130,4 +130,68 @@ class PromptListDTO: -ListResult = Tuple[Optional[Exception], Optional[List[PromptListResponse]]] \ No newline at end of file +ListResult = Tuple[Optional[Exception], Optional[List[PromptListResponse]]] + +# ------------------------------ Datasets ----------------------------- # +@dataclass +class DatasetDTO: + """Dataset data transfer object""" + slug: str + name: str + columns: List[str] + rows: List['DatasetRowDTO'] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DatasetDTO": + return cls( + slug=data["slug"], + name=data["name"], + columns=data["columns"], + rows=[DatasetRowDTO.from_dict(row) for row in data.get("rows", [])] + ) + + +@dataclass +class DatasetRowDTO: + """Dataset row data transfer object""" + values: Dict[str, str] + name: Optional[str] = None + idealOutput: Optional[str] = None + metadata: Dict[str, Any] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DatasetRowDTO": + return cls( + values=data["values"], + name=data.get("name", None), + idealOutput=data.get("idealOutput", None), + metadata=data.get("metadata", {}) + ) + + +@dataclass +class ListDatasetsDTO: + """DTO for listing datasets""" + pass + + +@dataclass +class GetDatasetDTO: + """DTO for getting a specific dataset""" + slug: str + + +@dataclass +class CreateDatasetItemDTO: + """DTO for creating a dataset item""" + slug: str + values: Dict[str, str] + name: Optional[str] = None + idealOutput: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +# Result types for dataset operations +ListDatasetsResult = Tuple[Optional[Exception], Optional[List[DatasetDTO]]] +GetDatasetResult = Tuple[Optional[Exception], Optional[DatasetDTO]] +CreateDatasetItemResult = Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]] \ No newline at end of file diff --git a/basalt/utils/protocols.py b/basalt/utils/protocols.py index 5e8ac8a..58b16fb 100644 --- a/basalt/utils/protocols.py +++ b/basalt/utils/protocols.py @@ -1,5 +1,5 @@ -from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal -from .dtos import GetResult, DescribeResult, ListResult +from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal, List +from .dtos import GetResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult from ..ressources.monitor.monitorsdk_types import IMonitorSDK @@ -31,11 +31,19 @@ def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = Non def describe(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... def list(self, feature_slug: Optional[str] = None) -> ListResult: ... +class IDatasetSDK(Protocol): + def list(self) -> ListDatasetsResult: ... + def get(self, slug: str) -> GetDatasetResult: ... + def addRow(self, slug: str, values: Dict[str, str], name: Optional[str] = None, + ideal_output: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> CreateDatasetItemResult: ... + class IBasaltSDK(Protocol): @property def prompt(self) -> IPromptSDK: ... @property def monitor(self) -> IMonitorSDK: ... + @property + def datasets(self) -> IDatasetSDK: ... class ILogger: def warn(self, message: str): ... diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb new file mode 100644 index 0000000..29c4fb4 --- /dev/null +++ b/examples/dataset_sdk_demo.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt Dataset SDK Demo\n", + "\n", + "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "os.environ[\"BASALT_BUILD\"] = \"development\"\n", + "\n", + "from basalt import Basalt\n", + "\n", + "# Initialize the SDK\n", + "basalt = Basalt(\n", + " api_key=\"sk-df55a...\", # Replace with your API key\n", + " log_level=\"debug\" # Optional: Set log level\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Listing Available Datasets\n", + "\n", + "Retrieve all datasets available in your workspace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# List all datasets in the workspace\n", + "err, datasets = basalt.datasets.list()\n", + "\n", + "if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + "else:\n", + " print(f\"Found {len(datasets)} datasets:\")\n", + " for i, dataset in enumerate(datasets):\n", + " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", + " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", + " \n", + " # Store the first dataset slug for later use (if available)\n", + " first_dataset_slug = datasets[0].slug if datasets else None" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Getting a Specific Dataset\n", + "\n", + "Retrieve details for a specific dataset using its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the first dataset from the list or replace with a specific slug\n", + "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", + "\n", + "err, dataset = basalt.datasets.get(dataset_slug)\n", + "\n", + "if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + "else:\n", + " print(f\"Dataset details for '{dataset.name}'\")\n", + " print(f\"Slug: {dataset.slug}\")\n", + " print(f\"Columns: {', '.join(dataset.columns)}\")\n", + " print(f\"Number of rows: {len(dataset.rows)}\")\n", + " \n", + " if dataset.rows:\n", + " print(\"\\nSample rows:\")\n", + " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", + " print(f\"Row {i+1}:\")\n", + " print(f\" Values: {row.get('values')}\")\n", + " if 'name' in row:\n", + " print(f\" Name: {row['name']}\")\n", + " if 'idealOutput' in row:\n", + " print(f\" Ideal output: {row['idealOutput']}\")\n", + " if 'metadata' in row:\n", + " print(f\" Metadata: {row['metadata']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Adding a Row to a Dataset\n", + "\n", + "Create a new row (item) in an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the dataset from the previous example\n", + "if 'dataset' in locals() and dataset:\n", + " # Build values for all columns in the dataset\n", + " values = {}\n", + " for column in dataset.columns:\n", + " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", + " \n", + " # Create the row\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=dataset.slug,\n", + " values=values,\n", + " name=\"Notebook Example Row\",\n", + " ideal_output=\"This is an ideal output for this row\",\n", + " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " else:\n", + " print(\"Successfully created new dataset row:\")\n", + " print(f\"Values: {row.values}\")\n", + " print(f\"Name: {row.name}\")\n", + " print(f\"Ideal output: {row.idealOutput}\")\n", + " print(f\"Metadata: {row.metadata}\")\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + "else:\n", + " print(\"Please run the previous cell to get a dataset first\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Error Handling with Dataset SDK\n", + "\n", + "Demonstrate proper error handling when working with datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", + " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", + " try:\n", + " err, row, warning = basalt.datasets.addRow(\n", + " slug=slug,\n", + " values=values,\n", + " name=name,\n", + " ideal_output=ideal_output,\n", + " metadata=metadata\n", + " )\n", + " \n", + " if err:\n", + " print(f\"Error creating dataset row: {err}\")\n", + " return None\n", + " \n", + " if warning:\n", + " print(f\"Warning: {warning}\")\n", + " \n", + " return row\n", + " except Exception as e:\n", + " print(f\"Unexpected error: {str(e)}\")\n", + " return None\n", + "\n", + "# Test with a valid dataset\n", + "if 'dataset_slug' in locals() and dataset_slug:\n", + " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", + " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", + " \n", + " if row:\n", + " print(f\"Successfully created row: {row.name}\")\n", + "\n", + "# Test with an invalid dataset slug\n", + "print(\"\\nTesting with invalid dataset slug:\")\n", + "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", + "print(f\"Result with invalid slug: {invalid_row}\")\n", + "\n", + "# Test with missing required values\n", + "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", + " print(\"\\nTesting with missing required values:\")\n", + " # Deliberately create incomplete values dict\n", + " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", + " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", + " print(f\"Result with incomplete values: {incomplete_row}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_datasetsdk.py b/tests/test_datasetsdk.py new file mode 100644 index 0000000..9d6ea3a --- /dev/null +++ b/tests/test_datasetsdk.py @@ -0,0 +1,160 @@ +import unittest +from unittest.mock import MagicMock +from parameterized import parameterized + +from basalt.sdk.datasetsdk import DatasetSDK +from basalt.utils.logger import Logger +from basalt.utils.dtos import DatasetDTO, DatasetRowDTO +from basalt.endpoints.list_datasets import ListDatasetsEndpoint, ListDatasetsEndpointResponse +from basalt.endpoints.get_dataset import GetDatasetEndpoint, GetDatasetEndpointResponse +from basalt.endpoints.create_dataset_item import CreateDatasetItemEndpoint, CreateDatasetItemEndpointResponse + +logger = Logger() +mocked_api = MagicMock() + +# Mock responses for different endpoints +dataset_list_response = ListDatasetsEndpointResponse( + datasets=[ + DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"] + ), + DatasetDTO( + slug="another-dataset", + name="Another Dataset", + columns=["col1", "col2", "col3"] + ) + ] +) + +dataset_get_response = GetDatasetEndpointResponse( + dataset=DatasetDTO( + slug="test-dataset", + name="Test Dataset", + columns=["input", "output"], + rows=[ + { + "values": { + "input": "Sample input", + "output": "Sample output" + }, + "name": "Sample Row", + "idealOutput": "Ideal output", + "metadata": {"source": "test"} + } + ] + ), + error=None +) + +dataset_add_row_response = CreateDatasetItemEndpointResponse( + datasetRow=DatasetRowDTO( + values={"input": "New input", "output": "New output"}, + name="New Row", + idealOutput="New ideal output", + metadata={"source": "test"} + ), + warning=None, + error=None +) + + +class TestDatasetSDK(unittest.TestCase): + def setUp(self): + self.dataset_sdk = DatasetSDK( + api=mocked_api, + logger=logger + ) + + def test_list_datasets(self): + """Test listing all datasets""" + # Configure mock + mocked_api.invoke.return_value = (None, dataset_list_response) + + # Call the method + err, datasets = self.dataset_sdk.list() + + # Assertions + self.assertIsNone(err) + self.assertEqual(len(datasets), 2) + self.assertEqual(datasets[0].slug, "test-dataset") + self.assertEqual(datasets[0].name, "Test Dataset") + self.assertEqual(datasets[1].slug, "another-dataset") + + # Verify correct endpoint was used + endpoint = mocked_api.invoke.call_args[0][0] + self.assertEqual(endpoint, ListDatasetsEndpoint) + + def test_get_dataset(self): + """Test getting a dataset by slug""" + # Configure mock + mocked_api.invoke.return_value = (None, dataset_get_response) + + # Call the method + err, dataset = self.dataset_sdk.get("test-dataset") + + # Assertions + self.assertIsNone(err) + self.assertEqual(dataset.slug, "test-dataset") + self.assertEqual(dataset.name, "Test Dataset") + self.assertEqual(len(dataset.columns), 2) + self.assertEqual(len(dataset.rows), 1) + + # Verify correct endpoint was used + endpoint = mocked_api.invoke.call_args[0][0] + self.assertEqual(endpoint, GetDatasetEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + + def test_create_dataset_item(self): + """Test creating a dataset item""" + # Configure mock + mocked_api.invoke.return_value = (None, dataset_add_row_response) + + # Call the method + values = {"input": "New input", "output": "New output"} + err, row, warning = self.dataset_sdk.addRow( + slug="test-dataset", + values=values, + name="New Row", + ideal_output="New ideal output", + metadata={"source": "test"} + ) + + # Assertions + self.assertIsNone(err) + self.assertIsNone(warning) + self.assertEqual(row.values, values) + self.assertEqual(row.name, "New Row") + self.assertEqual(row.idealOutput, "New ideal output") + + # Verify correct endpoint was used + endpoint = mocked_api.invoke.call_args[0][0] + self.assertEqual(endpoint, CreateDatasetItemEndpoint) + + # Verify DTO was created correctly + dto = mocked_api.invoke.call_args[0][1] + self.assertEqual(dto.slug, "test-dataset") + self.assertEqual(dto.values, values) + self.assertEqual(dto.name, "New Row") + self.assertEqual(dto.idealOutput, "New ideal output") + + def test_error_handling_get_dataset(self): + """Test error handling when getting a dataset""" + # Configure mock to return an error + mocked_api.invoke.return_value = (Exception("API Error"), None) + + # Call the method + err, dataset = self.dataset_sdk.get("non-existent") + + # Assertions + self.assertIsNotNone(err) + self.assertIsNone(dataset) + self.assertEqual(str(err), "API Error") + + +if __name__ == "__main__": + unittest.main()