Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion basalt/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
if TYPE_CHECKING: # pragma: no cover
from .client import DatasetsClient

from .file_upload import FileAttachment
from .models import Dataset, DatasetRow

__all__ = ["DatasetsClient", "Dataset", "DatasetRow"]
__all__ = ["DatasetsClient", "Dataset", "DatasetRow", "FileAttachment"]


def __getattr__(name: str) -> Any:
Expand Down
90 changes: 84 additions & 6 deletions basalt/datasets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .._internal.http import HTTPClient
from ..config import config
from ..types.exceptions import BasaltAPIError
from .file_upload import FileAttachment, FileUploadHandler
from .models import Dataset, DatasetRow


Expand Down Expand Up @@ -40,6 +41,13 @@ def __init__(
self._base_url = base_url or config.get("api_url")
super().__init__(client_name="datasets", http_client=http_client, log_level=log_level)

# Initialize file upload handler
self._file_upload_handler = FileUploadHandler(
http_client=self._http_client,
base_url=self._base_url,
api_key=self._api_key,
)

async def list(self) -> list[Dataset]:
"""
List all datasets available in the workspace.
Expand Down Expand Up @@ -176,10 +184,66 @@ def get_sync(self, slug: str) -> Dataset:

return dataset

async def _process_file_uploads(
self, values: dict[str, str | FileAttachment]
) -> dict[str, str]:
"""
Process file uploads and return values with S3 keys.

Args:
values: Dictionary containing strings and/or FileAttachment objects

Returns:
Dictionary with all FileAttachments replaced by their S3 keys

Raises:
FileValidationError: If file validation fails
FileUploadError: If file upload fails
"""
processed = {}

for key, value in values.items():
if isinstance(value, FileAttachment):
# Upload file and get S3 key
s3_key = await self._file_upload_handler.upload_file(value)
processed[key] = s3_key
else:
processed[key] = value

return processed

def _process_file_uploads_sync(
self, values: dict[str, str | FileAttachment]
) -> dict[str, str]:
"""
Process file uploads and return values with S3 keys (synchronous version).

Args:
values: Dictionary containing strings and/or FileAttachment objects

Returns:
Dictionary with all FileAttachments replaced by their S3 keys

Raises:
FileValidationError: If file validation fails
FileUploadError: If file upload fails
"""
processed = {}

for key, value in values.items():
if isinstance(value, FileAttachment):
# Upload file and get S3 key
s3_key = self._file_upload_handler.upload_file_sync(value)
processed[key] = s3_key
else:
processed[key] = value

return processed

async def add_row(
self,
slug: str,
values: dict[str, str],
values: dict[str, str | FileAttachment],
name: str | None = None,
ideal_output: str | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -190,21 +254,28 @@ async def add_row(
Args:
slug: The slug identifier for the dataset.
values: A dictionary of column values for the dataset item.
Values can be strings or FileAttachment objects for file columns.
Files are automatically uploaded to S3 before creating the row.
name: An optional name for the dataset item.
ideal_output: An optional ideal output for the dataset item.
metadata: An optional metadata dictionary.

Returns:
The created DatasetRow.
The created DatasetRow. File values will contain S3 keys.

Raises:
FileValidationError: If file validation fails.
FileUploadError: If file upload fails.
BasaltAPIError: If the API request fails.
NetworkError: If a network error occurs.
"""
# Process file uploads first
processed_values = await self._process_file_uploads(values)

url = f"{self._base_url}/datasets/{slug}/items"

body: dict[str, Any] = {
"values": values,
"values": processed_values,
}
if name is not None:
body["name"] = name
Expand Down Expand Up @@ -241,7 +312,7 @@ async def add_row(
def add_row_sync(
self,
slug: str,
values: dict[str, str],
values: dict[str, str | FileAttachment],
name: str | None = None,
ideal_output: str | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -252,21 +323,28 @@ def add_row_sync(
Args:
slug: The slug identifier for the dataset.
values: A dictionary of column values for the dataset item.
Values can be strings or FileAttachment objects for file columns.
Files are automatically uploaded to S3 before creating the row.
name: An optional name for the dataset item.
ideal_output: An optional ideal output for the dataset item.
metadata: An optional metadata dictionary.

Returns:
The created DatasetRow.
The created DatasetRow. File values will contain S3 keys.

Raises:
FileValidationError: If file validation fails.
FileUploadError: If file upload fails.
BasaltAPIError: If the API request fails.
NetworkError: If a network error occurs.
"""
# Process file uploads first
processed_values = self._process_file_uploads_sync(values)

url = f"{self._base_url}/datasets/{slug}/items"

body: dict[str, Any] = {
"values": values,
"values": processed_values,
}
if name is not None:
body["name"] = name
Expand Down
Loading