From a5bb37fe348b3940e83aa1cb2de3e356b6cd4e72 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Sat, 4 Oct 2025 15:26:32 +0200 Subject: [PATCH 1/7] review async support --- .gitignore | 4 +- basalt/objects/base_log.py | 33 +-- basalt/objects/generation.py | 59 ++--- basalt/objects/log.py | 68 +++-- basalt/objects/trace.py | 92 ++++--- basalt/ressources/monitor/base_log_types.py | 26 +- basalt/ressources/monitor/generation_types.py | 53 ++-- basalt/ressources/monitor/log_type.py | 17 -- basalt/ressources/monitor/log_types.py | 83 ++++--- basalt/ressources/monitor/monitorsdk_types.py | 28 ++- basalt/ressources/monitor/trace_types.py | 133 +++++----- basalt/sdk/datasetsdk.py | 54 ++-- basalt/sdk/monitorsdk.py | 233 ++---------------- basalt/sdk/promptsdk.py | 173 +++++-------- basalt/utils/api.py | 18 +- basalt/utils/flusher.py | 72 ++++-- basalt/utils/networker.py | 120 ++++----- basalt/utils/protocols.py | 33 ++- 18 files changed, 590 insertions(+), 709 deletions(-) delete mode 100644 basalt/ressources/monitor/log_type.py diff --git a/.gitignore b/.gitignore index 378acfc..9f0ebfb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build/ dist/ __pycache__/ test.py -.DS_Store \ No newline at end of file +.DS_Store +.idea/ +venv/ \ No newline at end of file diff --git a/basalt/objects/base_log.py b/basalt/objects/base_log.py index d6f8878..1d65941 100644 --- a/basalt/objects/base_log.py +++ b/basalt/objects/base_log.py @@ -5,7 +5,8 @@ from ..ressources.monitor.base_log_types import BaseLogParams from ..ressources.monitor.evaluator_types import Evaluator from ..ressources.monitor.trace_types import Trace -from ..ressources.monitor.log_types import Log +from ..ressources.monitor.log_types import Log, LogType + class BaseLog: """ @@ -13,15 +14,15 @@ class BaseLog: """ def __init__(self, params: BaseLogParams): self._id = f"log-{uuid.uuid4().hex[:8]}" - self._type = params.get("type") - self._name = params.get("name") - self._start_time = params.get("start_time", datetime.now()) - self._end_time = params.get("end_time") - self._metadata = params.get("metadata") - self._trace = params.get("trace") - self._parent = params.get("parent") - self._evaluators = params.get("evaluators") - self._ideal_output = params.get("ideal_output") + self._type = params.type + self._name = params.name + self._start_time = params.start_time if params.start_time is not None else datetime.now() + self._end_time = params.end_time + self._metadata = params.metadata + self._trace = params.trace + self._parent = params.parent + self._evaluators = params.evaluators + self._ideal_output = params.ideal_output # Add to trace's logs list if trace exists if self._trace: @@ -43,7 +44,7 @@ def parent(self, parent: 'Log'): self._parent = parent @property - def type(self) -> str: + def type(self) -> LogType: """Get the log type.""" return self._type @@ -76,7 +77,7 @@ def metadata(self) -> Optional[Dict[str, Any]]: def trace(self) -> 'Trace': """Get the trace.""" return self._trace - + @property def evaluators(self) -> List[Evaluator]: """Get the evaluators.""" @@ -113,13 +114,13 @@ def update(self, params: Dict[str, Any]) -> 'BaseLog': """Update the log.""" self._name = params.get("name", self._name) self._metadata = params.get("metadata", self._metadata) - + if params.get("start_time"): self._start_time = params.get("start_time") - + if params.get("end_time"): self._end_time = params.get("end_time") - + return self def end(self) -> 'BaseLog': @@ -138,4 +139,4 @@ def to_dict(self) -> Dict[str, Any]: "end_time": self._end_time, "metadata": self._metadata, "parent": {"id": self._parent.id} if self._parent else None, - } \ No newline at end of file + } diff --git a/basalt/objects/generation.py b/basalt/objects/generation.py index 545718f..76b15b4 100644 --- a/basalt/objects/generation.py +++ b/basalt/objects/generation.py @@ -2,27 +2,28 @@ from .base_log import BaseLog from ..ressources.monitor.generation_types import GenerationParams +from ..ressources.monitor.base_log_types import BaseLogParams +from ..ressources.monitor.log_types import LogType + class Generation(BaseLog): """ Class representing a generation in the monitoring system. """ def __init__(self, params: GenerationParams): - params_with_type = { - "type": "generation", - **params - } + params_with_type = BaseLogParams(**params.__dict__, type=LogType.GENERATION) + super().__init__(params_with_type) - - self._prompt = params.get("prompt") - self._input = params.get("input") - self._output = params.get("output") - self._input_tokens = params.get("input_tokens") - self._output_tokens = params.get("output_tokens") - self._cost = params.get("cost") - + + self._prompt = params.prompt + self._input = params.input + self._output = params.output + self._input_tokens = params.input_tokens + self._output_tokens = params.output_tokens + self._cost = params.cost + # Convert variables to array format if needed - variables = params.get("variables") + variables = params.variables if variables is not None: if isinstance(variables, dict): self._variables = [{"label": str(k), "value": str(v)} for k, v in variables.items()] @@ -32,8 +33,8 @@ def __init__(self, params: GenerationParams): self._variables = [] else: self._variables = [] - - self._options = params.get("options") + + self._options = params.options @property def prompt(self) -> Optional[Dict[str, Any]]: @@ -83,50 +84,50 @@ def options(self, options: Dict[str, Any]): def start(self, input: Optional[str] = None) -> 'Generation': """ Start the generation with an optional input. - + Args: input (Optional[str]): The input to the generation. - + Returns: Generation: The generation instance. """ if input: self._input = input - + super().start() return self def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generation': """ End the generation with an optional output or update parameters. - + Args: output (Optional[Union[str, Dict[str, Any]]]): The output of the generation or a dictionary of parameters to update. - + Returns: Generation: The generation instance. """ super().end() - + if isinstance(output, dict): self.update(output) elif isinstance(output, str): self._output = output - + # If this is a single generation, end the trace as well if self._options and self._options.get("type") == "single": - self.trace.end(self._output) - + self.trace.end_sync(self._output) + return self def update(self, params: Dict[str, Any]) -> 'Generation': """ Update the generation. - + Args: params (Dict[str, Any]): Parameters to update. - + Returns: Generation: The generation instance. """ @@ -136,7 +137,7 @@ def update(self, params: Dict[str, Any]) -> 'Generation': self._input_tokens = params.get("input_tokens", self._input_tokens) self._output_tokens = params.get("output_tokens", self._output_tokens) self._cost = params.get("cost", self._cost) - + # Update variables if provided variables = params.get("variables") if variables is not None: @@ -146,6 +147,6 @@ def update(self, params: Dict[str, Any]) -> 'Generation': self._variables = [{"label": str(v.get("label")), "value": str(v.get("value"))} for v in variables if v.get("label")] else: self._variables = [] - + super().update(params) - return self \ No newline at end of file + return self diff --git a/basalt/objects/log.py b/basalt/objects/log.py index 5102cea..e414be5 100644 --- a/basalt/objects/log.py +++ b/basalt/objects/log.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, cast from ..ressources.monitor.log_types import LogParams +from ..ressources.monitor.generation_types import GenerationParams from .base_log import BaseLog from .generation import Generation @@ -9,7 +10,7 @@ class Log(BaseLog): """ def __init__(self, params: LogParams): super().__init__(params) - self._input = params.get("input") + self._input = params.input self._output = None @property @@ -25,43 +26,43 @@ def output(self) -> Optional[str]: def start(self, input: Optional[str] = None) -> 'Log': """ Start the log with an optional input. - + Args: input (Optional[str]): The input to the log. - + Returns: Log: The log instance. """ if input: self._input = input - + super().start() return self def end(self, output: Optional[str] = None) -> 'Log': """ End the log with an optional output. - + Args: output (Optional[str]): The output of the log. - + Returns: Log: The log instance. """ super().end() - + if output: self._output = output - + return self def append(self, generation: 'Generation') -> 'Log': """ Append a generation to this log. - + Args: generation (Generation): The generation to append. - + Returns: Log: The log instance. """ @@ -69,7 +70,7 @@ def append(self, generation: 'Generation') -> 'Log': generation.trace.logs = [log for log in generation.trace.logs if log.id != generation.id] # Add child to the new trace list - self.trace.logs.append(generation) + self.trace.logs.append(cast(BaseLog, generation)) # Set the trace of the generation to the current log generation.trace = self.trace @@ -77,69 +78,58 @@ def append(self, generation: 'Generation') -> 'Log': # Set the parent of the generation to the current log generation.parent = self - + return self def update(self, params: Dict[str, Any]) -> 'Log': """ Update the log with new parameters. - + Args: params (Dict[str, Any]): Parameters to update. - + Returns: Log: The log instance. """ super().update(params) - + if "output" in params: self._output = params["output"] - + if "input" in params: self._input = params["input"] - + return self def create_generation(self, params: Dict[str, Any]) -> 'Generation': """ Create a new generation as a child of this log. - + Args: params (Dict[str, Any]): Parameters for the generation. - + Returns: Generation: The new generation instance. """ - from .generation import Generation - # Set the name to the prompt slug only if name is not provided name = params.get("name") if not name and params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - - generation = Generation({ - **params, - "name": name, - "trace": self.trace, - "parent": self - }) - + + generation = Generation(GenerationParams(**params, name=name, trace=self.trace, parent=self)) + return generation def create_log(self, params: Dict[str, Any]) -> 'Log': """ Create a new log as a child of this log. - + Args: params (Dict[str, Any]): Parameters for the log. - + Returns: Log: The new log instance. """ - log = Log({ - **params, - "trace": self.trace, - "parent": self - }) - - return log \ No newline at end of file + log = Log(LogParams(**params, trace=self.trace, parent=self)) + + return log diff --git a/basalt/objects/trace.py b/basalt/objects/trace.py index cab5bef..38ae882 100644 --- a/basalt/objects/trace.py +++ b/basalt/objects/trace.py @@ -5,41 +5,44 @@ from ..ressources.monitor.trace_types import TraceParams from .base_log import BaseLog from .generation import Generation +from .log import Log from ..utils.flusher import Flusher from .experiment import Experiment from ..ressources.monitor.evaluator_types import Evaluator -from ..utils.logger import Logger +from ..ressources.monitor.generation_types import GenerationParams +from ..ressources.monitor.log_types import LogParams +from ..utils.protocols import ILogger class Trace: """ Class representing a trace in the monitoring system. """ - def __init__(self, feature_slug: str, params: TraceParams, flusher: 'Flusher', logger: 'Logger'): + def __init__(self, feature_slug: str, params: TraceParams, flusher: 'Flusher', logger: 'ILogger'): self._feature_slug = feature_slug - self._input = params.get("input") - self._output = params.get("output") - self._ideal_output = params.get("ideal_output") - self._name = params.get("name") - self._start_time = params.get("start_time", datetime.now()) - self._end_time = params.get("end_time") - self._user = params.get("user") - self._organization = params.get("organization") - self._metadata = params.get("metadata") + self._input = params.input + self._output = params.output + self._ideal_output = params.ideal_output + self._name = params.name + self._start_time = params.start_time if params.start_time else datetime.now() + self._end_time = params.end_time + self._user = params.user + self._organization = params.organization + self._metadata = params.metadata self._logs: List['BaseLog'] = [] self._flusher = flusher self._is_ended = False - self._evaluators = params.get("evaluators") - self._evaluation_config = params.get("evaluationConfig") + self._evaluators = params.evaluators + self._evaluation_config = params.evaluation_config self._logger = logger self._experiment = None if "experiment" in params: - experiment = params["experiment"] + experiment = params.experiment if experiment is None: self._logger.warn("Warning: Experiment is None. This experiment will be ignored.") elif experiment.feature_slug != self._feature_slug: @@ -243,68 +246,57 @@ def update(self, params: Dict[str, Any]) -> 'Trace': def append(self, generation: 'Generation') -> 'Trace': """ Append a generation to this trace. - + Args: generation (Generation): The generation to append. - + Returns: Trace: The trace instance. """ # Remove child log from the list of its previous trace if generation.trace: generation.trace.logs = [log for log in generation.trace.logs if log.id != generation.id] - + # Add child to the new trace list self._logs.append(generation) generation.trace = self - + return self def create_generation(self, params: Dict[str, Any]) -> 'Generation': """ Create a new generation in this trace. - + Args: params (Dict[str, Any]): Parameters for the generation. - + Returns: Generation: The new generation instance. """ - from .generation import Generation - # Set the name to the prompt slug if available name = params.get("name") if params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - - generation = Generation({ - **params, - "name": name, - "trace": self - }) - + + generation = Generation(GenerationParams(**params, name=name, trace=self)) + return generation def create_log(self, params: Dict[str, Any]) -> 'BaseLog': """ Create a new log in this trace. - + Args: params (Dict[str, Any]): Parameters for the log. - + Returns: Log: The new log instance. """ - from .log import Log - - log = Log({ - **params, - "trace": self - }) + log = Log(LogParams(**params, trace=self)) return log - def end(self, output: Optional[str] = None) -> 'Trace': + async def end(self, output: Optional[str] = None) -> 'Trace': """ End the trace with an optional output. @@ -320,7 +312,27 @@ def end(self, output: Optional[str] = None) -> 'Trace': if self._can_flush(): self._end_time = datetime.now() self._is_ended = True - self._flusher.flush_trace(self) + await self._flusher.flush_trace(self) + + return self + + def end_sync(self, output: Optional[str] = None) -> 'Trace': + """ + End the trace with an optional output synchronously. + + Args: + output (Optional[str]): The output of the trace. + + Returns: + Trace: The trace instance. + """ + self._output = output if output is not None else self._output + + # Send to the API using the flusher + if self._can_flush(): + self._end_time = datetime.now() + self._is_ended = True + self._flusher.flush_trace_sync(self) return self @@ -353,4 +365,4 @@ def _can_flush(self) -> bool: if self._is_ended: self._logger.warn('Trace already ended. This operation will be ignored.') - return not self._is_ended \ No newline at end of file + return not self._is_ended diff --git a/basalt/ressources/monitor/base_log_types.py b/basalt/ressources/monitor/base_log_types.py index 024302d..e088fb1 100644 --- a/basalt/ressources/monitor/base_log_types.py +++ b/basalt/ressources/monitor/base_log_types.py @@ -4,16 +4,15 @@ from uuid import uuid4 from .evaluator_types import Evaluator -from .log_type import LogType if TYPE_CHECKING: from .trace_types import Trace - from .log_types import Log + from .log_types import Log, LogType @dataclass class BaseLogParams: """Base parameters for creating a log entry. - + Attributes: name: Name of the log entry, describing what it represents. start_time: When the log entry started, can be a datetime object or ISO string. @@ -36,11 +35,12 @@ class BaseLogParams: parent: Optional['Log'] = None trace: 'Trace' = None evaluators: Optional[List[Evaluator]] = None + type: Optional['LogType'] = None @dataclass class BaseLog: """Base class for all log entries. - + Attributes: id: Unique identifier for this log entry. Automatically generated when the log is created. @@ -68,10 +68,10 @@ class BaseLog: parent: Optional['Log'] = None trace: 'Trace' = None evaluators: List[Evaluator] = field(default_factory=list) - + def start(self) -> 'BaseLog': """Marks the log as started and sets the start time if not already set. - + Returns: The log instance for method chaining. """ @@ -79,10 +79,10 @@ def start(self) -> 'BaseLog': def set_metadata(self, metadata: Optional[Dict[str, Any]] = None) -> 'BaseLog': """Sets the metadata for the log. - + Args: metadata: The metadata to set for the log. - + Returns: The log instance for method chaining. """ @@ -90,10 +90,10 @@ def set_metadata(self, metadata: Optional[Dict[str, Any]] = None) -> 'BaseLog': def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog': """Adds an evaluator to the log. - + Args: evaluator: The evaluator to add to the log. - + Returns: The log instance for method chaining. """ @@ -101,10 +101,10 @@ def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog': def update(self, **params) -> 'BaseLog': """Updates the log with new parameters. - + Args: **params: The parameters to update. - + Returns: The log instance for method chaining. """ @@ -112,7 +112,7 @@ def update(self, **params) -> 'BaseLog': def end(self) -> 'BaseLog': """Marks the log as ended. - + Returns: The log instance for method chaining. """ diff --git a/basalt/ressources/monitor/generation_types.py b/basalt/ressources/monitor/generation_types.py index 761aabd..7dbc48d 100644 --- a/basalt/ressources/monitor/generation_types.py +++ b/basalt/ressources/monitor/generation_types.py @@ -1,4 +1,5 @@ -from typing import Dict, Optional, Union, Any +from numbers import Number +from typing import Dict, Optional, Union, Any, List from dataclasses import dataclass, field from .base_log_types import BaseLog, BaseLogParams, LogType @@ -6,13 +7,13 @@ @dataclass class PromptReference: """Reference to a prompt template. - + This class represents a reference to a prompt template used in AI model generations. - + Attributes: slug (str): Unique identifier for the prompt template. version (str): Version of the prompt template. - + Example: ```python # Basic prompt reference @@ -20,22 +21,23 @@ class PromptReference: ``` """ slug: str - version: str + version: Optional[str] = None + tag: Optional[str] = None @dataclass class GenerationParams(BaseLogParams): """Parameters for creating a new generation. - + This class defines the parameters that can be used to create a new generation, either with or without a prompt reference. - + Attributes: prompt (Optional[PromptReference]): Reference to the prompt template used. input (Optional[str]): The input provided to the model. output (Optional[str]): The output generated by the model. variables (Optional[Dict[str, Any]]): Variables used in the prompt template. options (Optional[Dict[str, Any]]): Additional options for the generation. - + Example: ```python # Create generation parameters with a prompt reference @@ -45,7 +47,7 @@ class GenerationParams(BaseLogParams): input="What is the capital of France?", variables={"style": "concise", "language": "en"} ) - + # Create generation parameters without a prompt reference params = GenerationParams( name="text-completion", @@ -59,14 +61,17 @@ class GenerationParams(BaseLogParams): output: Optional[str] = None variables: Optional[Dict[str, Any]] = None options: Optional[Dict[str, Any]] = None + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + cost: Optional[float] = None @dataclass class Generation(BaseLog): """Generation class representing an AI model generation within a trace. - + This class tracks interactions with AI models, including inputs, outputs, and prompt information used for the generation. - + Attributes: prompt (Optional[PromptReference]): Reference to the prompt template used. input (Optional[str]): The input provided to the model. @@ -76,7 +81,7 @@ class Generation(BaseLog): input_tokens (Optional[int]): Number of tokens used for the input. output_tokens (Optional[int]): Number of tokens used for the output. cost (Optional[float]): Cost of the generation. - + Example: ```python # Create a generation with a prompt reference @@ -85,13 +90,13 @@ class Generation(BaseLog): prompt=PromptReference(slug="qa-prompt", version="2.1.0"), input="What is the capital of France?" ) - + # Start the generation generation.start() - + # End the generation with output generation.end("The capital of France is Paris.") - + # Update generation metadata generation.update(metadata={ "model_version": "gpt-4", @@ -110,18 +115,18 @@ class Generation(BaseLog): def start(self, input: Optional[str] = None) -> 'Generation': """Marks the generation as started and sets the input if provided. - + Args: input (Optional[str]): Optional input data to associate with the generation. - + Returns: Generation: The generation instance for method chaining. - + Example: ```python # Start a generation without input generation.start() - + # Start a generation with input generation.start("What is the capital of France?") ``` @@ -130,22 +135,22 @@ def start(self, input: Optional[str] = None) -> 'Generation': def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generation': """Marks the generation as ended and sets the output if provided. - + Args: output (Optional[Union[str, Dict[str, Any]]]): Optional output data from the model. Can be either a string or a dictionary containing output parameters. - + Returns: Generation: The generation instance for method chaining. - + Example: ```python # End a generation without output generation.end() - + # End a generation with output as string generation.end("The capital of France is Paris.") - + # End a generation with output params generation.end({ "output": "The capital of France is Paris.", diff --git a/basalt/ressources/monitor/log_type.py b/basalt/ressources/monitor/log_type.py deleted file mode 100644 index cd4c1f2..0000000 --- a/basalt/ressources/monitor/log_type.py +++ /dev/null @@ -1,17 +0,0 @@ -class LogType: - """Enum-like class for log types. - - Attributes: - SPAN: Represents a span log type - GENERATION: Represents a generation log type - FUNCTION: Represents a function log type - TOOL: Represents a tool log type - RETRIEVAL: Represents a retrieval log type - EVENT: Represents an event log type - """ - SPAN = 'span' - GENERATION = 'generation' - FUNCTION = 'function' - TOOL = 'tool' - RETRIEVAL = 'retrieval' - EVENT = 'event' \ No newline at end of file diff --git a/basalt/ressources/monitor/log_types.py b/basalt/ressources/monitor/log_types.py index 90caf8c..5266ccd 100644 --- a/basalt/ressources/monitor/log_types.py +++ b/basalt/ressources/monitor/log_types.py @@ -1,19 +1,38 @@ +from enum import Enum from typing import Optional, TYPE_CHECKING from dataclasses import dataclass from .base_log_types import BaseLog, BaseLogParams -from .log_type import LogType if TYPE_CHECKING: from .generation_types import Generation, GenerationParams + +class LogType(Enum): + """Enum-like class for log types. + + Attributes: + SPAN: Represents a span log type + GENERATION: Represents a generation log type + FUNCTION: Represents a function log type + TOOL: Represents a tool log type + RETRIEVAL: Represents a retrieval log type + EVENT: Represents an event log type + """ + SPAN = 'span' + GENERATION = 'generation' + FUNCTION = 'function' + TOOL = 'tool' + RETRIEVAL = 'retrieval' + EVENT = 'event' + @dataclass class LogParams(BaseLogParams): """Parameters for creating or updating a log. - + This class defines the parameters needed to create or update a log entry, including its type, input, and output data. - + Attributes: type: The type of log entry (e.g., 'span', 'generation'). Used to distinguish between different kinds of logs. @@ -27,40 +46,40 @@ class LogParams(BaseLogParams): @dataclass class Log(BaseLog): """Log interface representing a specific operation or step within a trace. - + Logs are used to track discrete operations within a process flow, such as data fetching, validation, or any other logical step. Logs can contain generations and can be nested within other logs to create a hierarchical structure of operations. - + Example: ```python # Create a log within a trace log = trace.create_log({ 'name': 'data-processing' }) - + # Start the log with input log.start('Raw user data') - + # Create a nested log for a sub-operation validation_log = log.create_log({ 'name': 'data-validation' }) - + # Create a generation within the validation log generation = validation_log.create_generation({ 'name': 'validation-check', 'prompt': {'slug': 'data-validator', 'version': '1.0.0'}, 'input': 'Raw user data' }) - + # End the generation with output generation.end('Data is valid') - + # End the validation log validation_log.end('Validation complete') - + # End the main log with processed output log.end('Processed user data') ``` @@ -71,18 +90,18 @@ class Log(BaseLog): def start(self, input: Optional[str] = None) -> 'Log': """Marks the log as started and sets the input if provided. - + Args: input: Optional input data to associate with the log. - + Returns: The log instance for method chaining. - + Example: ```python # Start a log without input log.start() - + # Start a log with input log.start('Raw user data to be processed') ``` @@ -91,18 +110,18 @@ def start(self, input: Optional[str] = None) -> 'Log': def end(self, output: Optional[str] = None) -> 'Log': """Marks the log as ended and sets the output if provided. - + Args: output: Optional output data to associate with the log. - + Returns: The log instance for method chaining. - + Example: ```python # End a log without output log.end() - + # End a log with output log.end('Processed data: {"success": true, "items": 42}') ``` @@ -111,13 +130,13 @@ def end(self, output: Optional[str] = None) -> 'Log': def append(self, generation: 'Generation') -> 'Log': """Adds a generation to this log. - + Args: generation: The generation to add to this log. - + Returns: The log instance for method chaining. - + Example: ```python # Create a generation separately @@ -125,7 +144,7 @@ def append(self, generation: 'Generation') -> 'Log': 'name': 'external-generation', 'trace': trace }) - + # Append the generation to this log log.append(generation) ``` @@ -134,13 +153,13 @@ def append(self, generation: 'Generation') -> 'Log': def create_generation(self, params: 'GenerationParams') -> 'Generation': """Creates a new generation within this log. - + Args: params: Parameters for the generation. - + Returns: A new Generation instance associated with this log. - + Example: ```python # Create a generation with a prompt reference @@ -151,7 +170,7 @@ def create_generation(self, params: 'GenerationParams') -> 'Generation': 'variables': {'language': 'en', 'mode': 'detailed'}, 'metadata': {'priority': 'high'} }) - + # Create a simple generation without a prompt reference simple_generation = log.create_generation({ 'name': 'quick-check', @@ -164,20 +183,20 @@ def create_generation(self, params: 'GenerationParams') -> 'Generation': def create_log(self, params: LogParams) -> 'Log': """Creates a new nested log within this log. - + Args: params: Parameters for the nested log. - + Returns: A new Log instance associated with this log as its parent. - + Example: ```python # Create a basic nested log nested_log = log.create_log({ 'name': 'sub-operation' }) - + # Create a detailed nested log detailed_nested_log = log.create_log({ 'name': 'data-transformation', @@ -186,4 +205,4 @@ def create_log(self, params: LogParams) -> 'Log': }) ``` """ - ... \ No newline at end of file + ... diff --git a/basalt/ressources/monitor/monitorsdk_types.py b/basalt/ressources/monitor/monitorsdk_types.py index dfc6fa6..ed9ef65 100644 --- a/basalt/ressources/monitor/monitorsdk_types.py +++ b/basalt/ressources/monitor/monitorsdk_types.py @@ -151,7 +151,7 @@ def create_log(self, params: LogParams) -> Log: """ ... - def create_experiment(self, feature_slug: str, params: ExperimentParams) -> Tuple[Optional[Exception], Optional[Experiment]]: + async def create_experiment(self, feature_slug: str, params: ExperimentParams) -> Tuple[Optional[Exception], Optional[Experiment]]: """Creates a new experiment to bundle multiple traces together in. You can pass this experiment to the create_trace method to add the generated traces to the experiment. @@ -164,7 +164,31 @@ def create_experiment(self, feature_slug: str, params: ExperimentParams) -> Tupl Examples: ```python - experiment = basalt.monitor.create_experiment('user-query', {'name': 'my-experiment'}) + experiment = await basalt.monitor.create_experiment('user-query', {'name': 'my-experiment'}) + + # Create a trace and add it to the experiment + trace = basalt.monitor.create_trace('user-query', {'experiment': experiment}) + ``` + + Returns: + A tuple containing (Optional[Exception], Optional[Experiment]). The Experiment object can be used to track the AI generation. + """ + ... + + def create_experiment_sync(self, feature_slug: str, params: ExperimentParams) -> Tuple[Optional[Exception], Optional[Experiment]]: + """Synchronously creates a new experiment to bundle multiple traces together in. + + You can pass this experiment to the create_trace method to add the generated traces to the experiment. + It's used mostly for local experimentations, to compare the performance between different versions of a workflow. + + Args: + feature_slug: The unique identifier of the feature to which the experiment belongs. + params: Parameters for the experiment. + - name: Name of the experiment (required). + + Examples: + ```python + experiment = basalt.monitor.create_experiment_sync('user-query', {'name': 'my-experiment'}) # Create a trace and add it to the experiment trace = basalt.monitor.create_trace('user-query', {'experiment': experiment}) diff --git a/basalt/ressources/monitor/trace_types.py b/basalt/ressources/monitor/trace_types.py index 6d659e0..b57a9af 100644 --- a/basalt/ressources/monitor/trace_types.py +++ b/basalt/ressources/monitor/trace_types.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from .experiment_types import Experiment from .evaluator_types import Evaluator, EvaluationConfig -from .log_type import LogType if TYPE_CHECKING: from .log_types import Log, LogParams @@ -41,58 +40,58 @@ class TraceParams: @dataclass class Trace(TraceParams): """A trace represents a complete user interaction or process flow and serves as the top-level container for all monitoring activities. - + A trace provides methods to create and manage spans and generations within the process flow. - + Example: ```python # Create a basic trace trace = monitor_sdk.create_trace('user-query') - + # Start the trace with input trace.start('What is the capital of France?') - + # Create a span within the trace processing_log = trace.create_log( name='query-processing', type='process' ) - + # Create a generation within the span generation = processing_log.create_generation( name='answer-generation', prompt={'slug': 'qa-prompt', 'version': '1.0.0'}, input='What is the capital of France?' ) - + # End the generation with output generation.end('The capital of France is Paris.') - + # End the span processing_log.end() - + # End the trace with final output trace.end('Paris is the capital of France.') ``` """ - + start_time: datetime logs: List['BaseLog'] = field(default_factory=list) - + def start(self, input: Optional[str] = None) -> 'Trace': """Marks the trace as started and sets the input if provided. - + Args: input: Optional input data to associate with the trace. - + Returns: The trace instance for method chaining. - + Example: ```python # Start a trace without input trace.start() - + # Start a trace with input trace.start('User query: What is the capital of France?') ``` @@ -102,16 +101,16 @@ def start(self, input: Optional[str] = None) -> 'Trace': def set_ideal_output(self, ideal_output: str) -> 'Trace': """Sets the ideal output for the trace.""" ... - + def set_metadata(self, metadata: Dict[str, Any]) -> 'Trace': """Sets or updates the metadata for this trace. - + Args: metadata: The metadata to associate with this trace. - + Returns: The trace instance for method chaining. - + Example: ```python # Add metadata to the trace @@ -123,39 +122,39 @@ def set_metadata(self, metadata: Dict[str, Any]) -> 'Trace': ``` """ ... - + def set_evaluation_config(self, config: EvaluationConfig) -> 'Trace': """Sets the evaluation configuration for the trace. - + Args: config: The evaluation configuration to set. - + Returns: The trace instance for method chaining. """ ... - + def set_experiment(self, experiment: Experiment) -> 'Trace': """Sets the experiment for the trace. - + Args: experiment: The experiment to set. - + Returns: The trace instance for method chaining. """ ... - + def update(self, params: TraceParams) -> 'Trace': """Updates the trace with new parameters. The new parameters given in this method will override the existing ones. - + Args: params: The parameters to update. - + Returns: The trace instance for method chaining. - + Example: ```python # Update trace parameters @@ -166,27 +165,27 @@ def update(self, params: TraceParams) -> 'Trace': ``` """ ... - + def add_evaluator(self, evaluator: Evaluator) -> 'Trace': """Adds an evaluator to the trace. - + Args: evaluator: The evaluator to add to the trace. - + Returns: The trace instance for method chaining. """ ... - + def append(self, log: 'BaseLog') -> 'Trace': """Adds a log (span or generation) to this trace. - + Args: log: The log to add to this trace. - + Returns: The trace instance for method chaining. - + Example: ```python # Create a generation separately and append it to the trace @@ -194,23 +193,23 @@ def append(self, log: 'BaseLog') -> 'Trace': name='external-generation', trace=another_trace ) - + # Append the generation to this trace trace.append(generation) ``` """ ... - + def identify(self, user: Optional[User] = None, organization: Optional[Organization] = None) -> 'Trace': """Associates user information with this trace. - + Args: user: The user information to associate with this trace. organization: The organization information to associate with this trace. - + Returns: The trace instance for method chaining. - + Example: ```python # Identify a user with user and organization information @@ -227,16 +226,16 @@ def identify(self, user: Optional[User] = None, organization: Optional[Organizat ``` """ ... - + def create_generation(self, params: 'GenerationParams') -> 'Generation': """Creates a new generation within this trace. - + Args: params: Parameters for the generation. - + Returns: A new Generation instance associated with this trace. - + Example: ```python # Create a generation with a prompt reference @@ -247,7 +246,7 @@ def create_generation(self, params: 'GenerationParams') -> 'Generation': 'variables': {'style': 'concise', 'language': 'en'}, 'metadata': {'model_version': 'gpt-4'} }) - + # Create a generation without a prompt reference simple_generation = trace.create_generation({ 'name': 'text-completion', @@ -257,16 +256,16 @@ def create_generation(self, params: 'GenerationParams') -> 'Generation': ``` """ ... - + def create_log(self, params: 'LogParams') -> 'Log': """Creates a new span within this trace. - + Args: params: Parameters for the span. - + Returns: A new Log instance associated with this trace. - + Example: ```python # Create a basic span @@ -274,7 +273,7 @@ def create_log(self, params: 'LogParams') -> 'Log': 'name': 'data-fetching', 'type': 'io' }) - + # Create a detailed span detailed_log = trace.create_log({ 'name': 'user-validation', @@ -284,21 +283,41 @@ def create_log(self, params: 'LogParams') -> 'Log': ``` """ ... - - def end(self, output: Optional[str] = None) -> 'Trace': + + async def end(self, output: Optional[str] = None) -> 'Trace': """Marks the trace as ended and sets the output if provided. - + Args: output: Optional output data to associate with the trace. - + Returns: The trace instance for method chaining. - + Example: ```python # End a trace without output trace.end() - + + # End a trace with output + trace.end('The capital of France is Paris.') + ``` + """ + ... + + def end_sync(self, output: Optional[str] = None) -> 'Trace': + """Marks the trace as ended and sets the output if provided. + + Args: + output: Optional output data to associate with the trace. + + Returns: + The trace instance for method chaining. + + Example: + ```python + # End a trace without output + trace.end() + # End a trace with output trace.end('The capital of France is Paris.') ``` diff --git a/basalt/sdk/datasetsdk.py b/basalt/sdk/datasetsdk.py index 05b5acc..0f33eeb 100644 --- a/basalt/sdk/datasetsdk.py +++ b/basalt/sdk/datasetsdk.py @@ -2,7 +2,6 @@ SDK for interacting with Basalt datasets """ from typing import Dict, List, Optional, Tuple, Any -import asyncio from ..utils.dtos import ( ListDatasetsDTO, GetDatasetDTO, CreateDatasetItemDTO, @@ -13,7 +12,6 @@ from ..endpoints.list_datasets import ListDatasetsEndpoint from ..endpoints.get_dataset import GetDatasetEndpoint from ..endpoints.create_dataset_item import CreateDatasetItemEndpoint -from ..objects.dataset import Dataset, DatasetRow class DatasetSDK: @@ -28,16 +26,16 @@ def __init__( self._api = api self._logger = logger - def list(self) -> ListDatasetsResult: + async def list(self) -> ListDatasetsResult: """ List all datasets available in the workspace. Returns: - Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional + Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional exception and an optional list of DatasetDTO objects. """ dto = ListDatasetsDTO() - err, result = self._api.invoke(ListDatasetsEndpoint, dto) + err, result = await self._api.invoke(ListDatasetsEndpoint, dto) if err is not None: return err, None @@ -47,17 +45,17 @@ def list(self) -> ListDatasetsResult: name=dataset.name, columns=dataset.columns ) for dataset in result.datasets] - - async def async_list(self) -> ListDatasetsResult: + + def list_sync(self) -> ListDatasetsResult: """ - Asynchronously list all datasets available in the workspace. + Synchronously list all datasets available in the workspace. Returns: - Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional + Tuple[Optional[Exception], Optional[List[DatasetDTO]]]: A tuple containing an optional exception and an optional list of DatasetDTO objects. """ dto = ListDatasetsDTO() - err, result = await self._api.async_invoke(ListDatasetsEndpoint, dto) + err, result = self._api.invoke_sync(ListDatasetsEndpoint, dto) if err is not None: return err, None @@ -68,7 +66,7 @@ async def async_list(self) -> ListDatasetsResult: columns=dataset.columns ) for dataset in result.datasets] - def get(self, slug: str) -> GetDatasetResult: + async def get(self, slug: str) -> GetDatasetResult: """ Get a dataset by its slug. @@ -80,19 +78,19 @@ def get(self, slug: str) -> GetDatasetResult: exception and an optional DatasetDTO. """ dto = GetDatasetDTO(slug=slug) - err, result = self._api.invoke(GetDatasetEndpoint, dto) + err, result = await self._api.invoke(GetDatasetEndpoint, dto) if err is not None: return err, None - + if result.error: return Exception(result.error), None return None, result.dataset - - async def async_get(self, slug: str) -> GetDatasetResult: + + def get_sync(self, slug: str) -> GetDatasetResult: """ - Asynchronously get a dataset by its slug. + Synchronously get a dataset by its slug. Args: slug (str): The slug identifier for the dataset. @@ -102,17 +100,17 @@ async def async_get(self, slug: str) -> GetDatasetResult: exception and an optional DatasetDTO. """ dto = GetDatasetDTO(slug=slug) - err, result = await self._api.async_invoke(GetDatasetEndpoint, dto) + err, result = self._api.invoke_sync(GetDatasetEndpoint, dto) if err is not None: return err, None - + if result.error: return Exception(result.error), None return None, result.dataset - def addRow( + async def add_row( self, slug: str, values: Dict[str, str], @@ -142,17 +140,17 @@ def addRow( metadata=metadata ) - err, result = self._api.invoke(CreateDatasetItemEndpoint, dto) + err, result = await self._api.invoke(CreateDatasetItemEndpoint, dto) if err is not None: return err, None, None - + if result.error: return Exception(result.error), None, None - + return None, result.datasetRow, result.warning - - async def async_addRow( + + def add_row_sync( self, slug: str, values: Dict[str, str], @@ -161,7 +159,7 @@ async def async_addRow( metadata: Optional[Dict[str, Any]] = None ) -> CreateDatasetItemResult: """ - Asynchronously create a new item in a dataset. + Synchronously create a new item in a dataset. Args: slug (str): The slug identifier for the dataset. @@ -182,12 +180,12 @@ async def async_addRow( metadata=metadata ) - err, result = await self._api.async_invoke(CreateDatasetItemEndpoint, dto) + err, result = self._api.invoke_sync(CreateDatasetItemEndpoint, dto) if err is not None: return err, None, None - + if result.error: return Exception(result.error), None, None - + return None, result.datasetRow, result.warning diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index fd0f3e9..6ea80a6 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -1,5 +1,4 @@ from typing import Dict, Optional, Any, Tuple -import asyncio from ..utils.protocols import IApi, ILogger from ..ressources.monitor.trace_types import TraceParams @@ -25,13 +24,13 @@ def __init__( self._api = api self._logger = logger - def create_experiment( + async def create_experiment( self, feature_slug: str, params: ExperimentParams ) -> Tuple[Optional[Exception], Optional[Experiment]]: """ - Creates a new experiment for monitoring. + Asynchronously creates a new experiment for monitoring. Args: feature_slug (str): The feature slug for the experiment. @@ -40,15 +39,15 @@ def create_experiment( Returns: Experiment: A new Experiment instance. """ - return self._create_experiment(feature_slug, params) - - async def async_create_experiment( + return await self._create_experiment(feature_slug, params) + + def create_experiment_sync( self, feature_slug: str, params: ExperimentParams ) -> Tuple[Optional[Exception], Optional[Experiment]]: """ - Asynchronously creates a new experiment for monitoring. + Synchronously creates a new experiment for monitoring. Args: feature_slug (str): The feature slug for the experiment. @@ -57,8 +56,7 @@ async def async_create_experiment( Returns: Experiment: A new Experiment instance. """ - return await self._async_create_experiment(feature_slug, params) - + return self._create_experiment_sync(feature_slug, params) def create_trace( self, @@ -81,28 +79,6 @@ def create_trace( trace_params = TraceParams(**params) return self._create_trace(slug, trace_params) - - async def async_create_trace( - self, - slug: str, - params: Optional[TraceParams] = None - ) -> Trace: - """ - Asynchronously creates a new trace for monitoring. - - Args: - slug (str): The unique identifier for the trace. - params (TraceParams): Parameters for the trace. - - Returns: - Trace: A new Trace instance. - """ - if params is None: - params = {} - - trace_params = TraceParams(**params) - - return await self._async_create_trace(slug, trace_params) def create_generation( self, @@ -119,22 +95,6 @@ def create_generation( """ generation_params = GenerationParams(**params) return self._create_generation(generation_params) - - async def async_create_generation( - self, - params: Dict[str, Any] - ) -> Generation: - """ - Asynchronously creates a new generation for monitoring. - - Args: - params (Dict[str, Any]): Parameters for the generation. - - Returns: - Generation: A new Generation instance. - """ - generation_params = GenerationParams(**params) - return await self._async_create_generation(generation_params) def create_log( self, @@ -151,30 +111,14 @@ def create_log( """ log_params = LogParams(**params) return self._create_log(log_params) - - async def async_create_log( - self, - params: Dict[str, Any] - ) -> Log: - """ - Asynchronously creates a new log for monitoring. - - Args: - params (Dict[str, Any]): Parameters for the log. - - Returns: - Log: A new Log instance. - """ - log_params = LogParams(**params) - return await self._async_create_log(log_params) - def _create_experiment( + async def _create_experiment( self, feature_slug: str, params: ExperimentParams ) -> Tuple[Optional[Exception], Optional[Experiment]]: """ - Internal implementation for creating an experiment. + Internal async implementation for creating an experiment. Args: feature_slug (str): The feature slug for the experiment. @@ -185,24 +129,24 @@ def _create_experiment( """ dto = CreateExperimentDTO( feature_slug=feature_slug, - name=params.get("name"), + name=params.name, ) # Call the API endpoint - err, result = self._api.invoke(CreateExperimentEndpoint, dto) + err, result = await self._api.invoke(CreateExperimentEndpoint, dto) if err is None: return None, Experiment(result.experiment) return err, None - - async def _async_create_experiment( + + def _create_experiment_sync( self, feature_slug: str, params: ExperimentParams ) -> Tuple[Optional[Exception], Optional[Experiment]]: """ - Internal implementation for asynchronously creating an experiment. + Internal sync implementation for creating an experiment. Args: feature_slug (str): The feature slug for the experiment. @@ -213,11 +157,11 @@ async def _async_create_experiment( """ dto = CreateExperimentDTO( feature_slug=feature_slug, - name=params.get("name"), + name=params.name, ) # Call the API endpoint - err, result = await self._api.async_invoke(CreateExperimentEndpoint, dto) + err, result = self._api.invoke_sync(CreateExperimentEndpoint, dto) if err is None: return None, Experiment(result.experiment) @@ -241,61 +185,12 @@ def _create_trace( Trace: A new Trace instance. """ flusher = Flusher(self._api, self._logger) - # Convert TraceParams to a dictionary before passing to Trace - params_dict = { - "input": params.input, - "output": params.output, - "name": params.name, - "ideal_output": params.ideal_output, - "start_time": params.start_time, - "end_time": params.end_time, - "user": params.user, - "organization": params.organization, - "metadata": params.metadata, - "experiment": params.experiment, - "evaluators": params.evaluators, - "evaluationConfig": params.evaluation_config - } - trace = Trace(slug, params_dict, flusher, self._logger) - return trace - - async def _async_create_trace( - self, - slug: str, - params: TraceParams - ) -> Trace: - """ - Internal implementation for asynchronously creating a trace. - Args: - slug (str): The unique identifier for the trace. - params (TraceParams): Parameters for the trace. - - Returns: - Trace: A new Trace instance. - """ - flusher = Flusher(self._api, self._logger) - # Convert TraceParams to a dictionary before passing to Trace - params_dict = { - "input": params.input, - "output": params.output, - "name": params.name, - "start_time": params.start_time, - "end_time": params.end_time, - "user": params.user, - "organization": params.organization, - "metadata": params.metadata, - "experiment": params.experiment, - "evaluators": params.evaluators, - "evaluationConfig": params.evaluation_config - } - trace = Trace(slug, params_dict, flusher, self._logger) + trace = Trace(slug, params, flusher, self._logger) return trace - def _create_generation( - self, - params: GenerationParams - ) -> Generation: + @staticmethod + def _create_generation(params: GenerationParams) -> Generation: """ Internal implementation for creating a generation. @@ -305,55 +200,10 @@ def _create_generation( Returns: Generation: A new Generation instance. """ - # Convert GenerationParams to a dictionary before passing to Generation - params_dict = { - "name": params.name, - "trace": params.trace, - "prompt": params.prompt, - "input": params.input, - "output": params.output, - "variables": params.variables, - "parent": params.parent, - "metadata": params.metadata, - "start_time": params.start_time, - "end_time": params.end_time, - "options": params.options - } - return Generation(params_dict) - - async def _async_create_generation( - self, - params: GenerationParams - ) -> Generation: - """ - Internal implementation for asynchronously creating a generation. - - Args: - params (GenerationParams): Parameters for the generation. + return Generation(params) - Returns: - Generation: A new Generation instance. - """ - # Convert GenerationParams to a dictionary before passing to Generation - params_dict = { - "name": params.name, - "trace": params.trace, - "prompt": params.prompt, - "input": params.input, - "output": params.output, - "variables": params.variables, - "parent": params.parent, - "metadata": params.metadata, - "start_time": params.start_time, - "end_time": params.end_time, - "options": params.options - } - return Generation(params_dict) - - def _create_log( - self, - params: LogParams - ) -> Log: + @staticmethod + def _create_log(params: LogParams) -> Log: """ Internal implementation for creating a log. @@ -363,41 +213,4 @@ def _create_log( Returns: Log: A new Log instance. """ - # Convert LogParams to a dictionary before passing to Log - params_dict = { - "name": params.name, - "trace": params.trace, - "input": params.input, - "output": params.output, - "parent": params.parent, - "metadata": params.metadata, - "start_time": params.start_time, - "end_time": params.end_time - } - return Log(params_dict) - - async def _async_create_log( - self, - params: LogParams - ) -> Log: - """ - Internal implementation for asynchronously creating a log. - - Args: - params (LogParams): Parameters for the log. - - Returns: - Log: A new Log instance. - """ - # Convert LogParams to a dictionary before passing to Log - params_dict = { - "name": params.name, - "trace": params.trace, - "input": params.input, - "output": params.output, - "parent": params.parent, - "metadata": params.metadata, - "start_time": params.start_time, - "end_time": params.end_time - } - return Log(params_dict) \ No newline at end of file + return Log(params) diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index f909c4e..abbc27f 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -1,6 +1,8 @@ -from typing import Optional, Dict, Tuple, Any, Awaitable +from typing import Optional, Dict, Tuple, Any -from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, GetResult, DescribeResult, ListResult, PromptListResponse, PromptListDTO +from ..ressources.monitor.generation_types import GenerationParams, PromptReference +from ..ressources.monitor.trace_types import TraceParams +from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, DescribeResult, ListResult, PromptListResponse, PromptListDTO from ..utils.protocols import ICache, IApi, ILogger from ..endpoints.get_prompt import GetPromptEndpoint @@ -11,7 +13,6 @@ from ..objects.generation import Generation from ..utils.flusher import Flusher from datetime import datetime -import asyncio class PromptSDK: """ @@ -32,7 +33,7 @@ def __init__( self._cache_duration = 5 * 60 self._logger = logger - def get( + async def get( self, slug: str, version: Optional[str] = None, @@ -47,11 +48,11 @@ def get( slug (str): The slug identifier for the prompt. version (Optional[str]): The version of the prompt. tag (Optional[str]): The tag associated with the prompt. - variables (dict): A dictionnary of variables to replace in the prompt text. + variables (dict): A dictionary of variables to replace in the prompt text. cache_enabled (bool): Enable or disable cache for this request. Returns: - Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. """ dto = GetPromptDTO( @@ -68,11 +69,11 @@ def get( generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) return err, prompt_response, generation - err, result = self._api.invoke(GetPromptEndpoint, dto) + err, result = await self._api.invoke(GetPromptEndpoint, dto) if err is None: original_prompt_text = result.prompt.text - self._cache.put(dto, result.prompt, ttl=self._cache_duration) + self._cache.put(dto, result.prompt, self._cache_duration) self._fallback_cache.put(dto, result.prompt) err, prompt_response = self._replace_vars(result.prompt, variables) @@ -88,8 +89,8 @@ def get( return err, prompt_response, generation return err, None, None - - async def async_get( + + def get_sync( self, slug: str, version: Optional[str] = None, @@ -98,17 +99,17 @@ async def async_get( cache_enabled: bool = True ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: """ - Asynchronously retrieve a prompt by slug, optionally specifying version and tag. + Synchronously retrieve a prompt by slug, optionally specifying version and tag. Args: slug (str): The slug identifier for the prompt. version (Optional[str]): The version of the prompt. tag (Optional[str]): The tag associated with the prompt. - variables (dict): A dictionnary of variables to replace in the prompt text. + variables (dict): A dictionary of variables to replace in the prompt text. cache_enabled (bool): Enable or disable cache for this request. Returns: - Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. """ dto = GetPromptDTO( @@ -122,10 +123,10 @@ async def async_get( if cached: original_prompt_text = cached.text err, prompt_response = self._replace_vars(cached, variables) - generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) return err, prompt_response, generation - err, result = await self._api.async_invoke(GetPromptEndpoint, dto) + err, result = self._api.invoke_sync(GetPromptEndpoint, dto) if err is None: original_prompt_text = result.prompt.text @@ -133,7 +134,7 @@ async def async_get( self._fallback_cache.put(dto, result.prompt) err, prompt_response = self._replace_vars(result.prompt, variables) - generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) return err, prompt_response, generation fallback = self._fallback_cache.get(dto) if cache_enabled else None @@ -141,71 +142,23 @@ async def async_get( if fallback: original_prompt_text = fallback.text err, prompt_response = self._replace_vars(fallback, variables) - generation = await self._async_prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) return err, prompt_response, generation return err, None, None def _prepare_monitoring( - self, - prompt: PromptResponse, - slug: str, - version: Optional[str] = None, + self, + prompt: PromptResponse, + slug: str, + version: Optional[str] = None, tag: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, original_prompt_text: Optional[str] = None ) -> Generation: """ Prepare monitoring by creating a trace and generation object. - - Args: - prompt (PromptResponse): The prompt response. - slug (str): The slug identifier for the prompt. - version (Optional[str]): The version of the prompt. - tag (Optional[str]): The tag associated with the prompt. - variables (Optional[Dict[str, Any]]): Variables used in the prompt. - original_prompt_text (Optional[str]): The original prompt text. - - Returns: - Generation: The generation object. - """ - # Create a flusher - flusher = Flusher(self._api, self._logger) - - # Create a trace - trace = Trace(slug, { - "input": original_prompt_text or prompt.text, - "start_time": datetime.now() - }, flusher, self._logger) - - # Create a generation - generation = Generation({ - "name": slug, - "trace": trace, - "prompt": { - "slug": slug, - "version": version, - "tag": tag - }, - "input": original_prompt_text or prompt.text, - "variables": variables, - "options": {"type": "single"} - }) - - return generation - - async def _async_prepare_monitoring( - self, - prompt: PromptResponse, - slug: str, - version: Optional[str] = None, - tag: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - original_prompt_text: Optional[str] = None - ) -> Generation: - """ - Asynchronously prepare monitoring by creating a trace and generation object. - + Args: prompt (PromptResponse): The prompt response. slug (str): The slug identifier for the prompt. @@ -213,36 +166,36 @@ async def _async_prepare_monitoring( tag (Optional[str]): The tag associated with the prompt. variables (Optional[Dict[str, Any]]): Variables used in the prompt. original_prompt_text (Optional[str]): The original prompt text. - + Returns: Generation: The generation object. """ # Create a flusher flusher = Flusher(self._api, self._logger) - + # Create a trace - trace = Trace(slug, { - "input": original_prompt_text or prompt.text, - "start_time": datetime.now() - }, flusher, self._logger) - + trace = Trace(slug, TraceParams( + input=original_prompt_text or prompt.text, + start_time=datetime.now() + ), flusher, self._logger) + # Create a generation - generation = Generation({ - "name": slug, - "trace": trace, - "prompt": { - "slug": slug, - "version": version, - "tag": tag - }, - "input": original_prompt_text or prompt.text, - "variables": variables, - "options": {"type": "single"} - }) - + generation = Generation(GenerationParams( + name=slug, + trace=trace, + prompt=PromptReference( + slug=slug, + version=version, + tag=tag + ), + input=original_prompt_text or prompt.text, + variables=variables, + options={"type": "single"} + )) + return generation - def describe( + async def describe( self, slug: str, version: Optional[str] = None, @@ -255,7 +208,6 @@ def describe( slug (str): The slug identifier for the prompt. version (Optional[str]): The version of the prompt. tag (Optional[str]): The tag associated with the prompt. - cache_enabled (bool): Enable or disable cache for this request. Returns: Tuple[Optional[Exception], Optional[DescribePromptResponse]]: A tuple containing an optional exception and an optional DescribePromptResponse. @@ -266,7 +218,7 @@ def describe( tag=tag ) - err, result = self._api.invoke(DescribePromptEndpoint, dto) + err, result = await self._api.invoke(DescribePromptEndpoint, dto) if err is None: prompt = result.prompt @@ -282,15 +234,15 @@ def describe( ) return err, None - - async def async_describe( + + def describe_sync( self, slug: str, version: Optional[str] = None, tag: Optional[str] = None, ) -> DescribeResult: """ - Asynchronously get details about a prompt by slug, optionally specifying version and tag. + Synchronously get details about a prompt by slug, optionally specifying version and tag. Args: slug (str): The slug identifier for the prompt. @@ -306,7 +258,7 @@ async def async_describe( tag=tag ) - err, result = await self._api.async_invoke(DescribePromptEndpoint, dto) + err, result = self._api.invoke_sync(DescribePromptEndpoint, dto) if err is None: prompt = result.prompt @@ -323,10 +275,19 @@ async def async_describe( return err, None - def list(self, feature_slug: Optional[str] = None) -> ListResult: + async def list(self, feature_slug: Optional[str] = None) -> ListResult: + """ + List prompts, optionally filtering by feature_slug. + + Args: + feature_slug (Optional[str]): Optional feature slug to filter prompts by. + + Returns: + Tuple[Optional[Exception], Optional[List[PromptListResponse]]]: A tuple containing an optional exception and an optional list of PromptListResponse objects. + """ dto = PromptListDTO(featureSlug=feature_slug) - err, result = self._api.invoke(ListPromptsEndpoint, dto) + err, result = await self._api.invoke(ListPromptsEndpoint, dto) if err is not None: return err, None @@ -339,20 +300,20 @@ def list(self, feature_slug: Optional[str] = None) -> ListResult: available_versions=prompt.available_versions, available_tags=prompt.available_tags ) for prompt in result.prompts] - - async def async_list(self, feature_slug: Optional[str] = None) -> ListResult: + + def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: """ - Asynchronously list prompts, optionally filtering by feature_slug. - + Synchronously list prompts, optionally filtering by feature_slug. + Args: feature_slug (Optional[str]): Optional feature slug to filter prompts by. - + Returns: Tuple[Optional[Exception], Optional[List[PromptListResponse]]]: A tuple containing an optional exception and an optional list of PromptListResponse objects. """ dto = PromptListDTO(featureSlug=feature_slug) - err, result = await self._api.async_invoke(ListPromptsEndpoint, dto) + err, result = self._api.invoke_sync(ListPromptsEndpoint, dto) if err is not None: return err, None @@ -383,4 +344,4 @@ def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): systemText=replaced_system, version=prompt.version, model=prompt.model - ) \ No newline at end of file + ) diff --git a/basalt/utils/api.py b/basalt/utils/api.py index 9fd4bf2..c900967 100644 --- a/basalt/utils/api.py +++ b/basalt/utils/api.py @@ -1,7 +1,7 @@ +"""Module for interacting with the Basalt API.""" from typing import Dict, TypeVar, Optional, Tuple from .protocols import IEndpoint, INetworker, ILogger -import asyncio from .networker import Networker Input = TypeVar('Input') @@ -10,7 +10,7 @@ class Api: """ A class to interact with the Basalt API. - + Attributes: root_url (str): The root URL of the API. api_key (str): The API key for authentication. @@ -39,13 +39,13 @@ def __init__(self, root_url: str, networker: INetworker, api_key: str, sdk_versi if isinstance(networker, Networker): networker._logger = logger - def invoke( + async def invoke( self, endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None ) -> Tuple[Optional[Exception], Optional[Output]]: """ - Invoke an API endpoint with the given data transfer object (DTO). + Asynchronously invoke an API endpoint with the given data transfer object (DTO). Args: endpoint: The endpoint to be invoked. @@ -61,7 +61,7 @@ def invoke( request_info = endpoint.prepare_request(dto) # Fetch the result from the network using the prepared request information - error, result = self._network.fetch( + error, result = await self._network.fetch( self._root + request_info['path'], request_info['method'], request_info.get('body'), @@ -84,14 +84,14 @@ def _headers(self) -> Dict[str, str]: 'X-BASALT-SDK-TYPE': self._sdk_type, 'Content-Type': 'application/json' } - - async def async_invoke( + + def invoke_sync( self, endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None ) -> Tuple[Optional[Exception], Optional[Output]]: """ - Asynchronously invoke an API endpoint with the given data transfer object (DTO). + Synchronously invoke an API endpoint with the given data transfer object (DTO). Args: endpoint: The endpoint to be invoked. @@ -107,7 +107,7 @@ async def async_invoke( request_info = endpoint.prepare_request(dto) # Fetch the result from the network using the prepared request information - error, result = await self._network.async_fetch( + error, result = self._network.fetch_sync( self._root + request_info['path'], request_info['method'], request_info.get('body'), diff --git a/basalt/utils/flusher.py b/basalt/utils/flusher.py index ddcd730..eace5e1 100644 --- a/basalt/utils/flusher.py +++ b/basalt/utils/flusher.py @@ -1,12 +1,14 @@ -from typing import TYPE_CHECKING, Dict, Any, List +"""Module for flushing traces to the API.""" import json -from datetime import datetime +from typing import TYPE_CHECKING, Dict, Any + +from ..endpoints.monitor.send_trace import SendTraceEndpoint + if TYPE_CHECKING: from ..objects.trace import Trace from .protocols import IApi, ILogger -from ..endpoints.monitor.send_trace import SendTraceEndpoint class Flusher: """ @@ -19,10 +21,10 @@ def __init__(self, api: 'IApi', logger: 'ILogger'): def _trace_to_dict(self, trace: 'Trace') -> Dict[str, Any]: """ Convert a trace to a dictionary. - + Args: trace (Trace): The trace to convert. - + Returns: Dict[str, Any]: The trace as a dictionary. """ @@ -34,7 +36,7 @@ def _trace_to_dict(self, trace: 'Trace') -> Dict[str, Any]: "input": trace.input, "output": output, "ideal_output": trace.ideal_output, - "name": trace._name, + "name": trace.name, "start_time": trace.start_time.isoformat() if trace.start_time else None, "end_time": trace.end_time.isoformat() if trace.end_time else None, "user": trace.user, @@ -46,13 +48,14 @@ def _trace_to_dict(self, trace: 'Trace') -> Dict[str, Any]: "evaluationConfig": trace.evaluation_config } - def _log_to_dict(self, log: Any) -> Dict[str, Any]: + @staticmethod + def _log_to_dict(log: Any) -> Dict[str, Any]: """ Convert a log to a dictionary. - + Args: log (Any): The log to convert. - + Returns: Dict[str, Any]: The log as a dictionary. """ @@ -65,7 +68,7 @@ def _log_to_dict(self, log: Any) -> Dict[str, Any]: "id": log.id, "type": log.type, "ideal_output": log.ideal_output, - "name": log._name, + "name": log.name, "input": log.input, "output": output, "start_time": log.start_time.isoformat() if hasattr(log, 'start_time') and log.start_time else None, @@ -84,17 +87,50 @@ def _log_to_dict(self, log: Any) -> Dict[str, Any]: return base_dict - def flush_trace(self, trace: 'Trace') -> None: + async def flush_trace(self, trace: 'Trace') -> None: + """ + Flush a trace to the API asynchronously. + + Args: + trace (Trace): The trace to flush. + """ + try: + if not self._api: + self._logger.error("Cannot flush trace: no API instance available") + return None + + # Create an endpoint instance + endpoint = SendTraceEndpoint() + + # Convert trace to dictionary + trace_dict = self._trace_to_dict(trace) + + # Create the DTO with the trace dictionary + dto = {"trace": trace_dict} + + # Invoke the API with the endpoint and DTO + error, result = await self._api.invoke(endpoint, dto) + + if error: + self._logger.error(f"Failed to flush trace {trace.feature_slug}: {error}") + return None + + return result + + except Exception as e: + self._logger.error(f"Exception while flushing trace: {str(e)}") + + def flush_trace_sync(self, trace: 'Trace') -> None: """ - Flush a trace to the API. - + Flush a trace to the API synchronously. + Args: trace (Trace): The trace to flush. """ try: if not self._api: self._logger.error("Cannot flush trace: no API instance available") - return + return None # Create an endpoint instance endpoint = SendTraceEndpoint() @@ -106,11 +142,13 @@ def flush_trace(self, trace: 'Trace') -> None: dto = {"trace": trace_dict} # Invoke the API with the endpoint and DTO - error, result = self._api.invoke(endpoint, dto) + error, result = self._api.invoke_sync(endpoint, dto) if error: self._logger.error(f"Failed to flush trace {trace.feature_slug}: {error}") - return + return None + + return result except Exception as e: - self._logger.error(f"Exception while flushing trace: {str(e)}") \ No newline at end of file + self._logger.error(f"Exception while flushing trace: {str(e)}") diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index af87960..932f764 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -1,9 +1,9 @@ import requests import aiohttp -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Mapping from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized -from .protocols import INetworker, ILogger +from .protocols import INetworker class Networker(INetworker): """ @@ -13,16 +13,16 @@ class Networker(INetworker): def __init__(self): pass - def fetch( + async def fetch( self, url: str, method: str, - body = None, - headers = None, - params = None + body: Optional[Any] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[Mapping[str, str]] = None ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: """ - Fetch data from a given URL using the specified HTTP method. This method should never throw. + Fetch data from a given URL using the specified HTTP method. Args: url (str): The URL to fetch data from. @@ -37,84 +37,84 @@ def fetch( - (FetchError, None) """ try: - response = requests.request( - method, - url, - params=params, - json=body, - headers=headers - ) - - json_response = response.json() + async with aiohttp.ClientSession() as session: + async with session.request( + method, + url, + params=params, + json=body, + headers=headers + ) as response: + json_response = await response.json() - if response.status_code == 400: - return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None + if response.status == 400: + return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None - if response.status_code == 401: - return Unauthorized(json_response.get('error', 'Unauthorized')), None + if response.status == 401: + return Unauthorized(json_response.get('error', 'Unauthorized')), None - if response.status_code == 403: - return Forbidden(json_response.get('error', 'Forbidden')), None + if response.status == 403: + return Forbidden(json_response.get('error', 'Forbidden')), None - if response.status_code == 404: - return NotFound(json_response.get('error', 'Not Found')), None + if response.status == 404: + return NotFound(json_response.get('error', 'Not Found')), None - response.raise_for_status() + response.raise_for_status() - return None, json_response + return None, json_response except Exception as e: return NetworkBaseError(str(e)), None - - async def async_fetch( + + def fetch_sync( self, url: str, method: str, - body = None, - headers = None, - params = None + body: Optional[Any] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[Mapping[str, str]] = None ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: """ - Asynchronously fetch data from a given URL using the specified HTTP method. - + Synchronously fetch data from a given URL using the specified HTTP method. This method should never throw. + Args: url (str): The URL to fetch data from. method (str): The HTTP method to use (e.g., 'GET', 'POST'). body (Optional[Any]): The request payload to send (default is None). headers (Optional[Dict[str, str]]): The request headers to send (default is None). params (Optional[Dict[str, str]]): The query parameters to send (default is None). - + Returns: A result tuple (err, json_response), possible responses: - (None, json_response) - (FetchError, None) """ try: - async with aiohttp.ClientSession() as session: - async with session.request( - method, - url, - params=params, - json=body, - headers=headers - ) as response: - json_response = await response.json() - - if response.status == 400: - return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None - - if response.status == 401: - return Unauthorized(json_response.get('error', 'Unauthorized')), None - - if response.status == 403: - return Forbidden(json_response.get('error', 'Forbidden')), None - - if response.status == 404: - return NotFound(json_response.get('error', 'Not Found')), None - - response.raise_for_status() - - return None, json_response - + response = requests.request( + method, + url, + params=params, + json=body, + headers=headers + ) + + json_response = response.json() + + if response.status_code == 400: + return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None + + if response.status_code == 401: + return Unauthorized(json_response.get('error', 'Unauthorized')), None + + if response.status_code == 403: + return Forbidden(json_response.get('error', 'Forbidden')), None + + if response.status_code == 404: + return NotFound(json_response.get('error', 'Not Found')), None + + response.raise_for_status() + + return None, json_response + except Exception as e: return NetworkBaseError(str(e)), None diff --git a/basalt/utils/protocols.py b/basalt/utils/protocols.py index 58b16fb..8b094fa 100644 --- a/basalt/utils/protocols.py +++ b/basalt/utils/protocols.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal, List +from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal from .dtos import GetResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult from ..ressources.monitor.monitorsdk_types import IMonitorSDK @@ -15,10 +15,18 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: ... def decode_response(self, response: Any) -> Tuple[Optional[Exception], Optional[Output]]: ... class IApi(Protocol): - def invoke(self,endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None) -> Tuple[Optional[Exception], Optional[Output]]: ... + async def invoke(self,endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None) -> Tuple[Optional[Exception], Optional[Output]]: ... + def invoke_sync(self,endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None) -> Tuple[Optional[Exception], Optional[Output]]: ... class INetworker(Protocol): - def fetch(self, + async def fetch(self, + url: str, + method: str, + body: Optional[Any] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[Mapping[str, str]] = None + ) -> Tuple[Optional[Exception], Optional[Output]]: ... + def fetch_sync(self, url: str, method: str, body: Optional[Any] = None, @@ -27,14 +35,21 @@ def fetch(self, ) -> Tuple[Optional[Exception], Optional[Output]]: ... class IPromptSDK(Protocol): - def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... - def describe(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... - def list(self, feature_slug: Optional[str] = None) -> ListResult: ... + async def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... + def get_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... + async def describe(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... + def describe_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... + async def list(self, feature_slug: Optional[str] = None) -> ListResult: ... + def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: ... class IDatasetSDK(Protocol): - def list(self) -> ListDatasetsResult: ... - def get(self, slug: str) -> GetDatasetResult: ... - def addRow(self, slug: str, values: Dict[str, str], name: Optional[str] = None, + async def list(self) -> ListDatasetsResult: ... + def list_sync(self) -> ListDatasetsResult: ... + async def get(self, slug: str) -> GetDatasetResult: ... + def get_sync(self, slug: str) -> GetDatasetResult: ... + async def addRow(self, slug: str, values: Dict[str, str], name: Optional[str] = None, + ideal_output: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> CreateDatasetItemResult: ... + def addRow_sync(self, slug: str, values: Dict[str, str], name: Optional[str] = None, ideal_output: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> CreateDatasetItemResult: ... class IBasaltSDK(Protocol): From 771f069abc986e42fd430ef6c43b53e55bad3602 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Sun, 5 Oct 2025 15:55:27 +0200 Subject: [PATCH 2/7] Jinja and typing review --- basalt/objects/base_log.py | 4 +- basalt/objects/generation.py | 21 +- basalt/objects/prompt.py | 94 +++++ basalt/objects/trace.py | 29 +- basalt/ressources/monitor/base_log_types.py | 28 +- basalt/ressources/monitor/evaluator_types.py | 8 +- basalt/ressources/monitor/experiment_types.py | 3 +- basalt/ressources/monitor/generation_types.py | 3 +- basalt/ressources/monitor/log_types.py | 22 +- basalt/ressources/monitor/monitorsdk_types.py | 11 +- basalt/ressources/monitor/trace_types.py | 47 ++- basalt/ressources/prompts/__init__.py | 1 + basalt/ressources/prompts/prompt_types.py | 83 ++++ basalt/sdk/datasetsdk.py | 4 +- basalt/sdk/monitorsdk.py | 22 +- basalt/sdk/promptsdk.py | 140 +++---- basalt/utils/dtos.py | 59 +-- basalt/utils/flusher.py | 2 +- basalt/utils/networker.py | 46 +- basalt/utils/protocols.py | 10 +- basalt/utils/utils.py | 22 +- examples/dataset_sdk_async_demo.ipynb | 163 -------- examples/dataset_sdk_demo.ipynb | 393 ++++++++---------- examples/monitor_sdk_async_demo.ipynb | 255 ------------ examples/monitor_sdk_demo.ipynb | 302 +++++++------- examples/prompt_sdk_async_demo.ipynb | 32 +- requirements.txt | 3 + setup.py | 1 + tests/test_datasetsdk_async.py | 46 +- 29 files changed, 792 insertions(+), 1062 deletions(-) create mode 100644 basalt/objects/prompt.py create mode 100644 basalt/ressources/prompts/__init__.py create mode 100644 basalt/ressources/prompts/prompt_types.py delete mode 100644 examples/dataset_sdk_async_demo.ipynb delete mode 100644 examples/monitor_sdk_async_demo.ipynb diff --git a/basalt/objects/base_log.py b/basalt/objects/base_log.py index 1d65941..c5656e6 100644 --- a/basalt/objects/base_log.py +++ b/basalt/objects/base_log.py @@ -2,10 +2,10 @@ from typing import Dict, Optional, Any, List import uuid -from ..ressources.monitor.base_log_types import BaseLogParams +from ..ressources.monitor.base_log_types import BaseLogParams, LogType from ..ressources.monitor.evaluator_types import Evaluator from ..ressources.monitor.trace_types import Trace -from ..ressources.monitor.log_types import Log, LogType +from ..ressources.monitor.log_types import Log class BaseLog: diff --git a/basalt/objects/generation.py b/basalt/objects/generation.py index 76b15b4..3bd8ce2 100644 --- a/basalt/objects/generation.py +++ b/basalt/objects/generation.py @@ -1,9 +1,9 @@ +from datetime import datetime from typing import Dict, Optional, Any, List, Union from .base_log import BaseLog from ..ressources.monitor.generation_types import GenerationParams -from ..ressources.monitor.base_log_types import BaseLogParams -from ..ressources.monitor.log_types import LogType +from ..ressources.monitor.base_log_types import BaseLogParams, LogType class Generation(BaseLog): @@ -11,9 +11,19 @@ class Generation(BaseLog): Class representing a generation in the monitoring system. """ def __init__(self, params: GenerationParams): - params_with_type = BaseLogParams(**params.__dict__, type=LogType.GENERATION) - - super().__init__(params_with_type) + base_log_params = BaseLogParams( + name=params.name, + ideal_output=params.ideal_output, + start_time=params.start_time, + end_time=params.end_time, + metadata=params.metadata, + parent=params.parent, + trace=params.trace, + evaluators=params.evaluators, + type=LogType.GENERATION, + ) + + super().__init__(base_log_params) self._prompt = params.prompt self._input = params.input @@ -109,6 +119,7 @@ def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generatio Generation: The generation instance. """ super().end() + self._end_time = datetime.now() if isinstance(output, dict): self.update(output) diff --git a/basalt/objects/prompt.py b/basalt/objects/prompt.py new file mode 100644 index 0000000..2c39f50 --- /dev/null +++ b/basalt/objects/prompt.py @@ -0,0 +1,94 @@ +from typing import Dict, Optional, Any +from jinja2 import Template, Environment, meta + +from ..ressources.prompts.prompt_types import PromptParams, PromptModel + + +class Prompt: + """ + Class representing a prompt in the Basalt system. + """ + def __init__(self, params: PromptParams): + self._slug = params.slug + self._text = params.text + self._system_text = params.system_text + self._version = params.version + self._tag = params.tag + self._model = params.model + self._raw_text = params.text + self._raw_system_text = params.system_text + self._variables = params.variables + + if params.variables is not None: + self.compile_variables(params.variables) + + @property + def slug(self) -> str: + """Get the prompt slug.""" + return self._slug + + @property + def text(self) -> str: + """Get the prompt text.""" + return self._text + + @property + def system_text(self) -> Optional[str]: + """Get the prompt system text.""" + return self._system_text + + @property + def version(self) -> str: + """Get the prompt version.""" + return self._version + + @property + def tag(self) -> Optional[str]: + """Get the prompt tag.""" + return self._tag + + @property + def model(self) -> PromptModel: + """Get the prompt model configuration.""" + return self._model + + @property + def raw_text(self) -> str: + """Get the original prompt text before variable replacement.""" + return self._raw_text + + @property + def variables(self) -> Optional[Dict[str, str]]: + """Get the prompt variables.""" + return self._variables + + @property + def raw_system_text(self) -> Optional[str]: + """Get the original system text before variable replacement.""" + return self._raw_system_text + + def compile_variables(self, variables: Dict[str, Any]) -> 'Prompt': + """Compile the prompt variables.""" + self._variables = variables + + self._text = Template(self._raw_text).render(variables) + + undeclared_variable=self._find_undeclared_variables(self._text) + + if self._raw_system_text: + self._system_text = Template(self._raw_system_text).render(variables) + undeclared_variable = undeclared_variable | self._find_undeclared_variables(self._system_text) + + if undeclared_variable: + print("undeclared variables:", undeclared_variable) + + return self + + @staticmethod + def _find_undeclared_variables(template: str) -> set[str]: + env = Environment() + ast = env.parse(template) + variables = meta.find_undeclared_variables(ast) + + return variables + diff --git a/basalt/objects/trace.py b/basalt/objects/trace.py index 38ae882..62c80fc 100644 --- a/basalt/objects/trace.py +++ b/basalt/objects/trace.py @@ -1,7 +1,6 @@ from datetime import datetime from typing import Dict, Optional, Any, List - from ..ressources.monitor.trace_types import TraceParams from .base_log import BaseLog from .generation import Generation @@ -20,32 +19,30 @@ class Trace: def __init__(self, feature_slug: str, params: TraceParams, flusher: 'Flusher', logger: 'ILogger'): self._feature_slug = feature_slug - self._input = params.input - self._output = params.output - self._ideal_output = params.ideal_output - self._name = params.name - self._start_time = params.start_time if params.start_time else datetime.now() - self._end_time = params.end_time - self._user = params.user - self._organization = params.organization - self._metadata = params.metadata + self._input = params.get('input') + self._output = params.get("output") + self._ideal_output = params.get("ideal_output") + self._name = params.get("name") + self._start_time = params.get("start_time", datetime.now()) + self._end_time = params.get("end_time") + self._user = params.get("user") + self._organization = params.get("organization") + self._metadata = params.get("metadata") self._logs: List['BaseLog'] = [] self._flusher = flusher self._is_ended = False - self._evaluators = params.evaluators - self._evaluation_config = params.evaluation_config + self._evaluators = params.get("evaluators") + self._evaluation_config = params.get("evaluation_config") self._logger = logger self._experiment = None if "experiment" in params: - experiment = params.experiment - if experiment is None: - self._logger.warn("Warning: Experiment is None. This experiment will be ignored.") - elif experiment.feature_slug != self._feature_slug: + experiment = params["experiment"] + if experiment.feature_slug != self._feature_slug: self._logger.warn("Warning: Experiment feature slug does not match trace feature slug. This experiment will be ignored.") else: self._experiment = experiment diff --git a/basalt/ressources/monitor/base_log_types.py b/basalt/ressources/monitor/base_log_types.py index e088fb1..7d63b06 100644 --- a/basalt/ressources/monitor/base_log_types.py +++ b/basalt/ressources/monitor/base_log_types.py @@ -2,12 +2,32 @@ from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING from dataclasses import dataclass, field from uuid import uuid4 +from enum import Enum + +if TYPE_CHECKING: + from .log_types import Log from .evaluator_types import Evaluator +from .trace_types import Trace -if TYPE_CHECKING: - from .trace_types import Trace - from .log_types import Log, LogType + +class LogType(Enum): + """Enum-like class for log types. + + Attributes: + SPAN: Represents a span log type + GENERATION: Represents a generation log type + FUNCTION: Represents a function log type + TOOL: Represents a tool log type + RETRIEVAL: Represents a retrieval log type + EVENT: Represents an event log type + """ + SPAN = 'span' + GENERATION = 'generation' + FUNCTION = 'function' + TOOL = 'tool' + RETRIEVAL = 'retrieval' + EVENT = 'event' @dataclass class BaseLogParams: @@ -99,7 +119,7 @@ def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog': """ ... - def update(self, **params) -> 'BaseLog': + def update(self, params: Dict[str, Any]) -> 'BaseLog': """Updates the log with new parameters. Args: diff --git a/basalt/ressources/monitor/evaluator_types.py b/basalt/ressources/monitor/evaluator_types.py index 237d376..ed93912 100644 --- a/basalt/ressources/monitor/evaluator_types.py +++ b/basalt/ressources/monitor/evaluator_types.py @@ -1,16 +1,16 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, TypedDict @dataclass -class Evaluator: +class Evaluator(TypedDict): """ Represents an evaluator configuration. """ slug: str @dataclass -class EvaluationConfig: +class EvaluationConfig(TypedDict, total=False): """ Configuration for the evaluation of the trace and its logs. """ - sample_rate: Optional[float] = None + sample_rate: float diff --git a/basalt/ressources/monitor/experiment_types.py b/basalt/ressources/monitor/experiment_types.py index 5aafc19..6fd084d 100644 --- a/basalt/ressources/monitor/experiment_types.py +++ b/basalt/ressources/monitor/experiment_types.py @@ -1,8 +1,9 @@ +from typing import TypedDict from dataclasses import dataclass from datetime import datetime @dataclass -class ExperimentParams: +class ExperimentParams(TypedDict): """Parameters for creating an experiment.""" name: str diff --git a/basalt/ressources/monitor/generation_types.py b/basalt/ressources/monitor/generation_types.py index 7dbc48d..ae3f175 100644 --- a/basalt/ressources/monitor/generation_types.py +++ b/basalt/ressources/monitor/generation_types.py @@ -1,5 +1,4 @@ -from numbers import Number -from typing import Dict, Optional, Union, Any, List +from typing import Dict, Optional, Union, Any from dataclasses import dataclass, field from .base_log_types import BaseLog, BaseLogParams, LogType diff --git a/basalt/ressources/monitor/log_types.py b/basalt/ressources/monitor/log_types.py index 5266ccd..7b51dac 100644 --- a/basalt/ressources/monitor/log_types.py +++ b/basalt/ressources/monitor/log_types.py @@ -1,31 +1,11 @@ -from enum import Enum from typing import Optional, TYPE_CHECKING from dataclasses import dataclass -from .base_log_types import BaseLog, BaseLogParams +from .base_log_types import BaseLog, BaseLogParams, LogType if TYPE_CHECKING: from .generation_types import Generation, GenerationParams - -class LogType(Enum): - """Enum-like class for log types. - - Attributes: - SPAN: Represents a span log type - GENERATION: Represents a generation log type - FUNCTION: Represents a function log type - TOOL: Represents a tool log type - RETRIEVAL: Represents a retrieval log type - EVENT: Represents an event log type - """ - SPAN = 'span' - GENERATION = 'generation' - FUNCTION = 'function' - TOOL = 'tool' - RETRIEVAL = 'retrieval' - EVENT = 'event' - @dataclass class LogParams(BaseLogParams): """Parameters for creating or updating a log. diff --git a/basalt/ressources/monitor/monitorsdk_types.py b/basalt/ressources/monitor/monitorsdk_types.py index ed9ef65..8d2ff74 100644 --- a/basalt/ressources/monitor/monitorsdk_types.py +++ b/basalt/ressources/monitor/monitorsdk_types.py @@ -1,4 +1,4 @@ -from typing import Protocol, Optional, Tuple +from typing import Protocol, Optional, Tuple, TYPE_CHECKING from .trace_types import TraceParams from .experiment_types import ExperimentParams from .experiment_types import Experiment @@ -6,6 +6,9 @@ from .generation_types import GenerationParams, Generation from .log_types import LogParams, Log +if TYPE_CHECKING: + from ...utils.dtos import CreateExperimentResult + class IMonitorSDK(Protocol): """Interface for interacting with Basalt monitoring. @@ -38,7 +41,7 @@ class IMonitorSDK(Protocol): ``` """ - def create_trace(self, slug: str, params: Optional[TraceParams] = None) -> Trace: + def create_trace(self, slug: str, params: TraceParams = {}) -> Trace: """Creates a new trace to monitor a complete user interaction or process flow. Args: @@ -151,7 +154,7 @@ def create_log(self, params: LogParams) -> Log: """ ... - async def create_experiment(self, feature_slug: str, params: ExperimentParams) -> Tuple[Optional[Exception], Optional[Experiment]]: + async def create_experiment(self, feature_slug: str, params: ExperimentParams) -> 'CreateExperimentResult': """Creates a new experiment to bundle multiple traces together in. You can pass this experiment to the create_trace method to add the generated traces to the experiment. @@ -175,7 +178,7 @@ async def create_experiment(self, feature_slug: str, params: ExperimentParams) - """ ... - def create_experiment_sync(self, feature_slug: str, params: ExperimentParams) -> Tuple[Optional[Exception], Optional[Experiment]]: + def create_experiment_sync(self, feature_slug: str, params: ExperimentParams) -> 'CreateExperimentResult': """Synchronously creates a new experiment to bundle multiple traces together in. You can pass this experiment to the create_trace method to add the generated traces to the experiment. diff --git a/basalt/ressources/monitor/trace_types.py b/basalt/ressources/monitor/trace_types.py index b57a9af..85d9ad4 100644 --- a/basalt/ressources/monitor/trace_types.py +++ b/basalt/ressources/monitor/trace_types.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional, Any, List, TYPE_CHECKING +from typing import Dict, Optional, Any, List, TYPE_CHECKING, TypedDict from dataclasses import dataclass, field from .experiment_types import Experiment from .evaluator_types import Evaluator, EvaluationConfig @@ -10,35 +10,36 @@ from .base_log_types import BaseLog @dataclass -class User: +class User(TypedDict): """User information associated with a trace.""" id: str name: str @dataclass -class Organization: +class Organization(TypedDict): """Organization information associated with a trace.""" id: str name: str @dataclass -class TraceParams: +class TraceParams(TypedDict, total=False): """Parameters for creating or updating a trace.""" - name: Optional[str] = None - input: Optional[str] = None - output: Optional[str] = None - ideal_output: Optional[str] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - user: Optional[User] = None - organization: Optional[Organization] = None - metadata: Optional[Dict[str, Any]] = None - experiment: Optional['Experiment'] = None - evaluators: Optional[List[Evaluator]] = None - evaluation_config: Optional[EvaluationConfig] = None + name: str + input: str + output: str + ideal_output: str + start_time: datetime + end_time: datetime + user: User + organization: Organization + metadata: Dict[str, Any] + experiment: Experiment + evaluators: List[Evaluator] + evaluation_config: EvaluationConfig + @dataclass -class Trace(TraceParams): +class Trace: """A trace represents a complete user interaction or process flow and serves as the top-level container for all monitoring activities. A trace provides methods to create and manage spans and generations within the process flow. @@ -74,8 +75,18 @@ class Trace(TraceParams): trace.end('Paris is the capital of France.') ``` """ - + name: Optional[str] + input: Optional[str] + output: Optional[str] + ideal_output: Optional[str] start_time: datetime + end_time: Optional[datetime] + user: Optional[User] + organization: Optional[Organization] + metadata: Optional[Dict[str, Any]] + experiment: Optional['Experiment'] + evaluators: Optional[List[Evaluator]] + evaluation_config: Optional[EvaluationConfig] logs: List['BaseLog'] = field(default_factory=list) def start(self, input: Optional[str] = None) -> 'Trace': diff --git a/basalt/ressources/prompts/__init__.py b/basalt/ressources/prompts/__init__.py new file mode 100644 index 0000000..6e39eeb --- /dev/null +++ b/basalt/ressources/prompts/__init__.py @@ -0,0 +1 @@ +"""Prompt types module for Basalt SDK""" diff --git a/basalt/ressources/prompts/prompt_types.py b/basalt/ressources/prompts/prompt_types.py new file mode 100644 index 0000000..2a39a7d --- /dev/null +++ b/basalt/ressources/prompts/prompt_types.py @@ -0,0 +1,83 @@ +""" +Prompt types module for Basalt SDK +""" +from dataclasses import dataclass +from typing import Optional, Dict, Any + + +@dataclass +class PromptModelParameters: + """Model parameters for a prompt""" + temperature: float + max_length: int + response_format: str + top_k: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + json_object: Optional[dict] = None + + +@dataclass +class PromptModel: + """Model configuration for a prompt""" + provider: str + model: str + version: str + parameters: PromptModelParameters + + +@dataclass +class PromptParams: + """Parameters for creating a new prompt""" + slug: str + text: str + model: PromptModel + version: str + system_text: Optional[str] = None + tag: Optional[str] = None + variables: Optional[Dict[str, Any]] = None + + +@dataclass +class Prompt: + """ + Prompt class representing a prompt template in the Basalt system. + + This class represents a prompt template that can be used for AI model generations. + + Example: + ```python + # Get a prompt + error, prompt = basalt.prompts.get( + slug="qa-prompt", + version="2.1.0", + variables={"context": "Paris is the capital of France"} + ) + + # Access prompt properties + print(prompt.text) + print(prompt.model.provider) + ``` + """ + slug: str + text: str + raw_text: str + model: PromptModel + version: str + system_text: Optional[str] = None + raw_system_text: Optional[str] = None + variables: Optional[Dict[str, str]] = None + tag: Optional[str] = None + + def compile_variables(self, variables: Dict[str, Any]) -> 'Prompt': + """ + Compile the prompt variables and render the text and system_text templates. + + Args: + variables (Dict[str, Any]): A dictionary of variables to render into the prompt templates. + + Returns: + Prompt: The updated Prompt instance with rendered text and system_text. + """ + ... diff --git a/basalt/sdk/datasetsdk.py b/basalt/sdk/datasetsdk.py index 0f33eeb..5d85283 100644 --- a/basalt/sdk/datasetsdk.py +++ b/basalt/sdk/datasetsdk.py @@ -8,13 +8,13 @@ ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult, DatasetDTO, DatasetRowDTO ) -from ..utils.protocols import IApi, ILogger +from ..utils.protocols import IApi, ILogger, IDatasetSDK from ..endpoints.list_datasets import ListDatasetsEndpoint from ..endpoints.get_dataset import GetDatasetEndpoint from ..endpoints.create_dataset_item import CreateDatasetItemEndpoint -class DatasetSDK: +class DatasetSDK(IDatasetSDK): """ SDK for interacting with Basalt datasets. """ diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index 6ea80a6..49c8c44 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -1,6 +1,7 @@ -from typing import Dict, Optional, Any, Tuple +from typing import Dict, Optional, Any from ..utils.protocols import IApi, ILogger +from ..utils.dtos import CreateExperimentResult from ..ressources.monitor.trace_types import TraceParams from ..ressources.monitor.experiment_types import ExperimentParams from ..ressources.monitor.generation_types import GenerationParams @@ -28,7 +29,7 @@ async def create_experiment( self, feature_slug: str, params: ExperimentParams - ) -> Tuple[Optional[Exception], Optional[Experiment]]: + ) -> CreateExperimentResult: """ Asynchronously creates a new experiment for monitoring. @@ -45,7 +46,7 @@ def create_experiment_sync( self, feature_slug: str, params: ExperimentParams - ) -> Tuple[Optional[Exception], Optional[Experiment]]: + ) -> CreateExperimentResult: """ Synchronously creates a new experiment for monitoring. @@ -73,12 +74,7 @@ def create_trace( Returns: Trace: A new Trace instance. """ - if params is None: - params = {} - - trace_params = TraceParams(**params) - - return self._create_trace(slug, trace_params) + return self._create_trace(slug, params if params else {}) def create_generation( self, @@ -116,7 +112,7 @@ async def _create_experiment( self, feature_slug: str, params: ExperimentParams - ) -> Tuple[Optional[Exception], Optional[Experiment]]: + ) -> CreateExperimentResult: """ Internal async implementation for creating an experiment. @@ -129,7 +125,7 @@ async def _create_experiment( """ dto = CreateExperimentDTO( feature_slug=feature_slug, - name=params.name, + name=params['name'], ) # Call the API endpoint @@ -144,7 +140,7 @@ def _create_experiment_sync( self, feature_slug: str, params: ExperimentParams - ) -> Tuple[Optional[Exception], Optional[Experiment]]: + ) -> CreateExperimentResult: """ Internal sync implementation for creating an experiment. @@ -157,7 +153,7 @@ def _create_experiment_sync( """ dto = CreateExperimentDTO( feature_slug=feature_slug, - name=params.name, + name=params['name'], ) # Call the API endpoint diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index abbc27f..7cd5a88 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -1,16 +1,17 @@ -from typing import Optional, Dict, Tuple, Any +from typing import Optional, Dict, Tuple, cast from ..ressources.monitor.generation_types import GenerationParams, PromptReference from ..ressources.monitor.trace_types import TraceParams -from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, DescribeResult, ListResult, PromptListResponse, PromptListDTO +from ..ressources.prompts.prompt_types import Prompt as IPrompt, PromptParams +from ..utils.dtos import GetPromptDTO, GetPromptResult, PromptResponse, DescribePromptResponse, DescribePromptDTO, DescribeResult, ListResult, PromptListResponse, PromptListDTO from ..utils.protocols import ICache, IApi, ILogger from ..endpoints.get_prompt import GetPromptEndpoint from ..endpoints.describe_prompt import DescribePromptEndpoint from ..endpoints.list_prompts import ListPromptsEndpoint -from ..utils.utils import replace_variables from ..objects.trace import Trace from ..objects.generation import Generation +from ..objects.prompt import Prompt from ..utils.flusher import Flusher from datetime import datetime @@ -38,9 +39,9 @@ async def get( slug: str, version: Optional[str] = None, tag: Optional[str] = None, - variables: Dict[str, str] = {}, + variables: Optional[Dict[str, str]] = None, cache_enabled: bool = True - ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + ) -> GetPromptResult: """ Retrieve a prompt by slug, optionally specifying version and tag. @@ -55,6 +56,7 @@ async def get( Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. """ + dto = GetPromptDTO( slug=slug, version=version, @@ -64,29 +66,32 @@ async def get( cached = self._cache.get(dto) if cache_enabled else None if cached: - original_prompt_text = cached.text - err, prompt_response = self._replace_vars(cached, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + prompt_response = cast(PromptResponse, cached) + prompt = self._create_prompt_instance(prompt_response, variables) + generation = self._prepare_monitoring(prompt) + + return None, prompt, generation err, result = await self._api.invoke(GetPromptEndpoint, dto) if err is None: - original_prompt_text = result.prompt.text + prompt = self._create_prompt_instance(result.prompt, variables) + self._cache.put(dto, result.prompt, self._cache_duration) - self._fallback_cache.put(dto, result.prompt) + self._fallback_cache.put(dto, result.prompt, self._cache_duration) - err, prompt_response = self._replace_vars(result.prompt, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + generation = self._prepare_monitoring(prompt) + + return err, prompt, generation fallback = self._fallback_cache.get(dto) if cache_enabled else None if fallback: - original_prompt_text = fallback.text - err, prompt_response = self._replace_vars(fallback, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + prompt_response = cast(PromptResponse, fallback) + prompt = self._create_prompt_instance(prompt_response, variables) + generation = self._prepare_monitoring(prompt) + + return None, prompt, generation return err, None, None @@ -95,11 +100,11 @@ def get_sync( slug: str, version: Optional[str] = None, tag: Optional[str] = None, - variables: Dict[str, str] = {}, + variables: Optional[Dict[str, str]] = None, cache_enabled: bool = True - ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + ) -> GetPromptResult: """ - Synchronously retrieve a prompt by slug, optionally specifying version and tag. + Retrieve a prompt by slug, optionally specifying version and tag. Args: slug (str): The slug identifier for the prompt. @@ -121,51 +126,43 @@ def get_sync( cached = self._cache.get(dto) if cache_enabled else None if cached: - original_prompt_text = cached.text - err, prompt_response = self._replace_vars(cached, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + prompt_response = cast(PromptResponse, cached) + prompt = self._create_prompt_instance(prompt_response, variables) + generation = self._prepare_monitoring(prompt) + + return None, prompt, generation err, result = self._api.invoke_sync(GetPromptEndpoint, dto) if err is None: - original_prompt_text = result.prompt.text + prompt = self._create_prompt_instance(result.prompt, variables) + self._cache.put(dto, result.prompt, self._cache_duration) - self._fallback_cache.put(dto, result.prompt) + self._fallback_cache.put(dto, result.prompt, self._cache_duration) + + generation = self._prepare_monitoring(prompt) - err, prompt_response = self._replace_vars(result.prompt, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + return err, prompt, generation fallback = self._fallback_cache.get(dto) if cache_enabled else None if fallback: - original_prompt_text = fallback.text - err, prompt_response = self._replace_vars(fallback, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) - return err, prompt_response, generation + prompt_response = cast(PromptResponse, fallback) + prompt = self._create_prompt_instance(prompt_response, variables) + generation = self._prepare_monitoring(prompt) + + return None, prompt, generation return err, None, None - def _prepare_monitoring( - self, - prompt: PromptResponse, - slug: str, - version: Optional[str] = None, - tag: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - original_prompt_text: Optional[str] = None + def _prepare_monitoring(self,prompt: IPrompt, ) -> Generation: """ Prepare monitoring by creating a trace and generation object. Args: prompt (PromptResponse): The prompt response. - slug (str): The slug identifier for the prompt. - version (Optional[str]): The version of the prompt. - tag (Optional[str]): The tag associated with the prompt. - variables (Optional[Dict[str, Any]]): Variables used in the prompt. - original_prompt_text (Optional[str]): The original prompt text. + prompt (Prompt): The prompt Returns: Generation: The generation object. @@ -174,23 +171,22 @@ def _prepare_monitoring( flusher = Flusher(self._api, self._logger) # Create a trace - trace = Trace(slug, TraceParams( - input=original_prompt_text or prompt.text, + trace = Trace(prompt.slug, TraceParams( + input=prompt.text, start_time=datetime.now() ), flusher, self._logger) # Create a generation generation = Generation(GenerationParams( - name=slug, + name=prompt.slug, trace=trace, prompt=PromptReference( - slug=slug, - version=version, - tag=tag + slug=prompt.slug, + version=prompt.version, + tag=prompt.tag ), - input=original_prompt_text or prompt.text, - variables=variables, - options={"type": "single"} + input=prompt.text, + variables=prompt.variables )) return generation @@ -327,21 +323,17 @@ def list_sync(self, feature_slug: Optional[str] = None) -> ListResult: available_tags=prompt.available_tags ) for prompt in result.prompts] - def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): - missing_vars, replaced = replace_variables(prompt.text, variables) - missing_system_vars, replaced_system = replace_variables(prompt.systemText or "", variables) - - if missing_vars: - self._logger.warn(f"""Basalt Warning: Some variables are missing in the prompt text: - {", ".join(map(str, missing_vars))}""") - - if missing_system_vars: - self._logger.warn(f"""Basalt Warning: Some variables are missing in the prompt systemText: - {", ".join(map(str, missing_system_vars))}""") - - return None, PromptResponse( - text=replaced, - systemText=replaced_system, - version=prompt.version, - model=prompt.model - ) + @staticmethod + def _create_prompt_instance( + prompt_response: PromptResponse, + variables: Optional[dict] = None + ) -> Prompt: + return Prompt(PromptParams( + slug=prompt_response.slug, + text=prompt_response.text, + tag=prompt_response.tag, + model=prompt_response.model, + version=prompt_response.version, + system_text=prompt_response.systemText, + variables=variables + )) diff --git a/basalt/utils/dtos.py b/basalt/utils/dtos.py index 59a9326..45b83fc 100644 --- a/basalt/utils/dtos.py +++ b/basalt/utils/dtos.py @@ -2,6 +2,8 @@ from typing import Optional, Dict, Any, List, Tuple from ..ressources.monitor.generation_types import Generation +from ..ressources.monitor.experiment_types import Experiment +from ..ressources.prompts.prompt_types import Prompt from .utils import pick_typed, pick_number @@ -51,13 +53,17 @@ def from_dict(cls, data: Dict[str, Any]): @dataclass(frozen=True) class PromptResponse: text: str + slug: str + version: str + tag: str model: PromptModel systemText: str - version: str @classmethod def from_dict(cls, data: Dict[str, Any]): return cls( + slug=pick_typed(data, "slug", str), + tag=pick_typed(data, "tag", str), text=pick_typed(data, "text", str), model=PromptModel.from_dict(data.get("model")), systemText=pick_typed(data, "systemText", str), @@ -70,30 +76,30 @@ class GetPromptDTO: tag: Optional[str] = None version: Optional[str] = None -GetResult = Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]] +GetPromptResult = Tuple[Optional[Exception], Optional[Prompt], Optional[Generation]] # ------------------------------ Describe Prompt ----------------------------- # @dataclass(frozen=True) class DescribePromptResponse: - slug: str - status: str - name: str - description: str - available_versions: List[str] - available_tags: List[str] - variables: List[Dict[str, str]] - - @classmethod - def from_dict(cls, data: Dict[str, Any]): - return cls( - slug=pick_typed(data, "slug", str) if data.get("slug") else None, - status=pick_typed(data, "status", str), - name=pick_typed(data, "name", str), - description=pick_typed(data, "description", str) if data.get("description") else None, - available_versions=pick_typed(data, "availableVersions", list), - available_tags=pick_typed(data, "availableTags", list), - variables=pick_typed(data, "variables", list), - ) + slug: str + status: str + name: str + description: str + available_versions: List[str] + available_tags: List[str] + variables: List[Dict[str, str]] + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + return cls( + slug=pick_typed(data, "slug", str) if data.get("slug") else None, + status=pick_typed(data, "status", str), + name=pick_typed(data, "name", str), + description=pick_typed(data, "description", str) if data.get("description") else None, + available_versions=pick_typed(data, "availableVersions", list), + available_tags=pick_typed(data, "availableTags", list), + variables=pick_typed(data, "variables", list), + ) @dataclass(frozen=True) class DescribePromptDTO: @@ -140,7 +146,7 @@ class DatasetDTO: name: str columns: List[str] rows: List['DatasetRowDTO'] = None - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DatasetDTO": return cls( @@ -158,7 +164,7 @@ class DatasetRowDTO: name: Optional[str] = None idealOutput: Optional[str] = None metadata: Dict[str, Any] = None - + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DatasetRowDTO": return cls( @@ -167,7 +173,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "DatasetRowDTO": idealOutput=data.get("idealOutput", None), metadata=data.get("metadata", {}) ) - + @dataclass class ListDatasetsDTO: @@ -194,4 +200,7 @@ class CreateDatasetItemDTO: # Result types for dataset operations ListDatasetsResult = Tuple[Optional[Exception], Optional[List[DatasetDTO]]] GetDatasetResult = Tuple[Optional[Exception], Optional[DatasetDTO]] -CreateDatasetItemResult = Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]] \ No newline at end of file +CreateDatasetItemResult = Tuple[Optional[Exception], Optional[DatasetRowDTO], Optional[str]] + +# Result types for monitor operations +CreateExperimentResult = Tuple[Optional[Exception], Optional[Experiment]] diff --git a/basalt/utils/flusher.py b/basalt/utils/flusher.py index eace5e1..fcd44a6 100644 --- a/basalt/utils/flusher.py +++ b/basalt/utils/flusher.py @@ -66,7 +66,7 @@ def _log_to_dict(log: Any) -> Dict[str, Any]: base_dict = { "id": log.id, - "type": log.type, + "type": log.type.value if hasattr(log.type, 'value') else log.type, "ideal_output": log.ideal_output, "name": log.name, "input": log.input, diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index 932f764..f8c4aad 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -37,15 +37,35 @@ async def fetch( - (FetchError, None) """ try: - async with aiohttp.ClientSession() as session: + # Filter out None values from params and headers + filtered_params = {k: v for k, v in params.items() if v is not None} if params else None + filtered_headers = {k: v for k, v in headers.items() if v is not None} if headers else None + + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False)) as session: async with session.request( method, url, - params=params, + params=filtered_params, json=body, - headers=headers + headers=filtered_headers ) as response: - json_response = await response.json() + # Try to parse JSON response, but handle cases where there's no JSON body + json_response = None + content_type = response.headers.get('Content-Type', '') + if content_type and 'application/json' in content_type: + try: + json_response = await response.json() + except Exception: + json_response = {} + elif response.status not in [202, 204]: + # For non-202/204 responses without JSON content-type, try to parse anyway + try: + json_response = await response.json() + except Exception: + json_response = {} + else: + # For 202 (Accepted) or 204 (No Content), an empty body is expected + json_response = {} if response.status == 400: return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None @@ -98,7 +118,23 @@ def fetch_sync( headers=headers ) - json_response = response.json() + # Try to parse JSON response, but handle cases where there's no JSON body + json_response = None + content_type = response.headers.get('Content-Type', '') + if content_type and 'application/json' in content_type: + try: + json_response = response.json() + except Exception: + json_response = {} + elif response.status_code not in [202, 204]: + # For non-202/204 responses without JSON content-type, try to parse anyway + try: + json_response = response.json() + except Exception: + json_response = {} + else: + # For 202 (Accepted) or 204 (No Content), an empty body is expected + json_response = {} if response.status_code == 400: return BadRequest(json_response.get('error', json_response.get('errors', 'Bad Request'))), None diff --git a/basalt/utils/protocols.py b/basalt/utils/protocols.py index 8b094fa..652e74a 100644 --- a/basalt/utils/protocols.py +++ b/basalt/utils/protocols.py @@ -1,5 +1,5 @@ from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping, Literal -from .dtos import GetResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult +from .dtos import GetPromptResult, DescribeResult, ListResult, ListDatasetsResult, GetDatasetResult, CreateDatasetItemResult, CreateExperimentResult from ..ressources.monitor.monitorsdk_types import IMonitorSDK @@ -35,8 +35,8 @@ def fetch_sync(self, ) -> Tuple[Optional[Exception], Optional[Output]]: ... class IPromptSDK(Protocol): - async def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... - def get_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... + async def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Optional[Dict[str, str]] = None, cache_enabled: bool = True) -> GetPromptResult: ... + def get_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Optional[Dict[str, str]] = None, cache_enabled: bool = True) -> GetPromptResult: ... async def describe(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... def describe_sync(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... async def list(self, feature_slug: Optional[str] = None) -> ListResult: ... @@ -47,9 +47,9 @@ async def list(self) -> ListDatasetsResult: ... def list_sync(self) -> ListDatasetsResult: ... async def get(self, slug: str) -> GetDatasetResult: ... def get_sync(self, slug: str) -> GetDatasetResult: ... - async def addRow(self, slug: str, values: Dict[str, str], name: Optional[str] = None, + async def add_row(self, slug: str, values: Dict[str, str], name: Optional[str] = None, ideal_output: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> CreateDatasetItemResult: ... - def addRow_sync(self, slug: str, values: Dict[str, str], name: Optional[str] = None, + def add_row_sync(self, slug: str, values: Dict[str, str], name: Optional[str] = None, ideal_output: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> CreateDatasetItemResult: ... class IBasaltSDK(Protocol): diff --git a/basalt/utils/utils.py b/basalt/utils/utils.py index 51113ca..74fbcc0 100644 --- a/basalt/utils/utils.py +++ b/basalt/utils/utils.py @@ -1,22 +1,6 @@ from re import sub as reg_replace from typing import Tuple, Set, Dict, Any -def replace_variables(template: str, replacements: dict) -> Tuple[Set[str], str]: - missing_keys = set([]) - - count = lambda key: missing_keys.add(key) - get = lambda key: replacements.get(key) if replacements.get(key) else count(key) - replacer = lambda match: str(get(match.group(1))) if get(match.group(1)) else match.group(0) - - replaced = reg_replace( - r'{{(.*?)}}', - replacer, - template - ) - - return missing_keys, replaced - - def pick_typed(dict: Dict[str, Any], field_name: str, expected_type: Any) -> Any: value = dict.get(field_name) @@ -31,12 +15,12 @@ def pick_typed(dict: Dict[str, Any], field_name: str, expected_type: Any) -> Any def pick_number(dict: Dict[str, Any], field_name: str) -> float: value = dict.get(field_name) - + if isinstance(value, float): return float(value) # Additional check for int, because isinstance(True, int) == True if isinstance(value, bool) == False and isinstance(value, int): return int(value) - - raise TypeError(f"Field '{field_name}' must be a number (int or float), got {type(value).__name__}.") \ No newline at end of file + + raise TypeError(f"Field '{field_name}' must be a number (int or float), got {type(value).__name__}.") diff --git a/examples/dataset_sdk_async_demo.ipynb b/examples/dataset_sdk_async_demo.ipynb deleted file mode 100644 index cc0d9ec..0000000 --- a/examples/dataset_sdk_async_demo.ipynb +++ /dev/null @@ -1,163 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Basalt DatasetSDK Async Demo\n", - "\n", - "This notebook demonstrates the asynchronous functionality of the DatasetSDK in the Basalt Python SDK." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", - "\n", - "os.environ[\"BASALT_BUILD\"] = \"development\"\n", - "\n", - "from basalt import Basalt\n", - "\n", - "# Initialize the SDK\n", - "basalt = Basalt(\n", - " api_key=\"sk-f50...\", # Replace with your API key\n", - " log_level=\"debug\" # Optional: Set log level\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 1: Asynchronously List All Datasets\n", - "\n", - "This example demonstrates how to list all datasets asynchronously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def list_datasets():\n", - " print(\"Listing all datasets asynchronously...\")\n", - " err, datasets = await basalt.datasets.async_list()\n", - " if err:\n", - " print(f\"Error listing datasets: {err}\")\n", - " else:\n", - " print(f\"Found {len(datasets)} datasets\")\n", - " for dataset in datasets:\n", - " print(f\"- {dataset.name} (slug: {dataset.slug})\")\n", - " return datasets\n", - "\n", - "# Run the async function\n", - "datasets = await list_datasets()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 2: Asynchronously Get a Specific Dataset\n", - "\n", - "This example demonstrates how to retrieve a specific dataset by its slug." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def get_dataset(datasets):\n", - " print(\"\\nGetting a specific dataset asynchronously...\")\n", - " if len(datasets) > 0:\n", - " sample_dataset = datasets[0]\n", - " err, dataset = await basalt.datasets.async_get(sample_dataset.slug)\n", - " if err:\n", - " print(f\"Error getting dataset: {err}\")\n", - " else:\n", - " print(f\"Retrieved dataset: {dataset.name}\")\n", - " print(f\"Columns: {dataset.columns}\")\n", - " print(f\"Number of rows: {len(dataset.rows) if dataset.rows else 0}\")\n", - " return sample_dataset, dataset\n", - " else:\n", - " print(\"No datasets available\")\n", - " return None, None\n", - "\n", - "# Run the async function\n", - "sample_dataset, dataset = await get_dataset(datasets)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 3: Asynchronously Add a Row to a Dataset\n", - "\n", - "This example demonstrates how to add a new row to an existing dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def add_row(sample_dataset):\n", - " print(\"\\nAdding a row to a dataset asynchronously...\")\n", - " if sample_dataset:\n", - " # Create some sample values for the dataset row\n", - " values = {column: f\"Sample {column} value\" for column in sample_dataset.columns}\n", - " \n", - " err, row, warning = await basalt.datasets.async_addRow(\n", - " slug=sample_dataset.slug,\n", - " values=values,\n", - " name=\"Async Sample Row\",\n", - " ideal_output=\"Expected output for this row\",\n", - " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error adding row to dataset: {err}\")\n", - " elif warning:\n", - " print(f\"Row added with warning: {warning}\")\n", - " print(f\"Row values: {row.values}\")\n", - " else:\n", - " print(f\"Row added successfully\")\n", - " print(f\"Row values: {row.values}\")\n", - " print(f\"Row name: {row.name}\")\n", - "\n", - "# Run the async function\n", - "await add_row(sample_dataset)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb index e732583..6a5bc49 100644 --- a/examples/dataset_sdk_demo.ipynb +++ b/examples/dataset_sdk_demo.ipynb @@ -1,232 +1,163 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Basalt Dataset SDK Demo\n", - "\n", - "This notebook demonstrates how to use the Basalt Dataset SDK to interact with your Basalt datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", - "\n", - "os.environ[\"BASALT_BUILD\"] = \"development\"\n", - "\n", - "from basalt import Basalt\n", - "\n", - "# Initialize the SDK\n", - "basalt = Basalt(\n", - " api_key=\"sk-...\", # Replace with your API key\n", - " log_level=\"debug\" # Optional: Set log level\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Listing Available Datasets\n", - "\n", - "Retrieve all datasets available in your workspace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# List all datasets in the workspace\n", - "err, datasets = basalt.datasets.list()\n", - "\n", - "if err:\n", - " print(f\"Error listing datasets: {err}\")\n", - "else:\n", - " print(f\"Found {len(datasets)} datasets:\")\n", - " for i, dataset in enumerate(datasets):\n", - " print(f\"{i+1}. {dataset.name} (slug: {dataset.slug})\")\n", - " print(f\" - Columns: {', '.join(dataset.columns)}\")\n", - " \n", - " # Store the first dataset slug for later use (if available)\n", - " first_dataset_slug = datasets[0].slug if datasets else None" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Getting a Specific Dataset\n", - "\n", - "Retrieve details for a specific dataset using its slug." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the first dataset from the list or replace with a specific slug\n", - "dataset_slug = first_dataset_slug if 'first_dataset_slug' in locals() and first_dataset_slug else \"your-dataset-slug\"\n", - "\n", - "err, dataset = basalt.datasets.get(dataset_slug)\n", - "\n", - "if err:\n", - " print(f\"Error getting dataset: {err}\")\n", - "else:\n", - " print(f\"Dataset details for '{dataset.name}'\")\n", - " print(f\"Slug: {dataset.slug}\")\n", - " print(f\"Columns: {', '.join(dataset.columns)}\")\n", - " print(f\"Number of rows: {len(dataset.rows)}\")\n", - " \n", - " if dataset.rows:\n", - " print(\"\\nSample rows:\")\n", - " for i, row in enumerate(dataset.rows[:3]): # Show up to 3 rows\n", - " print(f\"Row {i+1}:\")\n", - " print(f\" Values: {row.get('values')}\")\n", - " if 'name' in row:\n", - " print(f\" Name: {row['name']}\")\n", - " if 'idealOutput' in row:\n", - " print(f\" Ideal output: {row['idealOutput']}\")\n", - " if 'metadata' in row:\n", - " print(f\" Metadata: {row['metadata']}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Adding a Row to a Dataset\n", - "\n", - "Create a new row (item) in an existing dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use the dataset from the previous example\n", - "if 'dataset' in locals() and dataset:\n", - " # Build values for all columns in the dataset\n", - " values = {}\n", - " for column in dataset.columns:\n", - " values[column] = f\"Example value for {column} - {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\n", - " \n", - " # Create the row\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=dataset.slug,\n", - " values=values,\n", - " name=\"Notebook Example Row\",\n", - " ideal_output=\"This is an ideal output for this row\",\n", - " metadata={\"source\": \"Jupyter notebook example\", \"timestamp\": __import__('datetime').datetime.now().isoformat()}\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " else:\n", - " print(\"Successfully created new dataset row:\")\n", - " print(f\"Values: {row.values}\")\n", - " print(f\"Name: {row.name}\")\n", - " print(f\"Ideal output: {row.idealOutput}\")\n", - " print(f\"Metadata: {row.metadata}\")\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - "else:\n", - " print(\"Please run the previous cell to get a dataset first\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Error Handling with Dataset SDK\n", - "\n", - "Demonstrate proper error handling when working with datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def safely_add_dataset_row(slug, values, name=None, ideal_output=None, metadata=None):\n", - " \"\"\"Safely add a row to a dataset with robust error handling\"\"\"\n", - " try:\n", - " err, row, warning = basalt.datasets.addRow(\n", - " slug=slug,\n", - " values=values,\n", - " name=name,\n", - " ideal_output=ideal_output,\n", - " metadata=metadata\n", - " )\n", - " \n", - " if err:\n", - " print(f\"Error creating dataset row: {err}\")\n", - " return None\n", - " \n", - " if warning:\n", - " print(f\"Warning: {warning}\")\n", - " \n", - " return row\n", - " except Exception as e:\n", - " print(f\"Unexpected error: {str(e)}\")\n", - " return None\n", - "\n", - "# Test with a valid dataset\n", - "if 'dataset_slug' in locals() and dataset_slug:\n", - " values = {\"input\": \"Test input\", \"output\": \"Test output\"}\n", - " row = safely_add_dataset_row(dataset_slug, values, name=\"Error Handling Test\")\n", - " \n", - " if row:\n", - " print(f\"Successfully created row: {row.name}\")\n", - "\n", - "# Test with an invalid dataset slug\n", - "print(\"\\nTesting with invalid dataset slug:\")\n", - "invalid_row = safely_add_dataset_row(\"non-existent-dataset\", {\"input\": \"Test input\"})\n", - "print(f\"Result with invalid slug: {invalid_row}\")\n", - "\n", - "# Test with missing required values\n", - "if 'dataset' in locals() and dataset and len(dataset.columns) > 0:\n", - " print(\"\\nTesting with missing required values:\")\n", - " # Deliberately create incomplete values dict\n", - " incomplete_values = {column: \"value\" for column in list(dataset.columns)[1:]} if len(dataset.columns) > 1 else {}\n", - " incomplete_row = safely_add_dataset_row(dataset.slug, incomplete_values)\n", - " print(f\"Result with incomplete values: {incomplete_row}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basalt DatasetSDK Demo\n", + "\n", + "This notebook demonstrates the asynchronous functionality of the DatasetSDK in the Basalt Python SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "os.environ[\"BASALT_BUILD\"] = \"development\"\n", + "\n", + "from basalt import Basalt\n", + "\n", + "# Initialize the SDK\n", + "basalt = Basalt(\n", + " api_key=\"sk-f50...\", # Replace with your API key\n", + " log_level=\"debug\" # Optional: Set log level\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1: List All Datasets\n", + "\n", + "This example demonstrates how to list all datasets asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def list_datasets():\n", + " print(\"Listing all datasets asynchronously...\")\n", + " err, datasets = await basalt.datasets.list()\n", + " if err:\n", + " print(f\"Error listing datasets: {err}\")\n", + " else:\n", + " print(f\"Found {len(datasets)} datasets\")\n", + " for dataset in datasets:\n", + " print(f\"- {dataset.name} (slug: {dataset.slug})\")\n", + " return datasets\n", + "\n", + "# Run the async function\n", + "datasets = await list_datasets()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2: Get a Specific Dataset\n", + "\n", + "This example demonstrates how to retrieve a specific dataset by its slug." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def get_dataset(datasets):\n", + " print(\"\\nGetting a specific dataset asynchronously...\")\n", + " if len(datasets) > 0:\n", + " sample_dataset = datasets[0]\n", + " err, dataset = await basalt.datasets.get(sample_dataset.slug)\n", + " if err:\n", + " print(f\"Error getting dataset: {err}\")\n", + " else:\n", + " print(f\"Retrieved dataset: {dataset.name}\")\n", + " print(f\"Columns: {dataset.columns}\")\n", + " print(f\"Number of rows: {len(dataset.rows) if dataset.rows else 0}\")\n", + " return sample_dataset, dataset\n", + " else:\n", + " print(\"No datasets available\")\n", + " return None, None\n", + "\n", + "# Run the async function\n", + "sample_dataset, dataset = await get_dataset(datasets)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3: Add a Row to a Dataset\n", + "\n", + "This example demonstrates how to add a new row to an existing dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def add_row(sample_dataset):\n", + " print(\"\\nAdding a row to a dataset asynchronously...\")\n", + " if sample_dataset:\n", + " # Create some sample values for the dataset row\n", + " values = {column: f\"Sample {column} value\" for column in sample_dataset.columns}\n", + "\n", + " err, row, warning = await basalt.datasets.add_row(\n", + " slug=sample_dataset.slug,\n", + " values=values,\n", + " name=\"Async Sample Row\",\n", + " ideal_output=\"Expected output for this row\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + "\n", + " if err:\n", + " print(f\"Error adding row to dataset: {err}\")\n", + " elif warning:\n", + " print(f\"Row added with warning: {warning}\")\n", + " print(f\"Row values: {row.values}\")\n", + " else:\n", + " print(f\"Row added successfully\")\n", + " print(f\"Row values: {row.values}\")\n", + " print(f\"Row name: {row.name}\")\n", + "\n", + "# Run the async function\n", + "await add_row(sample_dataset)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/monitor_sdk_async_demo.ipynb b/examples/monitor_sdk_async_demo.ipynb deleted file mode 100644 index d764660..0000000 --- a/examples/monitor_sdk_async_demo.ipynb +++ /dev/null @@ -1,255 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Basalt MonitorSDK Async Demo\n", - "\n", - "This notebook demonstrates the asynchronous functionality of the MonitorSDK in the Basalt Python SDK." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", - "\n", - "os.environ[\"BASALT_BUILD\"] = \"development\"\n", - "\n", - "from basalt import Basalt\n", - "from basalt.ressources.monitor.monitorsdk_types import (\n", - " ExperimentParams, TraceParams, GenerationParams, LogParams\n", - ")\n", - "\n", - "# Initialize the SDK\n", - "basalt = Basalt(\n", - " api_key=\"sk-f50...\", # Replace with your API key\n", - " log_level=\"debug\" # Optional: Set log level\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 1: Asynchronously Create a Trace\n", - "\n", - "This example demonstrates how to create a trace asynchronously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def create_trace():\n", - " print(\"Creating a trace asynchronously...\")\n", - " trace_params = TraceParams(\n", - " name=\"Async Test Trace\",\n", - " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", - " )\n", - " \n", - " trace = await basalt.monitor.async_create_trace(\n", - " slug=\"async-test-trace\",\n", - " params=trace_params\n", - " )\n", - " \n", - " print(f\"Created trace: {trace.id}\")\n", - " print(f\"Trace name: {trace.name}\")\n", - " print(f\"Trace slug: {trace.slug}\")\n", - " \n", - " return trace\n", - "\n", - "# Run the async function\n", - "trace = await create_trace()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 2: Asynchronously Create a Generation\n", - "\n", - "This example demonstrates how to create a generation associated with a trace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def create_generation(trace):\n", - " print(\"\\nCreating a generation asynchronously...\")\n", - " gen_params = GenerationParams(\n", - " trace_id=trace.id,\n", - " text=\"This is an async test generation\",\n", - " model_id=\"gpt-4\",\n", - " prompt=\"Generate a response asynchronously\",\n", - " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", - " )\n", - " \n", - " generation = await basalt.monitor.async_create_generation(gen_params)\n", - " \n", - " print(f\"Created generation: {generation.id}\")\n", - " print(f\"Generation text: {generation.text}\")\n", - " print(f\"Generation model: {generation.model_id}\")\n", - " \n", - " return generation\n", - "\n", - "# Run the async function\n", - "generation = await create_generation(trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 3: Asynchronously Create a Log\n", - "\n", - "This example demonstrates how to create a log entry associated with a trace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def create_log(trace):\n", - " print(\"\\nCreating a log asynchronously...\")\n", - " log_params = LogParams(\n", - " trace_id=trace.id,\n", - " type=\"info\",\n", - " message=\"This is an async test log message\",\n", - " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", - " )\n", - " \n", - " log = await basalt.monitor.async_create_log(log_params)\n", - " \n", - " print(f\"Created log: {log['id']}\")\n", - " print(f\"Log message: {log['message']}\")\n", - " print(f\"Log type: {log['type']}\")\n", - " \n", - " return log\n", - "\n", - "# Run the async function\n", - "log = await create_log(trace)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 4: Asynchronously Create an Experiment\n", - "\n", - "This example demonstrates how to create an experiment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def create_experiment():\n", - " print(\"\\nCreating an experiment asynchronously...\")\n", - " exp_params = ExperimentParams(\n", - " type=\"A/B Test\",\n", - " name=\"Async Test Experiment\",\n", - " setup={\n", - " \"control_id\": \"control-prompt\",\n", - " \"variation_id\": \"test-prompt\"\n", - " }\n", - " )\n", - " \n", - " experiment = await basalt.monitor.async_create_experiment(\n", - " \"async-test-feature\",\n", - " exp_params\n", - " )\n", - " \n", - " print(f\"Created experiment: {experiment.id}\")\n", - " print(f\"Experiment name: {experiment.name}\")\n", - " print(f\"Experiment type: {experiment.type}\")\n", - " \n", - " return experiment\n", - "\n", - "# Run the async function\n", - "experiment = await create_experiment()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example 5: Execute Multiple Monitoring Operations Concurrently\n", - "\n", - "This example demonstrates how to execute multiple asynchronous monitoring operations concurrently for better performance." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "async def execute_concurrent_operations():\n", - " print(\"\\nExecuting multiple monitoring operations concurrently...\")\n", - " \n", - " # Create trace parameters for concurrent operations\n", - " trace_params1 = TraceParams(\n", - " name=\"Concurrent Trace 1\",\n", - " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 1}\n", - " )\n", - " \n", - " trace_params2 = TraceParams(\n", - " name=\"Concurrent Trace 2\",\n", - " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 2}\n", - " )\n", - " \n", - " # Create multiple async tasks\n", - " tasks = [\n", - " basalt.monitor.async_create_trace(\"concurrent-trace-1\", trace_params1),\n", - " basalt.monitor.async_create_trace(\"concurrent-trace-2\", trace_params2)\n", - " ]\n", - " \n", - " # Execute all tasks concurrently\n", - " results = await asyncio.gather(*tasks)\n", - " \n", - " print(f\"Completed {len(tasks)} operations concurrently\")\n", - " print(f\"First trace: {results[0].name} (id: {results[0].id})\")\n", - " print(f\"Second trace: {results[1].name} (id: {results[1].id})\")\n", - "\n", - "# Run the async function\n", - "await execute_concurrent_operations()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/monitor_sdk_demo.ipynb b/examples/monitor_sdk_demo.ipynb index a5d7233..18d7ce6 100644 --- a/examples/monitor_sdk_demo.ipynb +++ b/examples/monitor_sdk_demo.ipynb @@ -4,39 +4,42 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Basalt Monitor SDK Demo\n", + "# Basalt MonitorSDK Demo\n", "\n", - "This notebook demonstrates how to use the Basalt Monitor SDK to track and monitor your AI application's execution." + "This notebook demonstrates the asynchronous functionality of the MonitorSDK in the Basalt Python SDK." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", "\n", "os.environ[\"BASALT_BUILD\"] = \"development\"\n", "\n", "from basalt import Basalt\n", + "from basalt.ressources.monitor.monitorsdk_types import (\n", + " ExperimentParams, TraceParams, GenerationParams, LogParams\n", + ")\n", "\n", "# Initialize the SDK\n", "basalt = Basalt(\n", - "\tapi_key=\"sk-d4ef...\", # Replace with your API key\n", - "\tlog_level=\"debug\" # Optional: Set log level\n", - ") " + " api_key=\"sk-f50...\", # Replace with your API key\n", + " log_level=\"debug\" # Optional: Set log level\n", + ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Creating a Basic Trace\n", + "## Example 1: Asynchronously Create a Trace\n", "\n", - "A trace represents a complete execution flow in your application. Let's create a simple trace:" + "This example demonstrates how to create a trace asynchronously." ] }, { @@ -45,27 +48,35 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a trace\n", - "trace = basalt.monitor.create_trace(\n", - " \"slug\", # Chain slug - identifies this type of workflow\n", - " {\n", - " \"input\": \"What are the benefits of AI in healthcare?\",\n", - " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"metadata\": {\"source\": \"web\", \"priority\": \"high\"}\n", - " }\n", - ")\n", + "def create_trace():\n", + " print(\"Creating a trace asynchronously...\")\n", + " trace_params = TraceParams(\n", + " name=\"Async Test Trace\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + "\n", + " trace = basalt.monitor.create_trace(\n", + " slug=\"async-test-trace\",\n", + " params=trace_params\n", + " )\n", + "\n", + " print(f\"Created trace: {trace.id}\")\n", + " print(f\"Trace name: {trace.name}\")\n", + " print(f\"Trace slug: {trace.slug}\")\n", "\n", - "print(f\"Created trace with input: {trace.input}\")" + " return trace\n", + "\n", + "# Run the async function\n", + "trace = create_trace()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Adding Logs to a Trace\n", + "## Example 2: Create a Generation\n", "\n", - "Logs represent individual steps or operations within a trace:" + "This example demonstrates how to create a generation associated with a trace." ] }, { @@ -74,34 +85,35 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a log for content moderation\n", - "moderation_log = trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"content-moderation\",\n", - " \"input\": trace.input,\n", - " \"metadata\": {\"model\": \"text-moderation-latest\"},\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - "\t\t\"metadata\": {\"source\": \"web\", \"priority\": \"high\"}\n", - "})\n", - "\n", - "# Simulate moderation check\n", - "moderation_result = {\"flagged\": False, \"categories\": [], \"scores\": {}}\n", - "\n", - "# Update and end the log\n", - "moderation_log.update({\"metadata\": {\"completed\": True}})\n", - "moderation_log.end(moderation_result)\n", - "\n", - "print(f\"Completed moderation check: {moderation_log.output}\")" + "def create_generation(trace):\n", + " print(\"\\nCreating a generation asynchronously...\")\n", + " gen_params = GenerationParams(\n", + " trace_id=trace.id,\n", + " text=\"This is an async test generation\",\n", + " model_id=\"gpt-4\",\n", + " prompt=\"Generate a response asynchronously\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + "\n", + " generation = basalt.monitor.create_generation(gen_params)\n", + "\n", + " print(f\"Created generation: {generation.id}\")\n", + " print(f\"Generation text: {generation.text}\")\n", + " print(f\"Generation model: {generation.model_id}\")\n", + "\n", + " return generation\n", + "\n", + "# Run the async function\n", + "generation = create_generation(trace)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Creating and Managing Generations\n", + "## Example 3: Create a Log\n", "\n", - "Generations are special types of logs specifically for AI model interactions:" + "This example demonstrates how to create a log entry associated with a trace." ] }, { @@ -110,61 +122,34 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a log for the main processing\n", - "main_log = trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"main-processing\",\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": trace.input\n", - "})\n", - "\n", - "# Create a generation within the main log found in Basalt\n", - "generation = main_log.create_generation({\n", - " \"name\": \"healthcare-benefits-generation\",\n", - " \"prompt\": {\n", - " \"slug\": \"prompt-slug\", # This tells the SDK to fetch the prompt from Basalt\n", - " \"version\": \"0.1\" # This specifies the version to use\n", - " },\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", - "\n", - "# Create a generation within the main log not managed in Basalt\n", - "generation = main_log.create_generation({\n", - " \"name\": \"healthcare-benefits-generation\",\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": trace.input\n", - "})\n", - "\n", - "# Simulate AI response\n", - "ai_response = \"\"\"\n", - "AI in healthcare offers numerous benefits:\n", - "1. Early disease detection through advanced imaging analysis\n", - "2. Personalized treatment recommendations\n", - "3. Automated administrative tasks\n", - "4. Enhanced drug discovery process\n", - "5. Improved patient monitoring\n", - "\"\"\"\n", - "\n", - "# End the generation with the response\n", - "generation.end(ai_response)\n", - "\n", - "# End the main log\n", - "main_log.end(ai_response)\n", - "\n", - "trace.end(\"End of trace\")\n", - "\n", - "print(f\"Generated response:\\n{generation.output}\")" + "def create_log(trace):\n", + " print(\"\\nCreating a log asynchronously...\")\n", + " log_params = LogParams(\n", + " trace_id=trace.id,\n", + " type=\"info\",\n", + " message=\"This is an async test log message\",\n", + " metadata={\"source\": \"async_example\", \"type\": \"demo\"}\n", + " )\n", + "\n", + " log = basalt.monitor.create_log(log_params)\n", + "\n", + " print(f\"Created log: {log['id']}\")\n", + " print(f\"Log message: {log['message']}\")\n", + " print(f\"Log type: {log['type']}\")\n", + "\n", + " return log\n", + "\n", + "# Run the async function\n", + "log = create_log(trace)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Complex Workflow Example\n", + "## Example 4: Create an Experiment\n", "\n", - "Here's a more complex example showing nested logs and multiple generations:" + "This example demonstrates how to create an experiment." ] }, { @@ -173,74 +158,83 @@ "metadata": {}, "outputs": [], "source": [ - "# Create a new trace for a complex workflow\n", - "complex_trace = basalt.monitor.create_trace(\n", - " \"theo-slug\",\n", - " {\n", - " \"input\": \"Patient presents with frequent headaches and fatigue.\",\n", - "\t\t\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"}\n", - " }\n", - ")\n", + "async def create_experiment():\n", + " print(\"\\nCreating an experiment asynchronously...\")\n", + " exp_params = ExperimentParams(\n", + " type=\"A/B Test\",\n", + " name=\"Async Test Experiment\",\n", + " setup={\n", + " \"control_id\": \"control-prompt\",\n", + " \"variation_id\": \"test-prompt\"\n", + " }\n", + " )\n", + "\n", + " err, experiment = await basalt.monitor.create_experiment(\n", + " \"async-test-feature\",\n", + " exp_params\n", + " )\n", + "\n", + " print(f\"Created experiment: {experiment.id}\")\n", + " print(f\"Experiment name: {experiment.name}\")\n", + "\n", + " return experiment\n", + "\n", + "# Run the async function\n", + "experiment = await create_experiment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 5: Execute Multiple Monitoring Operations Concurrently\n", + "\n", + "This example demonstrates how to execute multiple asynchronous monitoring operations concurrently for better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "async def execute_concurrent_operations():\n", + " print(\"\\nExecuting multiple monitoring operations concurrently...\")\n", + "\n", + " # Create trace parameters for concurrent operations\n", + " trace_params1 = TraceParams(\n", + " name=\"Concurrent Trace 1\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 1}\n", + " )\n", + "\n", + " trace_params2 = TraceParams(\n", + " name=\"Concurrent Trace 2\",\n", + " metadata={\"source\": \"async_concurrent_example\", \"trace_number\": 2}\n", + " )\n", + "\n", + " # Create multiple async tasks\n", + " tasks = [\n", + " basalt.monitor.create_trace(\"concurrent-trace-1\", trace_params1),\n", + " basalt.monitor.create_trace(\"concurrent-trace-2\", trace_params2)\n", + " ]\n", + "\n", + " # Execute all tasks concurrently\n", + " results = await asyncio.gather(*tasks)\n", + "\n", + " print(f\"Completed {len(tasks)} operations concurrently\")\n", + " print(f\"First trace: {results[0].name} (id: {results[0].id})\")\n", + " print(f\"Second trace: {results[1].name} (id: {results[1].id})\")\n", "\n", - "# Initial analysis log\n", - "analysis_log = complex_trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"symptom-analysis\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - " \"input\": complex_trace.input\n", - "})\n", - "\n", - "# Generate initial analysis\n", - "analysis_gen = analysis_log.create_generation({\n", - " \"name\": \"symptom-classification\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", - "analysis_gen.end(\"Primary symptoms suggest possible migraine or chronic fatigue syndrome\")\n", - "\n", - "# Create a nested log for recommendations\n", - "recommendations_log = analysis_log.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"treatment-recommendations\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": analysis_gen.output\n", - "})\n", - "\n", - "# Generate treatment recommendations\n", - "treatment_gen = recommendations_log.create_generation({\n", - " \"name\": \"treatment-suggestions\",\n", - " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", - "\n", - "treatment_response = \"\"\"\n", - "Recommended treatments:\n", - "1. Schedule neurological examination\n", - "2. Keep headache diary for pattern recognition\n", - "3. Consider sleep study for fatigue assessment\n", - "4. Initial blood work to rule out underlying conditions\n", - "\"\"\"\n", - "treatment_gen.end(treatment_response)\n", - "\n", - "# End all logs\n", - "recommendations_log.end(treatment_response)\n", - "analysis_log.end(analysis_gen.output)\n", - "complex_trace.end(\"End of main trace\")\n", - "\n", - "print(\"Completed medical report analysis workflow\")\n", - "print(f\"Analysis: {analysis_gen.output}\")\n", - "print(f\"Recommendations: {treatment_gen.output}\")" + "# Run the async function\n", + "await execute_concurrent_operations()" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -254,7 +248,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.9.0" } }, "nbformat": 4, diff --git a/examples/prompt_sdk_async_demo.ipynb b/examples/prompt_sdk_async_demo.ipynb index eaaad5a..24e6bbf 100644 --- a/examples/prompt_sdk_async_demo.ipynb +++ b/examples/prompt_sdk_async_demo.ipynb @@ -34,7 +34,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Example 1: Asynchronously List All Prompts\n", + "## Example 1: List All Prompts\n", "\n", "This example demonstrates how to list all prompts asynchronously." ] @@ -47,7 +47,7 @@ "source": [ "async def list_prompts():\n", " print(\"Listing all prompts asynchronously...\")\n", - " err, prompts = await basalt.prompt.async_list()\n", + " err, prompts = await basalt.prompt.list()\n", " if err:\n", " print(f\"Error listing prompts: {err}\")\n", " else:\n", @@ -64,7 +64,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Example 2: Asynchronously Get a Specific Prompt\n", + "## Example 2: Get a Specific Prompt\n", "\n", "This example demonstrates how to retrieve a specific prompt by its slug." ] @@ -79,7 +79,7 @@ " print(\"\\nGetting a specific prompt asynchronously...\")\n", " if len(prompts) > 0:\n", " sample_prompt = prompts[0]\n", - " err, prompt_response, generation = await basalt.prompt.async_get(sample_prompt.slug)\n", + " err, prompt_response, generation = await basalt.prompt.get(sample_prompt.slug)\n", " if err:\n", " print(f\"Error getting prompt: {err}\")\n", " else:\n", @@ -112,13 +112,13 @@ "async def describe_prompt(sample_prompt):\n", " print(\"\\nDescribing a prompt asynchronously...\")\n", " if sample_prompt:\n", - " err, description = await basalt.prompt.async_describe(sample_prompt.slug)\n", + " err, description = await basalt.prompt.describe(sample_prompt.slug)\n", " if err:\n", " print(f\"Error describing prompt: {err}\")\n", " else:\n", - " print(f\"Prompt: {description.prompt.name}\")\n", - " print(f\"Versions available: {len(description.versions)}\")\n", - " for version in description.versions:\n", + " print(f\"Prompt: {description.name}\")\n", + " print(f\"Versions available: {len(description.available_versions)}\")\n", + " for version in description.available_versions:\n", " print(f\"- Version {version.version} created at {version.created_at}\")\n", " return description\n", " else:\n", @@ -133,7 +133,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Example 4: Asynchronously Get a Prompt with Variable Substitution\n", + "## Example 4: Get a Prompt with Variable Substitution\n", "\n", "This example demonstrates how to retrieve a prompt with variables substituted." ] @@ -147,7 +147,7 @@ "async def get_prompt_with_variables(sample_prompt):\n", " print(\"\\nGetting a prompt with variable substitution asynchronously...\")\n", " if sample_prompt:\n", - " err, prompt_response, generation = await basalt.prompt.async_get(\n", + " err, prompt_response, generation = await basalt.prompt.get(\n", " sample_prompt.slug,\n", " variables={\"name\": \"John\", \"company\": \"Acme Inc\"}\n", " )\n", @@ -180,19 +180,21 @@ "metadata": {}, "outputs": [], "source": [ + "import asyncio\n", + "\n", "async def execute_concurrent_operations(prompts):\n", " print(\"\\nExecuting multiple prompt operations concurrently...\")\n", " if len(prompts) >= 2:\n", " # Create multiple async tasks\n", " tasks = [\n", - " basalt.prompt.async_get(prompts[0].slug),\n", - " basalt.prompt.async_get(prompts[1].slug),\n", - " basalt.prompt.async_list()\n", + " basalt.prompt.get(prompts[0].slug),\n", + " basalt.prompt.get(prompts[1].slug),\n", + " basalt.prompt.list()\n", " ]\n", - " \n", + "\n", " # Execute all tasks concurrently\n", " results = await asyncio.gather(*tasks)\n", - " \n", + "\n", " print(f\"Completed {len(tasks)} operations concurrently\")\n", " print(f\"First prompt: {results[0][1].slug if results[0][1] else 'Error'}\")\n", " print(f\"Second prompt: {results[1][1].slug if results[1][1] else 'Error'}\")\n", diff --git a/requirements.txt b/requirements.txt index e69de29..e1e4ff4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,3 @@ +requests +aiohttp +jinja2 diff --git a/setup.py b/setup.py index eb0a52a..496765a 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ def get_version(): install_requires=[ "requests>=2.32", "aiohttp>=3.8.0", + "jinja2>=3.1.0", ], python_requires=">=3.6" ) diff --git a/tests/test_datasetsdk_async.py b/tests/test_datasetsdk_async.py index bd1fc2d..05b9473 100644 --- a/tests/test_datasetsdk_async.py +++ b/tests/test_datasetsdk_async.py @@ -70,109 +70,109 @@ def setUp(self): ) # Reset mock calls before each test mocked_api.async_invoke.reset_mock() - + async def test_async_list_datasets(self): """Test asynchronously listing all datasets""" # Configure mock mocked_api.async_invoke.return_value = (None, dataset_list_response) - + # Call the method err, datasets = await self.dataset_sdk.async_list() - + # Assertions self.assertIsNone(err) self.assertEqual(len(datasets), 2) self.assertEqual(datasets[0].slug, "test-dataset") self.assertEqual(datasets[0].name, "Test Dataset") self.assertEqual(datasets[1].slug, "another-dataset") - + # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] self.assertEqual(endpoint, ListDatasetsEndpoint) - + async def test_async_get_dataset(self): """Test asynchronously getting a dataset by slug""" # Configure mock mocked_api.async_invoke.return_value = (None, dataset_get_response) - + # Call the method err, dataset = await self.dataset_sdk.async_get("test-dataset") - + # Assertions self.assertIsNone(err) self.assertEqual(dataset.slug, "test-dataset") self.assertEqual(dataset.name, "Test Dataset") self.assertEqual(len(dataset.columns), 2) self.assertEqual(len(dataset.rows), 1) - + # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] self.assertEqual(endpoint, GetDatasetEndpoint) - + # Verify DTO was created correctly dto = mocked_api.async_invoke.call_args[0][1] self.assertEqual(dto.slug, "test-dataset") - + async def test_async_create_dataset_item(self): """Test asynchronously creating a dataset item""" # Configure mock mocked_api.async_invoke.return_value = (None, dataset_add_row_response) - + # Call the method values = {"input": "New input", "output": "New output"} - err, row, warning = await self.dataset_sdk.async_addRow( + err, row, warning = await self.dataset_sdk.async_add_row( slug="test-dataset", values=values, name="New Row", ideal_output="New ideal output", metadata={"source": "test"} ) - + # Assertions self.assertIsNone(err) self.assertIsNone(warning) self.assertEqual(row.values, values) self.assertEqual(row.name, "New Row") self.assertEqual(row.idealOutput, "New ideal output") - + # Verify correct endpoint was used endpoint = mocked_api.async_invoke.call_args[0][0] self.assertEqual(endpoint, CreateDatasetItemEndpoint) - + # Verify DTO was created correctly dto = mocked_api.async_invoke.call_args[0][1] self.assertEqual(dto.slug, "test-dataset") self.assertEqual(dto.values, values) self.assertEqual(dto.name, "New Row") self.assertEqual(dto.idealOutput, "New ideal output") - + async def test_async_error_handling_get_dataset(self): """Test error handling when asynchronously getting a dataset""" # Configure mock to return an error error = Exception("API Error") mocked_api.async_invoke.return_value = (error, None) - + # Call the method err, dataset = await self.dataset_sdk.async_get("non-existent") - + # Assertions self.assertIsNotNone(err) self.assertIsNone(dataset) self.assertEqual(str(err), "API Error") - + class AsyncTestRunner: """Helper class to run async tests properly""" - + def __init__(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - + def run_test_case(self, test_case_class): """Run all async test methods in a test case""" suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class) - + for test in suite: test_method = getattr(test, test._testMethodName) if asyncio.iscoroutinefunction(test_method): @@ -190,7 +190,7 @@ def run_test_case(self, test_case_class): print(f"✓ {test._testMethodName}") except Exception as e: print(f"✗ {test._testMethodName}: {e}") - + def close(self): self.loop.close() From 07da568c86e74ef9d58f468ab303e541060a12ce Mon Sep 17 00:00:00 2001 From: Guillaume Date: Sun, 5 Oct 2025 22:43:27 +0200 Subject: [PATCH 3/7] Fix wrong types --- basalt/endpoints/monitor/send_trace.py | 2 +- basalt/objects/base_log.py | 18 +++--- basalt/objects/generation.py | 39 +++++++------ basalt/objects/log.py | 19 ++++--- basalt/objects/trace.py | 33 +++++------ basalt/ressources/monitor/base_log_types.py | 49 +++++++++-------- basalt/ressources/monitor/evaluator_types.py | 2 +- basalt/ressources/monitor/generation_types.py | 55 ++++++++++++------- basalt/ressources/monitor/log_types.py | 35 +++++++++--- basalt/ressources/monitor/trace_types.py | 30 +++++----- basalt/ressources/prompts/prompt_types.py | 2 +- basalt/sdk/monitorsdk.py | 14 ++--- basalt/sdk/promptsdk.py | 22 ++++---- 13 files changed, 181 insertions(+), 139 deletions(-) diff --git a/basalt/endpoints/monitor/send_trace.py b/basalt/endpoints/monitor/send_trace.py index ab0999b..fe58696 100644 --- a/basalt/endpoints/monitor/send_trace.py +++ b/basalt/endpoints/monitor/send_trace.py @@ -78,7 +78,7 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: processed_log["startTime"] = processed_log["start_time"].isoformat() if isinstance(processed_log["start_time"], datetime) else processed_log["start_time"] del processed_log["start_time"] if "end_time" in processed_log: - processed_log["endTime"] = processed_log["end_time"].isoformat() if isinstance(processed_log["end_time"], datetime) and processed_log["end_time"] else None + processed_log["endTime"] = processed_log["end_time"].isoformat() if isinstance(processed_log["end_time"], datetime) else processed_log["end_time"] del processed_log["end_time"] # Extract parent ID diff --git a/basalt/objects/base_log.py b/basalt/objects/base_log.py index c5656e6..19be5e6 100644 --- a/basalt/objects/base_log.py +++ b/basalt/objects/base_log.py @@ -14,15 +14,15 @@ class BaseLog: """ def __init__(self, params: BaseLogParams): self._id = f"log-{uuid.uuid4().hex[:8]}" - self._type = params.type - self._name = params.name - self._start_time = params.start_time if params.start_time is not None else datetime.now() - self._end_time = params.end_time - self._metadata = params.metadata - self._trace = params.trace - self._parent = params.parent - self._evaluators = params.evaluators - self._ideal_output = params.ideal_output + self._type = params.get("type") + self._name = params.get("name") + self._start_time = params.get("start_time") if params.get("start_time") is not None else datetime.now() + self._end_time = params.get("end_time") + self._metadata = params.get("metadata") + self._trace = params.get("trace") + self._parent = params.get("parent") + self._evaluators = params.get("evaluators") + self._ideal_output = params.get("ideal_output") # Add to trace's logs list if trace exists if self._trace: diff --git a/basalt/objects/generation.py b/basalt/objects/generation.py index 3bd8ce2..b6e5ccb 100644 --- a/basalt/objects/generation.py +++ b/basalt/objects/generation.py @@ -11,29 +11,29 @@ class Generation(BaseLog): Class representing a generation in the monitoring system. """ def __init__(self, params: GenerationParams): - base_log_params = BaseLogParams( - name=params.name, - ideal_output=params.ideal_output, - start_time=params.start_time, - end_time=params.end_time, - metadata=params.metadata, - parent=params.parent, - trace=params.trace, - evaluators=params.evaluators, - type=LogType.GENERATION, - ) + base_log_params = { + "name": params.get("name"), + "ideal_output": params.get("ideal_output"), + "start_time": params.get("start_time"), + "end_time": params.get("end_time"), + "metadata": params.get("metadata"), + "parent": params.get("parent"), + "trace": params.get("trace"), + "evaluators": params.get("evaluators"), + "type": LogType.GENERATION, + } super().__init__(base_log_params) - self._prompt = params.prompt - self._input = params.input - self._output = params.output - self._input_tokens = params.input_tokens - self._output_tokens = params.output_tokens - self._cost = params.cost + self._prompt = params.get("prompt") + self._input = params.get("input") + self._output = params.get("output") + self._input_tokens = params.get("input_tokens") + self._output_tokens = params.get("output_tokens") + self._cost = params.get("cost") # Convert variables to array format if needed - variables = params.variables + variables = params.get("variables") if variables is not None: if isinstance(variables, dict): self._variables = [{"label": str(k), "value": str(v)} for k, v in variables.items()] @@ -44,7 +44,7 @@ def __init__(self, params: GenerationParams): else: self._variables = [] - self._options = params.options + self._options = params.get("options") @property def prompt(self) -> Optional[Dict[str, Any]]: @@ -119,7 +119,6 @@ def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generatio Generation: The generation instance. """ super().end() - self._end_time = datetime.now() if isinstance(output, dict): self.update(output) diff --git a/basalt/objects/log.py b/basalt/objects/log.py index e414be5..3033a28 100644 --- a/basalt/objects/log.py +++ b/basalt/objects/log.py @@ -10,7 +10,7 @@ class Log(BaseLog): """ def __init__(self, params: LogParams): super().__init__(params) - self._input = params.input + self._input = params.get("input") self._output = None @property @@ -67,7 +67,8 @@ def append(self, generation: 'Generation') -> 'Log': Log: The log instance. """ # Remove child log from the list of its previous trace - generation.trace.logs = [log for log in generation.trace.logs if log.id != generation.id] + if generation.trace: + generation.trace.logs = [log for log in generation.trace.logs if log.id != generation.id] # Add child to the new trace list self.trace.logs.append(cast(BaseLog, generation)) @@ -101,12 +102,12 @@ def update(self, params: Dict[str, Any]) -> 'Log': return self - def create_generation(self, params: Dict[str, Any]) -> 'Generation': + def create_generation(self, params: GenerationParams) -> 'Generation': """ Create a new generation as a child of this log. Args: - params (Dict[str, Any]): Parameters for the generation. + params (GenerationParams): Parameters for the generation. Returns: Generation: The new generation instance. @@ -116,20 +117,22 @@ def create_generation(self, params: Dict[str, Any]) -> 'Generation': if not name and params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - generation = Generation(GenerationParams(**params, name=name, trace=self.trace, parent=self)) + generation_params = GenerationParams(**params, name=name, trace=self.trace, parent=self) + generation = Generation(generation_params) return generation - def create_log(self, params: Dict[str, Any]) -> 'Log': + def create_log(self, params: LogParams) -> 'Log': """ Create a new log as a child of this log. Args: - params (Dict[str, Any]): Parameters for the log. + params (LogParams): Parameters for the log. Returns: Log: The new log instance. """ - log = Log(LogParams(**params, trace=self.trace, parent=self)) + log_params = LogParams(**params, trace=self.trace, parent=self) + log = Log(log_params) return log diff --git a/basalt/objects/trace.py b/basalt/objects/trace.py index 62c80fc..85ad6a3 100644 --- a/basalt/objects/trace.py +++ b/basalt/objects/trace.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Dict, Optional, Any, List -from ..ressources.monitor.trace_types import TraceParams +from ..ressources.monitor.trace_types import TraceParams, User, Organization from .base_log import BaseLog from .generation import Generation from .log import Log @@ -143,18 +143,19 @@ def set_ideal_output(self, ideal_output: str) -> 'Trace': self._ideal_output = ideal_output return self - def identify(self, params: Dict[str, Any]) -> 'Trace': + def identify(self, user: User = {}, organization: Organization = {}) -> 'Trace': """ Set identification information for the trace. Args: - params (Dict[str, Any]): Identification parameters. + user: The user information to associate with this trace. + organization: The organization information to associate with this trace. Returns: Trace: The trace instance. """ - self._user = params.get("user") - self._organization = params.get("organization") + self._user = user + self._organization = organization return self def set_metadata(self, metadata: Dict[str, Any]) -> 'Trace': @@ -256,40 +257,40 @@ def append(self, generation: 'Generation') -> 'Trace': # Add child to the new trace list self._logs.append(generation) + + # Set the trace of the generation to the current log generation.trace = self + generation.options = {"type": "multi"} return self - def create_generation(self, params: Dict[str, Any]) -> 'Generation': + def create_generation(self, params: GenerationParams) -> 'Generation': """ Create a new generation in this trace. Args: - params (Dict[str, Any]): Parameters for the generation. + params (GenerationParams): Parameters for the generation. Returns: Generation: The new generation instance. """ - # Set the name to the prompt slug if available - name = params.get("name") - if params.get("prompt") and params["prompt"].get("slug"): - name = params["prompt"]["slug"] - - generation = Generation(GenerationParams(**params, name=name, trace=self)) + generation_params = GenerationParams(**params, trace=self) + generation = Generation(generation_params) return generation - def create_log(self, params: Dict[str, Any]) -> 'BaseLog': + def create_log(self, params: LogParams) -> 'BaseLog': """ Create a new log in this trace. Args: - params (Dict[str, Any]): Parameters for the log. + params (LogParams): Parameters for the log. Returns: Log: The new log instance. """ - log = Log(LogParams(**params, trace=self)) + log_params = LogParams(**params, trace=self) + log = Log(log_params) return log diff --git a/basalt/ressources/monitor/base_log_types.py b/basalt/ressources/monitor/base_log_types.py index 7d63b06..eb28b78 100644 --- a/basalt/ressources/monitor/base_log_types.py +++ b/basalt/ressources/monitor/base_log_types.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING +from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING, TypedDict, Literal, TypeVar from dataclasses import dataclass, field from uuid import uuid4 from enum import Enum @@ -10,8 +10,9 @@ from .evaluator_types import Evaluator from .trace_types import Trace +SelfType = TypeVar('SelfType', bound='BaseLog') -class LogType(Enum): +class LogType(str, Enum): """Enum-like class for log types. Attributes: @@ -29,12 +30,18 @@ class LogType(Enum): RETRIEVAL = 'retrieval' EVENT = 'event' -@dataclass -class BaseLogParams: +# Type alias for log type strings - use this in TypedDict parameters +LogTypeStr = Literal['span', 'generation', 'function', 'tool', 'retrieval', 'event'] + +class _BaseLogParamsRequired(TypedDict): + """Required fields for BaseLogParams.""" + name: str + +class BaseLogParams(_BaseLogParamsRequired, TypedDict, total=False): """Base parameters for creating a log entry. Attributes: - name: Name of the log entry, describing what it represents. + name: Name of the log entry, describing what it represents (required). start_time: When the log entry started, can be a datetime object or ISO string. If not provided, defaults to the current time when created. end_time: When the log entry ended, can be a datetime object or ISO string. @@ -47,15 +54,13 @@ class BaseLogParams: Every log must be associated with a trace. evaluators: The evaluators to attach to the log. """ - name: str - ideal_output: Optional[str] = None - start_time: Optional[Union[datetime, str]] = None - end_time: Optional[Union[datetime, str]] = None - metadata: Optional[Dict[str, Any]] = None - parent: Optional['Log'] = None - trace: 'Trace' = None - evaluators: Optional[List[Evaluator]] = None - type: Optional['LogType'] = None + ideal_output: Optional[str] + start_time: Optional[Union[datetime, str]] + end_time: Optional[Union[datetime, str]] + metadata: Optional[Dict[str, Any]] + parent: Optional['Log'] + trace: Optional['Trace'] + evaluators: Optional[List[Evaluator]] @dataclass class BaseLog: @@ -79,17 +84,17 @@ class BaseLog: Every log must be associated with a trace. evaluators: List of evaluators attached to the log. """ + name: str + type: LogType id: str = field(default_factory=lambda: str(f'log-{uuid4().hex[:8]}')) - type: LogType = None - name: str = None start_time: Optional[Union[datetime, str]] = None end_time: Optional[Union[datetime, str]] = None metadata: Optional[Dict[str, Any]] = None parent: Optional['Log'] = None - trace: 'Trace' = None + trace: Optional['Trace'] = None evaluators: List[Evaluator] = field(default_factory=list) - def start(self) -> 'BaseLog': + def start(self: SelfType) -> SelfType: """Marks the log as started and sets the start time if not already set. Returns: @@ -97,7 +102,7 @@ def start(self) -> 'BaseLog': """ ... - def set_metadata(self, metadata: Optional[Dict[str, Any]] = None) -> 'BaseLog': + def set_metadata(self: SelfType, metadata: Optional[Dict[str, Any]] = None) -> SelfType: """Sets the metadata for the log. Args: @@ -108,7 +113,7 @@ def set_metadata(self, metadata: Optional[Dict[str, Any]] = None) -> 'BaseLog': """ ... - def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog': + def add_evaluator(self: SelfType, evaluator: Evaluator) -> SelfType: """Adds an evaluator to the log. Args: @@ -119,7 +124,7 @@ def add_evaluator(self, evaluator: Evaluator) -> 'BaseLog': """ ... - def update(self, params: Dict[str, Any]) -> 'BaseLog': + def update(self: SelfType, params: Dict[str, Any]) -> SelfType: """Updates the log with new parameters. Args: @@ -130,7 +135,7 @@ def update(self, params: Dict[str, Any]) -> 'BaseLog': """ ... - def end(self) -> 'BaseLog': + def end(self: SelfType) -> SelfType: """Marks the log as ended. Returns: diff --git a/basalt/ressources/monitor/evaluator_types.py b/basalt/ressources/monitor/evaluator_types.py index ed93912..accb53e 100644 --- a/basalt/ressources/monitor/evaluator_types.py +++ b/basalt/ressources/monitor/evaluator_types.py @@ -13,4 +13,4 @@ class EvaluationConfig(TypedDict, total=False): """ Configuration for the evaluation of the trace and its logs. """ - sample_rate: float + sample_rate: Optional[float] diff --git a/basalt/ressources/monitor/generation_types.py b/basalt/ressources/monitor/generation_types.py index ae3f175..caf6b5b 100644 --- a/basalt/ressources/monitor/generation_types.py +++ b/basalt/ressources/monitor/generation_types.py @@ -1,17 +1,21 @@ -from typing import Dict, Optional, Union, Any +from typing import Dict, Optional, Union, Any, TypedDict from dataclasses import dataclass, field from .base_log_types import BaseLog, BaseLogParams, LogType -@dataclass -class PromptReference: +class _PromptReferenceRequired(TypedDict): + """Required fields for PromptReference.""" + slug: str + +class PromptReference(_PromptReferenceRequired, TypedDict, total=False): """Reference to a prompt template. This class represents a reference to a prompt template used in AI model generations. Attributes: - slug (str): Unique identifier for the prompt template. - version (str): Version of the prompt template. + slug (str): Unique identifier for the prompt template (required). + version (str): Version of the prompt template (optional). + tag (str): Tag for the prompt template (optional). Example: ```python @@ -19,12 +23,10 @@ class PromptReference: prompt = PromptReference(slug="qa-prompt", version="2.1.0") ``` """ - slug: str - version: Optional[str] = None - tag: Optional[str] = None + version: Optional[str] + tag: Optional[str] -@dataclass -class GenerationParams(BaseLogParams): +class GenerationParams(BaseLogParams, total=False): """Parameters for creating a new generation. This class defines the parameters that can be used to create a new generation, @@ -55,14 +57,18 @@ class GenerationParams(BaseLogParams): ) ``` """ - prompt: Optional[PromptReference] = None - input: Optional[str] = None - output: Optional[str] = None - variables: Optional[Dict[str, Any]] = None - options: Optional[Dict[str, Any]] = None - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - cost: Optional[float] = None + prompt: Optional[PromptReference] + input: Optional[str] + output: Optional[str] + variables: Optional[Dict[str, Any]] + options: Optional[Dict[str, Any]] + input_tokens: Optional[int] + output_tokens: Optional[int] + cost: Optional[float] + +class UpdateGenerationParams(GenerationParams, total=False): + """Parameters for updating a generation.""" + name: Optional[str] @dataclass class Generation(BaseLog): @@ -107,7 +113,7 @@ class Generation(BaseLog): input: Optional[str] = None output: Optional[str] = None variables: Optional[Dict[str, Any]] = None - type: str = field(default=LogType.GENERATION) + type: LogType = field(default=LogType.GENERATION) input_tokens: Optional[int] = None output_tokens: Optional[int] = None cost: Optional[float] = None @@ -160,3 +166,14 @@ def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generatio ``` """ ... + + def update(self, params: 'UpdateGenerationParams') -> 'Generation': + """Updates the log with new parameters. + + Args: + **params: The parameters to update. + + Returns: + The log instance for method chaining. + """ + ... diff --git a/basalt/ressources/monitor/log_types.py b/basalt/ressources/monitor/log_types.py index 7b51dac..9e0c4bc 100644 --- a/basalt/ressources/monitor/log_types.py +++ b/basalt/ressources/monitor/log_types.py @@ -1,27 +1,35 @@ -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, TypedDict from dataclasses import dataclass -from .base_log_types import BaseLog, BaseLogParams, LogType +from .base_log_types import BaseLog, BaseLogParams, LogType, LogTypeStr if TYPE_CHECKING: from .generation_types import Generation, GenerationParams -@dataclass -class LogParams(BaseLogParams): +class _LogParamsRequired(TypedDict): + """Required fields for BaseLogParams.""" + type: LogTypeStr + +class LogParams(BaseLogParams, _LogParamsRequired, total=False): """Parameters for creating or updating a log. This class defines the parameters needed to create or update a log entry, including its type, input, and output data. Attributes: - type: The type of log entry (e.g., 'span', 'generation'). + type: The type of log entry (e.g., 'span', 'generation') (required). Used to distinguish between different kinds of logs. input: Optional input data for this operation. output: Optional output data generated by the operation. """ - type: LogType = None - input: Optional[str] = None - output: Optional[str] = None + input: Optional[str] + output: Optional[str] + +class UpdateLogParams(LogParams, total=False): + """Parameters for updating a log.""" + name: Optional[str] + type: Optional[LogTypeStr] + @dataclass class Log(BaseLog): @@ -186,3 +194,14 @@ def create_log(self, params: LogParams) -> 'Log': ``` """ ... + + def update(self, params: 'UpdateLogParams') -> 'Log': + """Updates the log with new parameters. + + Args: + **params: The parameters to update. + + Returns: + The log instance for method chaining. + """ + ... diff --git a/basalt/ressources/monitor/trace_types.py b/basalt/ressources/monitor/trace_types.py index 85d9ad4..c397079 100644 --- a/basalt/ressources/monitor/trace_types.py +++ b/basalt/ressources/monitor/trace_types.py @@ -4,9 +4,10 @@ from .experiment_types import Experiment from .evaluator_types import Evaluator, EvaluationConfig + if TYPE_CHECKING: - from .log_types import Log, LogParams from .generation_types import Generation, GenerationParams + from .log_types import Log, LogParams from .base_log_types import BaseLog @dataclass @@ -21,21 +22,20 @@ class Organization(TypedDict): id: str name: str -@dataclass class TraceParams(TypedDict, total=False): """Parameters for creating or updating a trace.""" - name: str - input: str - output: str - ideal_output: str - start_time: datetime - end_time: datetime - user: User - organization: Organization - metadata: Dict[str, Any] - experiment: Experiment - evaluators: List[Evaluator] - evaluation_config: EvaluationConfig + name: Optional[str] + input: Optional[str] + output: Optional[str] + ideal_output: Optional[str] + start_time: Optional[datetime] + end_time: Optional[datetime] + user: Optional[User] + organization: Optional['Organization'] + metadata: Optional[Dict[str, Any]] + experiment: Optional[Experiment] + evaluators: Optional[List[Evaluator]] + evaluation_config: Optional[EvaluationConfig] @dataclass @@ -211,7 +211,7 @@ def append(self, log: 'BaseLog') -> 'Trace': """ ... - def identify(self, user: Optional[User] = None, organization: Optional[Organization] = None) -> 'Trace': + def identify(self, user: User = {}, organization: Organization = {}) -> 'Trace': """Associates user information with this trace. Args: diff --git a/basalt/ressources/prompts/prompt_types.py b/basalt/ressources/prompts/prompt_types.py index 2a39a7d..509a1d4 100644 --- a/basalt/ressources/prompts/prompt_types.py +++ b/basalt/ressources/prompts/prompt_types.py @@ -67,7 +67,7 @@ class Prompt: version: str system_text: Optional[str] = None raw_system_text: Optional[str] = None - variables: Optional[Dict[str, str]] = None + variables: Optional[Dict[str, Any]] = None tag: Optional[str] = None def compile_variables(self, variables: Dict[str, Any]) -> 'Prompt': diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index 49c8c44..32af9e2 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -78,35 +78,33 @@ def create_trace( def create_generation( self, - params: Dict[str, Any] + params: GenerationParams ) -> Generation: """ Creates a new generation for monitoring. Args: - params (Dict[str, Any]): Parameters for the generation. + params (GenerationParams): Parameters for the generation. Returns: Generation: A new Generation instance. """ - generation_params = GenerationParams(**params) - return self._create_generation(generation_params) + return self._create_generation(params) def create_log( self, - params: Dict[str, Any] + params: LogParams ) -> Log: """ Creates a new log for monitoring. Args: - params (Dict[str, Any]): Parameters for the log. + params (LogParams): Parameters for the log. Returns: Log: A new Log instance. """ - log_params = LogParams(**params) - return self._create_log(log_params) + return self._create_log(params) async def _create_experiment( self, diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 7cd5a88..4819b01 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -177,17 +177,17 @@ def _prepare_monitoring(self,prompt: IPrompt, ), flusher, self._logger) # Create a generation - generation = Generation(GenerationParams( - name=prompt.slug, - trace=trace, - prompt=PromptReference( - slug=prompt.slug, - version=prompt.version, - tag=prompt.tag - ), - input=prompt.text, - variables=prompt.variables - )) + generation = Generation({ + "name": prompt.slug, + "trace": trace, + "prompt": { + "slug": prompt.slug, + "version": prompt.version, + "tag": prompt.tag + }, + "input": prompt.text, + "variables": prompt.variables + }) return generation From c789bd3fa2fb813a2e40c77ede017528aa04b401 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Mon, 6 Oct 2025 19:56:45 +0200 Subject: [PATCH 4/7] Review unit tests --- basalt/objects/log.py | 4 ++- basalt/sdk/promptsdk.py | 5 ++- basalt/utils/utils.py | 3 +- tests/test_api.py | 42 +++++++++++------------ tests/test_datasetsdk.py | 26 +++++++------- tests/test_get_prompt_endpoint.py | 2 ++ tests/test_monitor_sdk.py | 39 ++++++++++++--------- tests/test_networker.py | 24 ++++++++----- tests/test_promptsdk.py | 56 +++++++++++++++++++------------ tests/test_promptsdk_async.py | 2 ++ tests/test_utils.py | 21 +----------- 11 files changed, 119 insertions(+), 105 deletions(-) diff --git a/basalt/objects/log.py b/basalt/objects/log.py index 3033a28..9e6a512 100644 --- a/basalt/objects/log.py +++ b/basalt/objects/log.py @@ -117,7 +117,9 @@ def create_generation(self, params: GenerationParams) -> 'Generation': if not name and params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - generation_params = GenerationParams(**params, name=name, trace=self.trace, parent=self) + # Create a new params dict to avoid modifying the original + generation_params_dict = {**params, "name": name, "trace": self.trace, "parent": self} + generation_params = GenerationParams(**generation_params_dict) generation = Generation(generation_params) return generation diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 4819b01..21c4e64 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -186,7 +186,10 @@ def _prepare_monitoring(self,prompt: IPrompt, "tag": prompt.tag }, "input": prompt.text, - "variables": prompt.variables + "variables": prompt.variables, + "options": { + "type": "single" + } }) return generation diff --git a/basalt/utils/utils.py b/basalt/utils/utils.py index 74fbcc0..826d13e 100644 --- a/basalt/utils/utils.py +++ b/basalt/utils/utils.py @@ -1,5 +1,4 @@ -from re import sub as reg_replace -from typing import Tuple, Set, Dict, Any +from typing import Dict, Any def pick_typed(dict: Dict[str, Any], field_name: str, expected_type: Any) -> Any: value = dict.get(field_name) diff --git a/tests/test_api.py b/tests/test_api.py index 12b16f7..28eaf05 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,7 +8,7 @@ class TestApi(unittest.TestCase): def test_uses_endpoint_to_encode_request(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, {}) + mocked_network.fetch_sync.return_value = (None, {}) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -26,13 +26,13 @@ def test_uses_endpoint_to_encode_request(self): sdk_type="py-test" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) mocked_endpoint.prepare_request.assert_called_once_with({ "some": "dto" }) def test_uses_endpoint_to_decode_response(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -50,7 +50,7 @@ def test_uses_endpoint_to_decode_response(self): sdk_type="py-test" ) - err, res = api.invoke(mocked_endpoint, { "some": "dto" }) + err, res = api.invoke_sync(mocked_endpoint, { "some": "dto" }) mocked_endpoint.decode_response.assert_called_once_with({ "some": "response" }) @@ -59,7 +59,7 @@ def test_uses_endpoint_to_decode_response(self): def test_forwards_decoder_error(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -77,7 +77,7 @@ def test_forwards_decoder_error(self): sdk_type="py-test" ) - err, res = api.invoke(mocked_endpoint, { "some": "dto" }) + err, res = api.invoke_sync(mocked_endpoint, { "some": "dto" }) mocked_endpoint.decode_response.assert_called_once_with({ "some": "response" }) @@ -88,7 +88,7 @@ def test_forwards_decoder_error(self): @parameterized.expand(["GET", "POST", "PUT", "DELETE"]) def test_uses_http_verb(self, http_verb): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -106,15 +106,15 @@ def test_uses_http_verb(self, http_verb): sdk_type="py-test" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) - call_args = mocked_network.fetch.call_args[0] + call_args = mocked_network.fetch_sync.call_args[0] self.assertEqual(call_args[1], http_verb) def test_prefixes_api_root(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -132,15 +132,15 @@ def test_prefixes_api_root(self): sdk_type="py-test" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) - call_args = mocked_network.fetch.call_args[0] + call_args = mocked_network.fetch_sync.call_args[0] self.assertTrue(call_args[0].startswith("https://basalt-test/")) def test_includes_path_in_url(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -158,9 +158,9 @@ def test_includes_path_in_url(self): sdk_type="py-test" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) - call_args = mocked_network.fetch.call_args[0] + call_args = mocked_network.fetch_sync.call_args[0] self.assertIn("/test-path", call_args[0]) @@ -170,7 +170,7 @@ def test_includes_path_in_url(self): ]) def test_includes_path_in_url(self, params): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -188,15 +188,15 @@ def test_includes_path_in_url(self, params): sdk_type="py-test" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) - call_args = mocked_network.fetch.call_args + call_args = mocked_network.fetch_sync.call_args self.assertEqual(call_args.kwargs["params"], params) def test_passes_headers_to_network(self): mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) + mocked_network.fetch_sync.return_value = (None, { "some": "response" }) mocked_endpoint = MagicMock() mocked_endpoint.prepare_request.return_value = { @@ -214,9 +214,9 @@ def test_passes_headers_to_network(self): sdk_type="test-sdk-type" ) - api.invoke(mocked_endpoint, { "some": "dto" }) + api.invoke_sync(mocked_endpoint, { "some": "dto" }) - headers = mocked_network.fetch.call_args.kwargs["headers"] + headers = mocked_network.fetch_sync.call_args.kwargs["headers"] self.assertIn("Authorization", headers) self.assertIn("my-api-key", headers["Authorization"]) diff --git a/tests/test_datasetsdk.py b/tests/test_datasetsdk.py index 9d6ea3a..fdd0d16 100644 --- a/tests/test_datasetsdk.py +++ b/tests/test_datasetsdk.py @@ -70,10 +70,10 @@ def setUp(self): def test_list_datasets(self): """Test listing all datasets""" # Configure mock - mocked_api.invoke.return_value = (None, dataset_list_response) + mocked_api.invoke_sync.return_value = (None, dataset_list_response) # Call the method - err, datasets = self.dataset_sdk.list() + err, datasets = self.dataset_sdk.list_sync() # Assertions self.assertIsNone(err) @@ -83,16 +83,16 @@ def test_list_datasets(self): self.assertEqual(datasets[1].slug, "another-dataset") # Verify correct endpoint was used - endpoint = mocked_api.invoke.call_args[0][0] + endpoint = mocked_api.invoke_sync.call_args[0][0] self.assertEqual(endpoint, ListDatasetsEndpoint) def test_get_dataset(self): """Test getting a dataset by slug""" # Configure mock - mocked_api.invoke.return_value = (None, dataset_get_response) + mocked_api.invoke_sync.return_value = (None, dataset_get_response) # Call the method - err, dataset = self.dataset_sdk.get("test-dataset") + err, dataset = self.dataset_sdk.get_sync("test-dataset") # Assertions self.assertIsNone(err) @@ -102,21 +102,21 @@ def test_get_dataset(self): self.assertEqual(len(dataset.rows), 1) # Verify correct endpoint was used - endpoint = mocked_api.invoke.call_args[0][0] + endpoint = mocked_api.invoke_sync.call_args[0][0] self.assertEqual(endpoint, GetDatasetEndpoint) # Verify DTO was created correctly - dto = mocked_api.invoke.call_args[0][1] + dto = mocked_api.invoke_sync.call_args[0][1] self.assertEqual(dto.slug, "test-dataset") def test_create_dataset_item(self): """Test creating a dataset item""" # Configure mock - mocked_api.invoke.return_value = (None, dataset_add_row_response) + mocked_api.invoke_sync.return_value = (None, dataset_add_row_response) # Call the method values = {"input": "New input", "output": "New output"} - err, row, warning = self.dataset_sdk.addRow( + err, row, warning = self.dataset_sdk.add_row_sync( slug="test-dataset", values=values, name="New Row", @@ -132,11 +132,11 @@ def test_create_dataset_item(self): self.assertEqual(row.idealOutput, "New ideal output") # Verify correct endpoint was used - endpoint = mocked_api.invoke.call_args[0][0] + endpoint = mocked_api.invoke_sync.call_args[0][0] self.assertEqual(endpoint, CreateDatasetItemEndpoint) # Verify DTO was created correctly - dto = mocked_api.invoke.call_args[0][1] + dto = mocked_api.invoke_sync.call_args[0][1] self.assertEqual(dto.slug, "test-dataset") self.assertEqual(dto.values, values) self.assertEqual(dto.name, "New Row") @@ -145,10 +145,10 @@ def test_create_dataset_item(self): def test_error_handling_get_dataset(self): """Test error handling when getting a dataset""" # Configure mock to return an error - mocked_api.invoke.return_value = (Exception("API Error"), None) + mocked_api.invoke_sync.return_value = (Exception("API Error"), None) # Call the method - err, dataset = self.dataset_sdk.get("non-existent") + err, dataset = self.dataset_sdk.get_sync("non-existent") # Assertions self.assertIsNotNone(err) diff --git a/tests/test_get_prompt_endpoint.py b/tests/test_get_prompt_endpoint.py index e0f1350..9394a55 100644 --- a/tests/test_get_prompt_endpoint.py +++ b/tests/test_get_prompt_endpoint.py @@ -25,6 +25,8 @@ def test_decodes_valid_response(self): "warning": "This is a warning", "prompt": { "text": "Valid prompt text", + "slug": "test-prompt", + "tag": "latest", "systemText": "Some system prompt", "version": "0.1", "model": { diff --git a/tests/test_monitor_sdk.py b/tests/test_monitor_sdk.py index 0a8bd31..ec621aa 100644 --- a/tests/test_monitor_sdk.py +++ b/tests/test_monitor_sdk.py @@ -208,7 +208,7 @@ def test_end_trace(self): trace = self.monitor.create_trace("test-slug", {"input": self.content}) # End the trace - trace.end("Trace output") + trace.end_sync("Trace output") # Assert trace was ended correctly self.assertEqual(trace.output, "Trace output") @@ -239,10 +239,10 @@ def test_trace_identify(self): trace = self.monitor.create_trace("test-slug", {"input": self.content}) # Identify the trace - trace.identify({ - "user": self.user, - "organization": {"id": "org-123", "name": "Basalt"} - }) + trace.identify( + user=self.user, + organization={"id": "org-123", "name": "Basalt"} + ) # Assert trace was identified correctly self.assertEqual(trace.user, self.user) @@ -288,8 +288,10 @@ def test_prompt_generation_integration(self): warning=None, prompt=PromptResponse( text="Answer the following question about {{topic}}: {{question}}", + slug="ml-best-practices", + tag="latest", systemText="Some system prompt", - version="0.1", + version="1.0", model=PromptModel( provider="open-ai", model="gpt-4o", @@ -306,7 +308,7 @@ def test_prompt_generation_integration(self): # Create a mock API for PromptSDK prompt_api = MagicMock() - prompt_api.invoke.return_value = (None, mock_prompt_response) + prompt_api.invoke_sync.return_value = (None, mock_prompt_response) # Create a PromptSDK instance from basalt.utils.memcache import MemoryCache @@ -320,7 +322,7 @@ def test_prompt_generation_integration(self): ) # Get prompt from Basalt - err, prompt_response, generation = prompt_sdk.get( + err, prompt_response, generation = prompt_sdk.get_sync( "ml-best-practices", variables={"topic": "machine learning", "question": self.query}, version="1.0" @@ -340,9 +342,10 @@ def test_prompt_generation_integration(self): # Verify generation object properties self.assertEqual(generation.prompt["slug"], "ml-best-practices") self.assertEqual(generation.prompt["version"], "1.0") - self.assertEqual(generation.input, "Answer the following question about {{topic}}: {{question}}") + # The input should be the compiled text (with variables replaced) + self.assertEqual(generation.input, "Answer the following question about machine learning: What are the best practices for machine learning model deployment?") self.assertEqual(generation.variables, [ - {"label": "topic", "value": "machine learning"}, + {"label": "topic", "value": "machine learning"}, {"label": "question", "value": self.query} ]) self.assertEqual(generation.options["type"], "single") @@ -368,7 +371,7 @@ def test_prompt_generation_integration(self): model_span.end(model_response) # End the main trace - main_trace.end("Completed prompt generation test") + main_trace.end_sync("Completed prompt generation test") # Verify trace structure # Filter logs to only include those of type "span" @@ -401,6 +404,8 @@ def test_complex_workflow(self): warning=None, prompt=PromptResponse( text="Generate content about: {{query}}", + slug="generate-content", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -419,7 +424,7 @@ def test_complex_workflow(self): # Create a mock API for PromptSDK prompt_api = MagicMock() - prompt_api.invoke.return_value = (None, mock_generate_prompt_response) + prompt_api.invoke_sync.return_value = (None, mock_generate_prompt_response) # Create a PromptSDK instance from basalt.utils.memcache import MemoryCache @@ -431,7 +436,7 @@ def test_complex_workflow(self): ) # Get prompt from Basalt - err, prompt_response, generation = prompt_sdk.get( + err, prompt_response, generation = prompt_sdk.get_sync( "generate-content", variables={"query": self.query}, version="1.0" @@ -468,6 +473,8 @@ def test_complex_workflow(self): warning=None, prompt=PromptResponse( text="Classify the following content: {{content}}", + slug="classify-content", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -485,10 +492,10 @@ def test_complex_workflow(self): ) # Update the mock API for the classification prompt - prompt_api.invoke.return_value = (None, mock_classify_prompt_response) + prompt_api.invoke_sync.return_value = (None, mock_classify_prompt_response) # Get prompt from Basalt - err, classify_prompt_response, classify_generation = prompt_sdk.get( + err, classify_prompt_response, classify_generation = prompt_sdk.get_sync( "classify-content", variables={"content": generated_text}, version="1.0" @@ -513,7 +520,7 @@ def test_complex_workflow(self): classification_span.end(categories) # End the main trace - main_trace.end("Workflow completed") + main_trace.end_sync("Workflow completed") # Verify trace structure # Filter logs to only include those of type "span" diff --git a/tests/test_networker.py b/tests/test_networker.py index 7d0098b..7d3c3d0 100644 --- a/tests/test_networker.py +++ b/tests/test_networker.py @@ -11,7 +11,7 @@ class TestNetworker(unittest.TestCase): def test_uses_requests_to_make_http_calls(self, request_mock): networker = Networker() - networker.fetch('http://test/abc', 'GET') + networker.fetch_sync('http://test/abc', 'GET') request_mock.assert_called_once_with('GET', 'http://test/abc', params=None, json=None, headers=None) @@ -20,7 +20,7 @@ def test_captures_requests_exceptions(self, request_mock): networker = Networker() request_mock.side_effect = Exception('Some unknown error') - err, res = networker.fetch('http://test/abc', 'GET') + err, res = networker.fetch_sync('http://test/abc', 'GET') self.assertIsNone(res) self.assertEqual(err.message, 'Some unknown error') @@ -32,7 +32,7 @@ def test_rejects_non_json_responses(self, request_mock): request_mock.return_value = Mock() request_mock.return_value.json.side_effect = Exception('No JSON object could be decoded') - err, res = networker.fetch('http://test/abc', 'GET') + err, res = networker.fetch_sync('http://test/abc', 'GET') self.assertIsNone(res) self.assertIsInstance(err, FetchError) @@ -40,10 +40,13 @@ def test_rejects_non_json_responses(self, request_mock): @patch('requests.request') def test_returns_valid_json_as_dict(self, request_mock): networker = Networker() - request_mock.return_value = Mock() - request_mock.return_value.json.return_value = { "some": "data" } + mock_response = Mock() + mock_response.json.return_value = { "some": "data" } + mock_response.headers = { 'Content-Type': 'application/json' } + mock_response.status_code = 200 + request_mock.return_value = mock_response - err, res = networker.fetch('http://test/abc', 'GET') + err, res = networker.fetch_sync('http://test/abc', 'GET') self.assertIsNone(err) self.assertEqual(res, { "some": "data" }) @@ -57,10 +60,13 @@ def test_returns_valid_json_as_dict(self, request_mock): @patch('requests.request') def test_uses_custom_errors(self, response_code, error_type, request_mock): networker = Networker() - request_mock.return_value = Mock() - request_mock.return_value.status_code = response_code + mock_response = Mock() + mock_response.status_code = response_code + mock_response.headers = { 'Content-Type': 'application/json' } + mock_response.json.return_value = {} + request_mock.return_value = mock_response - err, _ = networker.fetch('http://test/abc', 'GET') + err, _ = networker.fetch_sync('http://test/abc', 'GET') self.assertIsInstance(err, FetchError) self.assertEqual(type(err).__name__, error_type) diff --git a/tests/test_promptsdk.py b/tests/test_promptsdk.py index 6ce29bb..db47b0e 100644 --- a/tests/test_promptsdk.py +++ b/tests/test_promptsdk.py @@ -10,12 +10,14 @@ logger = Logger() mocked_api = MagicMock() -mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( +mocked_api.invoke_sync.return_value = (None, GetPromptEndpointResponse( warning=None, prompt=PromptResponse( text="Some prompt", + slug="test-slug", + tag="prod", systemText="Some system prompt", - version="0.1", + version="1.0", model=PromptModel( provider="open-ai", model="gpt-4o", @@ -45,8 +47,8 @@ def test_uses_correct_endpoint(self): logger=logger ) - prompt.get("slug") - endpoint = mocked_api.invoke.call_args[0][0] + prompt.get_sync("slug") + endpoint = mocked_api.invoke_sync.call_args[0][0] self.assertEqual(endpoint, GetPromptEndpoint) @@ -65,9 +67,9 @@ def test_passes_correct_dto(self, slug, version, tag): logger=logger ) - prompt.get(slug, version=version, tag=tag) + prompt.get_sync(slug, version=version, tag=tag) - dto = mocked_api.invoke.call_args[0][1] + dto = mocked_api.invoke_sync.call_args[0][1] self.assertEqual( dto, @@ -76,7 +78,7 @@ def test_passes_correct_dto(self, slug, version, tag): def test_forwards_api_error(self): mocked_api = MagicMock() - mocked_api.invoke.return_value = (Exception("Some error"), None) + mocked_api.invoke_sync.return_value = (Exception("Some error"), None) prompt = PromptSDK( mocked_api, @@ -85,7 +87,7 @@ def test_forwards_api_error(self): logger=logger ) - err, res, generation = prompt.get("slug") + err, res, generation = prompt.get_sync("slug") self.assertIsInstance(err, Exception) self.assertIsNone(res) @@ -93,10 +95,12 @@ def test_forwards_api_error(self): def test_replaces_variables(self): mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( + mocked_api.invoke_sync.return_value = (None, GetPromptEndpointResponse( warning=None, prompt=PromptResponse( text="Say hello {{name}}", + slug="slug", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -120,19 +124,21 @@ def test_replaces_variables(self): logger=logger ) - _, prompt_response, generation = prompt.get("slug", variables={ "name": "Basalt" }) + _, prompt_response, generation = prompt.get_sync("slug", variables={ "name": "Basalt" }) self.assertEqual(prompt_response.text, "Say hello Basalt") self.assertIsInstance(generation, Generation) - self.assertEqual(generation.input, "Say hello {{name}}") + self.assertEqual(generation.input, "Say hello Basalt") self.assertEqual(generation.prompt["slug"], "slug") def test_saves_raw_prompt_to_cache(self): mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( + mocked_api.invoke_sync.return_value = (None, GetPromptEndpointResponse( warning=None, prompt=PromptResponse( text="Say hello {{name}}", + slug="slug", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -159,7 +165,7 @@ def test_saves_raw_prompt_to_cache(self): logger=logger ) - prompt.get("slug", variables={ "name": "Basalt" }) + prompt.get_sync("slug", variables={ "name": "Basalt" }) mocked_cache.put.assert_called_once() @@ -173,6 +179,8 @@ def test_does_not_request_when_cache_hit(self): mocked_cache = MagicMock() mocked_cache.get.return_value = PromptResponse( text="Say hello {{name}}", + slug="slug", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -194,21 +202,23 @@ def test_does_not_request_when_cache_hit(self): fallback_cache=fallback_cache, logger=logger ) - err, res, generation = prompt.get("slug", variables={ "name": "Cached" }) + err, res, generation = prompt.get_sync("slug", variables={ "name": "Cached" }) - mocked_api.invoke.assert_not_called() + mocked_api.invoke_sync.assert_not_called() self.assertIsNone(err) self.assertEqual(res.text, "Say hello Cached") self.assertIsInstance(generation, Generation) - self.assertEqual(generation.input, "Say hello {{name}}") + self.assertEqual(generation.input, "Say hello Cached") def test_caches_in_fallback_forever(self): mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( + mocked_api.invoke_sync.return_value = (None, GetPromptEndpointResponse( warning=None, prompt=PromptResponse( text="Say hello {{name}}", + slug="slug", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -235,17 +245,19 @@ def test_caches_in_fallback_forever(self): logger=logger ) - prompt.get("slug", variables={ "name": "Cached" }) + prompt.get_sync("slug", variables={ "name": "Cached" }) fallback_cache.put.assert_called_once() def test_uses_fallback_cache_on_api_failure(self): mocked_api = MagicMock() - mocked_api.invoke.return_value = (Exception("Some error"), None) - + mocked_api.invoke_sync.return_value = (Exception("Some error"), None) + fallback_cache = MagicMock() fallback_cache.get.return_value = PromptResponse( text="From fallback cache", + slug="slug", + tag="latest", systemText="Some system prompt", version="0.1", model=PromptModel( @@ -268,7 +280,7 @@ def test_uses_fallback_cache_on_api_failure(self): logger=logger ) - _, res, generation = prompt.get("slug", variables={ "name": "Cached" }) + _, res, generation = prompt.get_sync("slug", variables={ "name": "Cached" }) fallback_cache.get.assert_called_once() self.assertEqual(res.text, "From fallback cache") @@ -282,7 +294,7 @@ def test_returns_generation_object(self): logger=logger ) - _, _, generation = prompt.get("test-slug", version="1.0", tag="prod", variables={"key": "value"}) + _, _, generation = prompt.get_sync("test-slug", version="1.0", tag="prod", variables={"key": "value"}) self.assertIsInstance(generation, Generation) self.assertEqual(generation.prompt["slug"], "test-slug") diff --git a/tests/test_promptsdk_async.py b/tests/test_promptsdk_async.py index 4ecb96e..1ca3a44 100644 --- a/tests/test_promptsdk_async.py +++ b/tests/test_promptsdk_async.py @@ -36,6 +36,8 @@ warning=None, prompt=PromptResponse( text="This is a test prompt: {{variable}}", + slug="test-prompt", + tag="latest", systemText="You are a helpful assistant", version="1.0", model=mock_model diff --git a/tests/test_utils.py b/tests/test_utils.py index ba71c46..aa25813 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,28 +1,9 @@ import unittest from parameterized import parameterized -from basalt.utils.utils import replace_variables, pick_typed, pick_number +from basalt.utils.utils import pick_typed, pick_number class TestUtils(unittest.TestCase): - @parameterized.expand([ - # Replaces a variable - ('Hello {{name}}', { "name": 'Basalt' }, (set([]), 'Hello Basalt')), - - # Replaces a multiple occurences of a variable - ('Hello {{name}} {{name}}', { "name": 'Basalt' }, (set([]), 'Hello Basalt Basalt')), - - # Can replace multiple variables - ('{{a}} + {{b}} = {{c}}', { "a": 1, "b": 2, "c": 3 }, (set([]), '1 + 2 = 3')), - - # Can replace multiple variables - ('{{a}} + {{b}} = {{c {{c}} }}', { "a": 1, "b": 2, "c": 3 }, (set(["c {{c"]), '1 + 2 = {{c {{c}} }}')), - - # Doesn't replace or empty missing variables - ('Hello {{missing}}', {}, (set(["missing"]), 'Hello {{missing}}')) - ]) - def test_replace_variables(self, str_value, vars, expected): - self.assertEqual(replace_variables(str_value, vars), expected) - @parameterized.expand([ ({ "a": 1 }, int, True), ({ "a": 1 }, float, False), From 69aa431672c8f320a63995b4234c66fce360bd33 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Mon, 6 Oct 2025 20:08:18 +0200 Subject: [PATCH 5/7] Fix for python 3.8 --- .github/workflows/python-tests.yml | 28 ++++++++++++++-------------- basalt/objects/prompt.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 5e3acbe..7ceafc9 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -2,7 +2,7 @@ name: Python SDK Tests on: pull_request: - + # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -12,19 +12,19 @@ jobs: pull-requests: write runs-on: ubuntu-latest - + strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] - + python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] + steps: - uses: actions/checkout@v3 - + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - + - name: Install dependencies working-directory: ./ run: | @@ -32,7 +32,7 @@ jobs: pip install pytest pytest-cov pip install -e . pip install -r dev-requirements.txt - + - name: Run tests working-directory: ./ run: | @@ -41,13 +41,13 @@ jobs: echo "COVERAGE_PCT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(f'{float(root.attrib[\"line-rate\"]) * 100:.2f}')")" >> $GITHUB_OUTPUT echo "TEST_COUNT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(root.find('.//metrics').attrib['tests'])")" >> $GITHUB_OUTPUT id: test_results - + - name: Create result file run: | mkdir -p test-results echo "${{ steps.test_results.outputs.COVERAGE_PCT }}" > test-results/coverage.txt echo "${{ steps.test_results.outputs.TEST_COUNT }}" > test-results/test-count.txt - + - name: Upload test results uses: actions/upload-artifact@v4 with: @@ -58,11 +58,11 @@ jobs: needs: test runs-on: ubuntu-latest if: github.event_name == 'pull_request' - + steps: - name: Download all artifacts uses: actions/download-artifact@v4 - + - name: Prepare comment id: prepare_comment run: | @@ -71,7 +71,7 @@ jobs: echo "" >> $GITHUB_ENV echo "| Python Version | Status | Coverage | Tests Run |" >> $GITHUB_ENV echo "| -------------- | ------ | -------- | --------- |" >> $GITHUB_ENV - + for version in 3.8 3.9 3.10 3.11 3.12; do version_path="test-results-$version" if [ -d "$version_path" ] && [ -f "$version_path/coverage.txt" ]; then @@ -82,11 +82,11 @@ jobs: echo "| Python $version | ❌ Failed or not run | - | - |" >> $GITHUB_ENV fi done - + echo "" >> $GITHUB_ENV echo "*Last updated: $(date -u '+%Y-%m-%d %H:%M:%S UTC')*" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV - + - name: Find Comment uses: peter-evans/find-comment@v3 id: fc diff --git a/basalt/objects/prompt.py b/basalt/objects/prompt.py index 2c39f50..1300763 100644 --- a/basalt/objects/prompt.py +++ b/basalt/objects/prompt.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Set from jinja2 import Template, Environment, meta from ..ressources.prompts.prompt_types import PromptParams, PromptModel @@ -85,7 +85,7 @@ def compile_variables(self, variables: Dict[str, Any]) -> 'Prompt': return self @staticmethod - def _find_undeclared_variables(template: str) -> set[str]: + def _find_undeclared_variables(template: str) -> Set[str]: env = Environment() ast = env.parse(template) variables = meta.find_undeclared_variables(ast) From eb501963633625e5da90f68e0d3e2b74c33eaafb Mon Sep 17 00:00:00 2001 From: Guillaume Date: Mon, 6 Oct 2025 20:13:11 +0200 Subject: [PATCH 6/7] fix python versions --- .github/workflows/python-tests.yml | 50 +----------------------------- setup.py | 2 +- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 7ceafc9..306836f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] steps: - uses: actions/checkout@v3 @@ -54,51 +54,3 @@ jobs: name: test-results-${{ matrix.python-version }} path: test-results/ - comment: - needs: test - runs-on: ubuntu-latest - if: github.event_name == 'pull_request' - - steps: - - name: Download all artifacts - uses: actions/download-artifact@v4 - - - name: Prepare comment - id: prepare_comment - run: | - echo "COMMENT_BODY<> $GITHUB_ENV - echo "## Python SDK Test Results" >> $GITHUB_ENV - echo "" >> $GITHUB_ENV - echo "| Python Version | Status | Coverage | Tests Run |" >> $GITHUB_ENV - echo "| -------------- | ------ | -------- | --------- |" >> $GITHUB_ENV - - for version in 3.8 3.9 3.10 3.11 3.12; do - version_path="test-results-$version" - if [ -d "$version_path" ] && [ -f "$version_path/coverage.txt" ]; then - coverage=$(cat "$version_path/coverage.txt") - test_count=$(cat "$version_path/test-count.txt") - echo "| Python $version | ✅ Passed | $coverage% | $test_count |" >> $GITHUB_ENV - else - echo "| Python $version | ❌ Failed or not run | - | - |" >> $GITHUB_ENV - fi - done - - echo "" >> $GITHUB_ENV - echo "*Last updated: $(date -u '+%Y-%m-%d %H:%M:%S UTC')*" >> $GITHUB_ENV - echo "EOF" >> $GITHUB_ENV - - - name: Find Comment - uses: peter-evans/find-comment@v3 - id: fc - with: - issue-number: ${{ github.event.pull_request.number }} - comment-author: 'github-actions[bot]' - body-includes: Python SDK Test Results - - - name: Create or update comment - uses: peter-evans/create-or-update-comment@v4 - with: - comment-id: ${{ steps.fc.outputs.comment-id }} - issue-number: ${{ github.event.pull_request.number }} - body: ${{ env.COMMENT_BODY }} - edit-mode: replace diff --git a/setup.py b/setup.py index 496765a..ec70b2b 100644 --- a/setup.py +++ b/setup.py @@ -30,5 +30,5 @@ def get_version(): "aiohttp>=3.8.0", "jinja2>=3.1.0", ], - python_requires=">=3.6" + python_requires=">=3.8" ) From 71efc987d9ed8bf854f93a85357f193c570159c9 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Mon, 6 Oct 2025 20:17:36 +0200 Subject: [PATCH 7/7] version change --- basalt/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basalt/_version.py b/basalt/_version.py index 493f741..6a9beea 100644 --- a/basalt/_version.py +++ b/basalt/_version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.4.0"