Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ basalt*.egg-info/
build/
dist/
__pycache__/
test.py
test.py
.DS_Store
2 changes: 1 addition & 1 deletion basalt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.2.0"
10 changes: 6 additions & 4 deletions basalt/basalt_facade.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils.api import Api
from .utils.protocols import IPromptSDK, IBasaltSDK, IMonitorSDK
from .utils.protocols import IPromptSDK, IBasaltSDK, LogLevel
from .sdk.promptsdk import PromptSDK
from .sdk.monitorsdk import MonitorSDK
from .basaltsdk import BasaltSDK
Expand All @@ -8,14 +8,16 @@
from .config import config
from .utils.logger import Logger

from .ressources.monitor.monitorsdk_types import IMonitorSDK

global_fallback_cache = MemoryCache()

class BasaltFacade(IBasaltSDK):
"""
The Basalt client.
"""

def __init__(self, api_key: str, log_level: str = 'all'):
def __init__(self, api_key: str, log_level: LogLevel = 'all'):
"""
Initializes the Basalt client with the given API key and log level.

Expand All @@ -25,8 +27,8 @@ def __init__(self, api_key: str, log_level: str = 'all'):
"""
cache = MemoryCache()
logger = Logger(log_level=log_level)
networker = Networker(logger=logger)
networker = Networker()

api = Api(
networker=networker,
root_url=config["api_url"],
Expand Down
5 changes: 3 additions & 2 deletions basalt/basaltsdk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .utils.protocols import IPromptSDK, IBasaltSDK, IMonitorSDK
from .utils.protocols import IPromptSDK, IBasaltSDK
from .ressources.monitor.monitorsdk_types import IMonitorSDK

class BasaltSDK(IBasaltSDK):
"""
Expand All @@ -9,7 +10,7 @@ class BasaltSDK(IBasaltSDK):
def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK):
self._prompt = prompt_sdk
self._monitor = monitor_sdk

@property
def prompt(self) -> IPromptSDK:
"""Read-only access to the PromptSDK instance"""
Expand Down
9 changes: 6 additions & 3 deletions basalt/endpoints/list_prompts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, List

from ..utils.dtos import PromptListResponse
from ..utils.dtos import PromptListResponse, PromptListDTO

@dataclass
class ListPromptsEndpointResponse:
Expand Down Expand Up @@ -29,7 +29,7 @@ class ListPromptsEndpoint:
Endpoint class for fetching a prompt.
"""
@staticmethod
def prepare_request() -> Dict[str, Any]:
def prepare_request(dto: PromptListDTO) -> Dict[str, Any]:
"""
Prepare the request dictionary for the ListPrompts endpoint.

Expand All @@ -38,7 +38,10 @@ def prepare_request() -> Dict[str, Any]:
"""
return {
"path": "/prompts",
"method": "GET"
"method": "GET",
"query": {
"featureSlug": dto.featureSlug
}
}

@staticmethod
Expand Down
55 changes: 55 additions & 0 deletions basalt/endpoints/monitor/create_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from datetime import datetime

# Minimal Experiment model (expand as needed)
@dataclass
class Experiment:
feature_slug: str
name: str
id: str
created_at: datetime
# Add more fields as needed

@classmethod
def from_dict(cls, data: Dict[str, Any]):
return cls(
feature_slug=data.get("featureSlug") or data.get("feature_slug"),
name=data.get("name"),
id=data.get("id"),
created_at=data.get("createdAt"),
)

Comment thread
MarquisG marked this conversation as resolved.
@dataclass
class CreateExperimentDTO:
feature_slug: str
name: str

@dataclass
class Output:
experiment: Experiment

class CreateExperimentEndpoint:
"""
Endpoint for creating an experiment
"""
@staticmethod
def prepare_request(dto: CreateExperimentDTO) -> Dict[str, Any]:
body = {
"featureSlug": dto.feature_slug,
"name": dto.name,
}

return {
"method": "post",
"path": "/monitor/experiments",
"body": body,
}

@staticmethod
def decode_response(body: Any) -> Tuple[Optional[Exception], Optional[Output]]:
if not isinstance(body, dict):
return Exception("Failed to decode response (invalid body format)"), None

experiment = Experiment.from_dict(body)
return None, Output(experiment=experiment)
81 changes: 34 additions & 47 deletions basalt/endpoints/monitor/send_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class SendTraceEndpoint:
def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]:
"""
Prepares the request for sending a trace.

Args:
dto (Optional[Dict[str, Any]]): The data transfer object containing the trace.

Returns:
Dict[str, Any]: The request information.
"""
Expand All @@ -28,57 +28,40 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]:
"path": "/monitor/trace",
"body": {}
}

trace = dto["trace"]

# Check if trace is already a dictionary or an object
if isinstance(trace, dict):
trace_data = trace
logs = trace_data.get("logs", [])
logs = trace.get("logs", [])
else:
trace_data = trace.to_dict()
# Convert logs to a format suitable for the API
logs = []
for log in trace_data["logs"]:
log_data = log.to_dict()

# Convert dates to ISO format
log_data["startTime"] = log_data["start_time"].isoformat() if isinstance(log_data["start_time"], datetime) else log_data["start_time"]
log_data["endTime"] = log_data["end_time"].isoformat() if isinstance(log_data["end_time"], datetime) and log_data["end_time"] else None

# Remove old format keys
del log_data["start_time"]
del log_data["end_time"]
dict_log = log.to_dict()

# Convert dates and handle parent ID
log_data = {
"startTime": dict_log["start_time"].isoformat() if isinstance(dict_log["start_time"], datetime) else dict_log["start_time"],
"endTime": dict_log["end_time"].isoformat() if isinstance(dict_log["end_time"], datetime) and dict_log["end_time"] else None,
"parentId": dict_log["parent"]["id"] if dict_log["parent"] else None,
"inputTokens": dict_log["input_tokens"] if "input_tokens" in dict_log else None,
"outputTokens": dict_log["output_tokens"] if "output_tokens" in dict_log else None,
"cost": dict_log["cost"] if "cost" in dict_log else None,
"variables": [{"label": k, "value": v} for k, v in dict_log["variables"].items()] if "variables" in dict_log else None,
"input": dict_log["input"] if "input" in dict_log else None,
"output": dict_log["output"] if "output" in dict_log else None,
"prompt": dict_log["prompt"] if "prompt" in dict_log else None,
"evaluators": dict_log["evaluators"] if "evaluators" in dict_log else None
}

# Add input and output if they exist
if hasattr(log, "input"):
log_data["input"] = log.input
if hasattr(log, "output"):
log_data["output"] = log.output

# Add prompt and variables if it's a generation
if hasattr(log, "prompt"):
log_data["prompt"] = log.prompt
if hasattr(log, "variables") and log.variables:
log_data["variables"] = [{"label": key, "value": value} for key, value in log.variables.items()]
if hasattr(log, "inputTokens"):
log_data["inputTokens"] = log.inputTokens
if hasattr(log, "outputTokens"):
log_data["outputTokens"] = log.outputTokens
if hasattr(log, "cost"):
log_data["cost"] = log.cost

# Extract parent ID
if log_data["parent"]:
log_data["parentId"] = log_data["parent"]["id"]
del log_data["parent"]
else:
log_data["parentId"] = None

logs.append(log_data)

# Process logs if they're already in dictionary format
processed_logs = []

for log_data in logs:
# If log_data is already processed by the flusher, it will have these keys
if "startTime" in log_data and "endTime" in log_data:
Expand All @@ -96,35 +79,39 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]:
if "end_time" in processed_log:
processed_log["endTime"] = processed_log["end_time"].isoformat() if isinstance(processed_log["end_time"], datetime) and processed_log["end_time"] else None
del processed_log["end_time"]

# Extract parent ID
if "parent" in processed_log and processed_log["parent"]:
processed_log["parentId"] = processed_log["parent"]["id"]
del processed_log["parent"]
else:
processed_log["parentId"] = None

processed_logs.append(processed_log)

# Create the request body
body = {
"chainSlug": trace_data.get("chain_slug", trace_data.get("chainSlug")),
"featureSlug": trace_data.get("feature_slug", trace_data.get("featureSlug")),
"name": trace_data.get("name", trace_data.get("name")),
"experiment": {"id": trace_data.get("experiment", {}).id} if trace_data.get("experiment") else None,
"input": trace_data.get("input"),
"output": trace_data.get("output"),
"metadata": trace_data.get("metadata"),
"organization": trace_data.get("organization"),
"user": trace_data.get("user"),
"startTime": trace_data.get("start_time", trace_data.get("startTime")),
"endTime": trace_data.get("end_time", trace_data.get("endTime")),
"logs": processed_logs
"logs": processed_logs,
"evaluators": trace_data.get("evaluators"),
"evaluationConfig": trace_data.get("evaluationConfig")
}

# Convert dates to ISO format if they're datetime objects
if isinstance(body["startTime"], datetime):
body["startTime"] = body["startTime"].isoformat()
if isinstance(body["endTime"], datetime):
body["endTime"] = body["endTime"].isoformat()

return {
"method": "post",
"path": "/monitor/trace",
Expand Down
24 changes: 19 additions & 5 deletions basalt/objects/base_log.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from datetime import datetime
from typing import Dict, Optional, Any, TYPE_CHECKING
from typing import Dict, Optional, Any, List
import uuid

if TYPE_CHECKING:
from .log import Log
from .trace import Trace
from ..ressources.monitor.base_log_types import BaseLogParams
from ..ressources.monitor.evaluator_types import Evaluator
from ..ressources.monitor.trace_types import Trace
from ..ressources.monitor.log_types import Log

class BaseLog:
"""
Base class for logs and generations.
"""
def __init__(self, params: Dict[str, Any]):
def __init__(self, params: BaseLogParams):
self._id = f"log-{uuid.uuid4().hex[:8]}"
self._type = params.get("type")
self._name = params.get("name")
Expand All @@ -19,6 +20,7 @@ def __init__(self, params: Dict[str, Any]):
self._metadata = params.get("metadata")
self._trace = params.get("trace")
self._parent = params.get("parent")
self._evaluators = params.get("evaluators")

# Add to trace's logs list if trace exists
if self._trace:
Expand Down Expand Up @@ -68,6 +70,11 @@ def metadata(self) -> Optional[Dict[str, Any]]:
def trace(self) -> 'Trace':
"""Get the trace."""
return self._trace

@property
def evaluators(self) -> List[Evaluator]:
"""Get the evaluators."""
return self._evaluators

@trace.setter
def trace(self, trace: 'Trace'):
Expand All @@ -84,6 +91,13 @@ def set_metadata(self, metadata: Dict[str, Any]) -> 'BaseLog':
self._metadata = metadata
return self

def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog':
if self._evaluators is None:
self._evaluators = []

self._evaluators.append(evaluator)
return self

def update(self, params: Dict[str, Any]) -> 'BaseLog':
"""Update the log."""
self._name = params.get("name", self._name)
Expand Down
22 changes: 22 additions & 0 deletions basalt/objects/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import datetime
from ..ressources.monitor.experiment_types import Experiment as IExperiment

class Experiment:
def __init__(self, experiment: IExperiment):
self._experiment = experiment

@property
def id(self) -> str:
return self._experiment.id

@property
def name(self) -> str:
return self._experiment.name

@property
def feature_slug(self) -> str:
return self._experiment.feature_slug

@property
def created_at(self) -> datetime:
return self._experiment.created_at
Loading
Loading