diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml new file mode 100644 index 0000000..23dca11 --- /dev/null +++ b/.github/workflows/code-checks.yml @@ -0,0 +1,30 @@ +name: Code Checks + +on: + pull_request: + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + format: + name: Format check + + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Ruff format check + uses: astral-sh/ruff-action@v3 + with: + args: "format --check --diff" diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 5e3acbe..fbae8ac 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -2,7 +2,7 @@ name: Python SDK Tests on: pull_request: - + # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -12,93 +12,111 @@ jobs: pull-requests: write runs-on: ubuntu-latest - + strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] - + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + steps: - - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - working-directory: ./ - run: | - python -m pip install --upgrade pip - pip install pytest pytest-cov - pip install -e . - pip install -r dev-requirements.txt - - - name: Run tests - working-directory: ./ - run: | - pytest tests/ --cov=basalt --cov-report=term --cov-report=xml - echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_OUTPUT - echo "COVERAGE_PCT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(f'{float(root.attrib[\"line-rate\"]) * 100:.2f}')")" >> $GITHUB_OUTPUT - echo "TEST_COUNT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(root.find('.//metrics').attrib['tests'])")" >> $GITHUB_OUTPUT - id: test_results - - - name: Create result file - run: | - mkdir -p test-results - echo "${{ steps.test_results.outputs.COVERAGE_PCT }}" > test-results/coverage.txt - echo "${{ steps.test_results.outputs.TEST_COUNT }}" > test-results/test-count.txt - - - name: Upload test results - uses: actions/upload-artifact@v4 - with: - name: test-results-${{ matrix.python-version }} - path: test-results/ + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + working-directory: ./ + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + pip install -e . + pip install -r dev-requirements.txt + + - name: Run tests + working-directory: ./ + run: | + pytest tests/ \ + --cov=basalt --cov-report=term --cov-report=xml \ + --junitxml=pytest-report.xml + + echo "PYTHON_VERSION=${{ matrix.python-version }}" >> "$GITHUB_OUTPUT" + + echo "COVERAGE_PCT=$(python - <<'PY' + import xml.etree.ElementTree as ET + rate = float(ET.parse('coverage.xml').getroot().attrib['line-rate']) * 100 + print(f'{rate:.2f}') + PY + )" >> "$GITHUB_OUTPUT" + + echo "TEST_COUNT=$(python - <<'PY' + import xml.etree.ElementTree as ET + root = ET.parse('pytest-report.xml').getroot() + if root.tag == 'testsuite': + print(root.attrib['tests']) + else: + print(sum(int(s.attrib['tests']) for s in root.findall('testsuite'))) + PY + )" >> "$GITHUB_OUTPUT" + id: test_results + + - name: Create result file + run: | + mkdir -p test-results + echo "${{ steps.test_results.outputs.COVERAGE_PCT }}" > test-results/coverage.txt + echo "${{ steps.test_results.outputs.TEST_COUNT }}" > test-results/test-count.txt + + - name: Upload test results + uses: actions/upload-artifact@v4 + with: + name: test-results-${{ matrix.python-version }} + path: test-results/ comment: needs: test runs-on: ubuntu-latest if: github.event_name == 'pull_request' - + steps: - - name: Download all artifacts - uses: actions/download-artifact@v4 - - - name: Prepare comment - id: prepare_comment - run: | - echo "COMMENT_BODY<> $GITHUB_ENV - echo "## Python SDK Test Results" >> $GITHUB_ENV - echo "" >> $GITHUB_ENV - echo "| Python Version | Status | Coverage | Tests Run |" >> $GITHUB_ENV - echo "| -------------- | ------ | -------- | --------- |" >> $GITHUB_ENV - - for version in 3.8 3.9 3.10 3.11 3.12; do - version_path="test-results-$version" - if [ -d "$version_path" ] && [ -f "$version_path/coverage.txt" ]; then - coverage=$(cat "$version_path/coverage.txt") - test_count=$(cat "$version_path/test-count.txt") - echo "| Python $version | ✅ Passed | $coverage% | $test_count |" >> $GITHUB_ENV - else - echo "| Python $version | ❌ Failed or not run | - | - |" >> $GITHUB_ENV - fi - done - - echo "" >> $GITHUB_ENV - echo "*Last updated: $(date -u '+%Y-%m-%d %H:%M:%S UTC')*" >> $GITHUB_ENV - echo "EOF" >> $GITHUB_ENV - - - name: Find Comment - uses: peter-evans/find-comment@v3 - id: fc - with: - issue-number: ${{ github.event.pull_request.number }} - comment-author: 'github-actions[bot]' - body-includes: Python SDK Test Results - - - name: Create or update comment - uses: peter-evans/create-or-update-comment@v4 - with: - comment-id: ${{ steps.fc.outputs.comment-id }} - issue-number: ${{ github.event.pull_request.number }} - body: ${{ env.COMMENT_BODY }} - edit-mode: replace + - name: Download all artifacts + uses: actions/download-artifact@v4 + + - name: Prepare comment + id: prepare_comment + run: | + echo "COMMENT_BODY<> $GITHUB_ENV + echo "## Python SDK Test Results" >> $GITHUB_ENV + echo "" >> $GITHUB_ENV + echo "| Python Version | Status | Coverage | Tests Run |" >> $GITHUB_ENV + echo "| -------------- | ------ | -------- | --------- |" >> $GITHUB_ENV + + for version in 3.8 3.9 3.10 3.11 3.12 3.13; do + version_path="test-results-$version" + if [ -d "$version_path" ] && [ -f "$version_path/coverage.txt" ]; then + coverage=$(cat "$version_path/coverage.txt") + test_count=$(cat "$version_path/test-count.txt") + echo "| Python $version | ✅ Passed | $coverage% | $test_count |" >> $GITHUB_ENV + else + echo "| Python $version | ❌ Failed or not run | - | - |" >> $GITHUB_ENV + fi + done + + echo "" >> $GITHUB_ENV + echo "*Last updated: $(date -u '+%Y-%m-%d %H:%M:%S UTC')*" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fc + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: "github-actions[bot]" + body-includes: Python SDK Test Results + + - name: Create or update comment + uses: peter-evans/create-or-update-comment@v4 + with: + comment-id: ${{ steps.fc.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body: ${{ env.COMMENT_BODY }} + edit-mode: replace diff --git a/basalt/basalt_facade.py b/basalt/basalt_facade.py index 4964b2c..5db0c48 100644 --- a/basalt/basalt_facade.py +++ b/basalt/basalt_facade.py @@ -10,12 +10,13 @@ global_fallback_cache = MemoryCache() + class BasaltFacade(IBasaltSDK): """ The Basalt client. """ - def __init__(self, api_key: str, log_level: str = 'all'): + def __init__(self, api_key: str, log_level: str = "all"): """ Initializes the Basalt client with the given API key and log level. @@ -26,14 +27,14 @@ def __init__(self, api_key: str, log_level: str = 'all'): cache = MemoryCache() logger = Logger(log_level=log_level) networker = Networker(logger=logger) - + api = Api( networker=networker, root_url=config["api_url"], api_key=api_key, sdk_version=config["sdk_version"], sdk_type=config["sdk_type"], - logger=logger + logger=logger, ) prompt = PromptSDK(api, cache, global_fallback_cache, logger) diff --git a/basalt/basaltsdk.py b/basalt/basaltsdk.py index dedabf4..6c61715 100644 --- a/basalt/basaltsdk.py +++ b/basalt/basaltsdk.py @@ -1,5 +1,6 @@ from .utils.protocols import IPromptSDK, IBasaltSDK, IMonitorSDK + class BasaltSDK(IBasaltSDK): """ The BasaltSDK class implements the IBasaltSDK interface. @@ -9,7 +10,7 @@ class BasaltSDK(IBasaltSDK): def __init__(self, prompt_sdk: IPromptSDK, monitor_sdk: IMonitorSDK): self._prompt = prompt_sdk self._monitor = monitor_sdk - + @property def prompt(self) -> IPromptSDK: """Read-only access to the PromptSDK instance""" diff --git a/basalt/config.py b/basalt/config.py index 1d9b3c7..a0c79c6 100644 --- a/basalt/config.py +++ b/basalt/config.py @@ -5,7 +5,9 @@ from ._version import __version__ config = { - 'api_url': 'http://localhost:3001' if build == 'development' else 'https://api.getbasalt.ai', - 'sdk_version': __version__, - 'sdk_type': 'python', + "api_url": "http://localhost:3001" + if build == "development" + else "https://api.getbasalt.ai", + "sdk_version": __version__, + "sdk_type": "python", } diff --git a/basalt/endpoints/describe_prompt.py b/basalt/endpoints/describe_prompt.py index 31b1717..ca5d43a 100644 --- a/basalt/endpoints/describe_prompt.py +++ b/basalt/endpoints/describe_prompt.py @@ -3,6 +3,7 @@ from ..utils.dtos import DescribePromptDTO, DescribePromptResponse + @dataclass class DescribePromptEndpointResponse: warning: Optional[str] @@ -24,10 +25,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "DescribePromptEndpointResponse": prompt=DescribePromptResponse.from_dict(data["prompt"]), ) + class DescribePromptEndpoint: """ Endpoint class for fetching a prompt. """ + @staticmethod def prepare_request(dto: DescribePromptDTO) -> Dict[str, Any]: """ @@ -37,19 +40,18 @@ def prepare_request(dto: DescribePromptDTO) -> Dict[str, Any]: dto (DescribePromptDTO): The data transfer object containing the request parameters. Returns: - The path, method, and query parameters for describing a prompt on the API. + The path, method, and query parameters for describing a prompt on the API. """ return { "path": f"/prompts/{dto.slug}/describe", "method": "GET", - "query": { - "version": dto.version, - "tag": dto.tag - } + "query": {"version": dto.version, "tag": dto.tag}, } @staticmethod - def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[DescribePromptEndpointResponse]]: + def decode_response( + response: dict, + ) -> Tuple[Optional[Exception], Optional[DescribePromptEndpointResponse]]: """ Decode the response returned from the API @@ -57,7 +59,7 @@ def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[Descr response (dict): The JSON response to encode into a DescribePromptEndpointResponse Returns: - A tuple containing an optional exception and an optional DescribePromptEndpointResponse. + A tuple containing an optional exception and an optional DescribePromptEndpointResponse. """ try: return None, DescribePromptEndpointResponse.from_dict(response) diff --git a/basalt/endpoints/get_prompt.py b/basalt/endpoints/get_prompt.py index d6ee408..f09ee97 100644 --- a/basalt/endpoints/get_prompt.py +++ b/basalt/endpoints/get_prompt.py @@ -3,6 +3,7 @@ from ..utils.dtos import GetPromptDTO, PromptResponse + @dataclass class GetPromptEndpointResponse: warning: Optional[str] @@ -24,10 +25,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetPromptEndpointResponse": prompt=PromptResponse.from_dict(data["prompt"]), ) + class GetPromptEndpoint: """ Endpoint class for fetching a prompt. """ + @staticmethod def prepare_request(dto: GetPromptDTO) -> Dict[str, Any]: """ @@ -37,19 +40,18 @@ def prepare_request(dto: GetPromptDTO) -> Dict[str, Any]: dto (GetPromptDTO): The data transfer object containing the request parameters. Returns: - The path, method, and query parameters for getting a prompt on the API. + The path, method, and query parameters for getting a prompt on the API. """ return { "path": f"/prompts/{dto.slug}", "method": "GET", - "query": { - "version": dto.version, - "tag": dto.tag - } + "query": {"version": dto.version, "tag": dto.tag}, } @staticmethod - def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[GetPromptEndpointResponse]]: + def decode_response( + response: dict, + ) -> Tuple[Optional[Exception], Optional[GetPromptEndpointResponse]]: """ Decode the response returned from the API @@ -57,7 +59,7 @@ def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[GetPr response (dict): The JSON response to encode into a GetPromptEndpointResponse Returns: - A tuple containing an optional exception and an optional GetPromptEndpointResponse. + A tuple containing an optional exception and an optional GetPromptEndpointResponse. """ try: return None, GetPromptEndpointResponse.from_dict(response) diff --git a/basalt/endpoints/list_prompts.py b/basalt/endpoints/list_prompts.py index 78fff06..17974d3 100644 --- a/basalt/endpoints/list_prompts.py +++ b/basalt/endpoints/list_prompts.py @@ -3,6 +3,7 @@ from ..utils.dtos import PromptListResponse + @dataclass class ListPromptsEndpointResponse: warning: Optional[str] @@ -21,28 +22,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "ListPromptsEndpointResponse": """ return cls( warning=data.get("warning"), - prompts=[PromptListResponse.from_dict(prompt) for prompt in data["prompts"]], + prompts=[ + PromptListResponse.from_dict(prompt) for prompt in data["prompts"] + ], ) + class ListPromptsEndpoint: """ Endpoint class for fetching a prompt. """ + @staticmethod def prepare_request() -> Dict[str, Any]: """ Prepare the request dictionary for the ListPrompts endpoint. Returns: - The path, method, and query parameters for getting a prompt on the API. + The path, method, and query parameters for getting a prompt on the API. """ - return { - "path": "/prompts", - "method": "GET" - } + return {"path": "/prompts", "method": "GET"} @staticmethod - def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[ListPromptsEndpointResponse]]: + def decode_response( + response: dict, + ) -> Tuple[Optional[Exception], Optional[ListPromptsEndpointResponse]]: """ Decode the response returned from the API @@ -50,7 +54,7 @@ def decode_response(response: dict) -> Tuple[Optional[Exception], Optional[ListP response (dict): The JSON response to encode into a ListPromptsEndpointResponse Returns: - A tuple containing an optional exception and an optional ListPromptsEndpointResponse. + A tuple containing an optional exception and an optional ListPromptsEndpointResponse. """ try: return None, ListPromptsEndpointResponse.from_dict(response) diff --git a/basalt/endpoints/monitor/__init__.py b/basalt/endpoints/monitor/__init__.py index 8d24c29..13258e6 100644 --- a/basalt/endpoints/monitor/__init__.py +++ b/basalt/endpoints/monitor/__init__.py @@ -1,3 +1,3 @@ """ Monitor endpoints for the Basalt API. -""" \ No newline at end of file +""" diff --git a/basalt/endpoints/monitor/send_trace.py b/basalt/endpoints/monitor/send_trace.py index 528b960..f42ec55 100644 --- a/basalt/endpoints/monitor/send_trace.py +++ b/basalt/endpoints/monitor/send_trace.py @@ -1,36 +1,35 @@ """ Endpoint for sending a trace to the API. """ + from typing import Dict, Any, Optional, TypeVar, Tuple from datetime import datetime # Define type variables for the endpoint -Input = TypeVar('Input', bound=Dict[str, Any]) -Output = TypeVar('Output', bound=Dict[str, Any]) +Input = TypeVar("Input", bound=Dict[str, Any]) +Output = TypeVar("Output", bound=Dict[str, Any]) + class SendTraceEndpoint: """ Endpoint for sending a trace to the API. """ + def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: """ Prepares the request for sending a trace. - + Args: dto (Optional[Dict[str, Any]]): The data transfer object containing the trace. - + Returns: Dict[str, Any]: The request information. """ if not dto or "trace" not in dto: - return { - "method": "post", - "path": "/monitor/trace", - "body": {} - } - + return {"method": "post", "path": "/monitor/trace", "body": {}} + trace = dto["trace"] - + # Check if trace is already a dictionary or an object if isinstance(trace, dict): trace_data = trace @@ -41,11 +40,20 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: logs = [] for log in trace_data["logs"]: log_data = log.to_dict() - + # Convert dates to ISO format - log_data["startTime"] = log_data["start_time"].isoformat() if isinstance(log_data["start_time"], datetime) else log_data["start_time"] - log_data["endTime"] = log_data["end_time"].isoformat() if isinstance(log_data["end_time"], datetime) and log_data["end_time"] else None - + log_data["startTime"] = ( + log_data["start_time"].isoformat() + if isinstance(log_data["start_time"], datetime) + else log_data["start_time"] + ) + log_data["endTime"] = ( + log_data["end_time"].isoformat() + if isinstance(log_data["end_time"], datetime) + and log_data["end_time"] + else None + ) + # Remove old format keys del log_data["start_time"] del log_data["end_time"] @@ -55,28 +63,31 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: log_data["input"] = log.input if hasattr(log, "output"): log_data["output"] = log.output - + # Add prompt and variables if it's a generation if hasattr(log, "prompt"): log_data["prompt"] = log.prompt if hasattr(log, "variables") and log.variables: - log_data["variables"] = [{"label": key, "value": value} for key, value in log.variables.items()] + log_data["variables"] = [ + {"label": key, "value": value} + for key, value in log.variables.items() + ] if hasattr(log, "inputTokens"): log_data["inputTokens"] = log.inputTokens if hasattr(log, "outputTokens"): log_data["outputTokens"] = log.outputTokens if hasattr(log, "cost"): log_data["cost"] = log.cost - + # Extract parent ID if log_data["parent"]: log_data["parentId"] = log_data["parent"]["id"] del log_data["parent"] else: log_data["parentId"] = None - + logs.append(log_data) - + # Process logs if they're already in dictionary format processed_logs = [] for log_data in logs: @@ -91,21 +102,30 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: # Convert dates to ISO format if they're in the old format processed_log = dict(log_data) if "start_time" in processed_log: - processed_log["startTime"] = processed_log["start_time"].isoformat() if isinstance(processed_log["start_time"], datetime) else processed_log["start_time"] + processed_log["startTime"] = ( + processed_log["start_time"].isoformat() + if isinstance(processed_log["start_time"], datetime) + else processed_log["start_time"] + ) del processed_log["start_time"] if "end_time" in processed_log: - processed_log["endTime"] = processed_log["end_time"].isoformat() if isinstance(processed_log["end_time"], datetime) and processed_log["end_time"] else None + processed_log["endTime"] = ( + processed_log["end_time"].isoformat() + if isinstance(processed_log["end_time"], datetime) + and processed_log["end_time"] + else None + ) del processed_log["end_time"] - + # Extract parent ID if "parent" in processed_log and processed_log["parent"]: processed_log["parentId"] = processed_log["parent"]["id"] del processed_log["parent"] else: processed_log["parentId"] = None - + processed_logs.append(processed_log) - + # Create the request body body = { "chainSlug": trace_data.get("chain_slug", trace_data.get("chainSlug")), @@ -116,32 +136,30 @@ def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: "user": trace_data.get("user"), "startTime": trace_data.get("start_time", trace_data.get("startTime")), "endTime": trace_data.get("end_time", trace_data.get("endTime")), - "logs": processed_logs + "logs": processed_logs, } - + # Convert dates to ISO format if they're datetime objects if isinstance(body["startTime"], datetime): body["startTime"] = body["startTime"].isoformat() if isinstance(body["endTime"], datetime): body["endTime"] = body["endTime"].isoformat() - - return { - "method": "post", - "path": "/monitor/trace", - "body": body - } - - def decode_response(self, response: Any) -> Tuple[Optional[Exception], Optional[Output]]: + + return {"method": "post", "path": "/monitor/trace", "body": body} + + def decode_response( + self, response: Any + ) -> Tuple[Optional[Exception], Optional[Output]]: """ Decodes the response from sending a trace. - + Args: response (Any): The response from the API. - + Returns: Tuple[Optional[Exception], Optional[Dict[str, Any]]]: The decoded response. """ if not isinstance(response, dict): return Exception("Failed to decode response (invalid body format)"), None - - return None, response.get("trace", {}) \ No newline at end of file + + return None, response.get("trace", {}) diff --git a/basalt/objects/__init__.py b/basalt/objects/__init__.py index a56a1af..f38c4f0 100644 --- a/basalt/objects/__init__.py +++ b/basalt/objects/__init__.py @@ -3,4 +3,4 @@ from .log import Log from .base_log import BaseLog -__all__ = ['Trace', 'Generation', 'Log', 'BaseLog'] \ No newline at end of file +__all__ = ["Trace", "Generation", "Log", "BaseLog"] diff --git a/basalt/objects/base_log.py b/basalt/objects/base_log.py index 200a95b..248704e 100644 --- a/basalt/objects/base_log.py +++ b/basalt/objects/base_log.py @@ -1,15 +1,17 @@ -from datetime import datetime -from typing import Dict, Optional, Any, TYPE_CHECKING import uuid +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, Optional if TYPE_CHECKING: from .log import Log from .trace import Trace + class BaseLog: """ Base class for logs and generations. """ + def __init__(self, params: Dict[str, Any]): self._id = f"log-{uuid.uuid4().hex[:8]}" self._type = params.get("type") @@ -30,12 +32,12 @@ def id(self) -> str: return self._id @property - def parent(self) -> Optional['Log']: + def parent(self) -> Optional["Log"]: """Get the parent log.""" return self._parent @parent.setter - def parent(self, parent: 'Log'): + def parent(self, parent: "Log"): """Set the parent log.""" self._parent = parent @@ -65,39 +67,39 @@ def metadata(self) -> Optional[Dict[str, Any]]: return self._metadata @property - def trace(self) -> 'Trace': + def trace(self) -> "Trace": """Get the trace.""" return self._trace @trace.setter - def trace(self, trace: 'Trace'): + def trace(self, trace: "Trace"): """Set the trace.""" self._trace = trace - def start(self) -> 'BaseLog': + def start(self) -> "BaseLog": """Start the log.""" self._start_time = datetime.now() return self - def set_metadata(self, metadata: Dict[str, Any]) -> 'BaseLog': + def set_metadata(self, metadata: Dict[str, Any]) -> "BaseLog": """Set the metadata.""" self._metadata = metadata return self - def update(self, params: Dict[str, Any]) -> 'BaseLog': + def update(self, params: Dict[str, Any]) -> "BaseLog": """Update the log.""" self._name = params.get("name", self._name) self._metadata = params.get("metadata", self._metadata) - + if params.get("start_time"): self._start_time = params.get("start_time") - + if params.get("end_time"): self._end_time = params.get("end_time") - + return self - def end(self) -> 'BaseLog': + def end(self) -> "BaseLog": """End the log.""" self._end_time = datetime.now() return self @@ -112,4 +114,4 @@ def to_dict(self) -> Dict[str, Any]: "end_time": self._end_time, "metadata": self._metadata, "parent": {"id": self._parent.id} if self._parent else None, - } \ No newline at end of file + } diff --git a/basalt/objects/generation.py b/basalt/objects/generation.py index bbe887e..3ece4ee 100644 --- a/basalt/objects/generation.py +++ b/basalt/objects/generation.py @@ -2,36 +2,41 @@ from .base_log import BaseLog + class Generation(BaseLog): """ Class representing a generation in the monitoring system. """ + def __init__(self, params: Dict[str, Any]): - params_with_type = { - "type": "generation", - **params - } + params_with_type = {"type": "generation", **params} super().__init__(params_with_type) - + self._prompt = params.get("prompt") self._input = params.get("input") self._output = params.get("output") self._inputTokens = params.get("inputTokens") self._outputTokens = params.get("outputTokens") self._cost = params.get("cost") - + # Convert variables to array format if needed variables = params.get("variables") if variables is not None: if isinstance(variables, dict): - self._variables = [{"label": str(k), "value": str(v)} for k, v in variables.items()] + self._variables = [ + {"label": str(k), "value": str(v)} for k, v in variables.items() + ] elif isinstance(variables, list): - self._variables = [{"label": str(v.get("label")), "value": str(v.get("value"))} for v in variables if v.get("label")] + self._variables = [ + {"label": str(v.get("label")), "value": str(v.get("value"))} + for v in variables + if v.get("label") + ] else: self._variables = [] else: self._variables = [] - + self._options = params.get("options") @property @@ -79,53 +84,53 @@ def options(self, options: Dict[str, Any]): """Set the generation options.""" self._options = options - def start(self, input: Optional[str] = None) -> 'Generation': + def start(self, input: Optional[str] = None) -> "Generation": """ Start the generation with an optional input. - + Args: input (Optional[str]): The input to the generation. - + Returns: Generation: The generation instance. """ if input: self._input = input - + super().start() return self - def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> 'Generation': + def end(self, output: Optional[Union[str, Dict[str, Any]]] = None) -> "Generation": """ End the generation with an optional output or update parameters. - + Args: output (Optional[Union[str, Dict[str, Any]]]): The output of the generation or a dictionary of parameters to update. - + Returns: Generation: The generation instance. """ super().end() - + if isinstance(output, dict): self.update(output) elif isinstance(output, str): self._output = output - + # If this is a single generation, end the trace as well if self._options and self._options.get("type") == "single": self.trace.end(self._output) - + return self - def update(self, params: Dict[str, Any]) -> 'Generation': + def update(self, params: Dict[str, Any]) -> "Generation": """ Update the generation. - + Args: params (Dict[str, Any]): Parameters to update. - + Returns: Generation: The generation instance. """ @@ -135,16 +140,22 @@ def update(self, params: Dict[str, Any]) -> 'Generation': self._inputTokens = params.get("inputTokens", self._inputTokens) self._outputTokens = params.get("outputTokens", self._outputTokens) self._cost = params.get("cost", self._cost) - + # Update variables if provided variables = params.get("variables") if variables is not None: if isinstance(variables, dict): - self._variables = [{"label": str(k), "value": str(v)} for k, v in variables.items()] + self._variables = [ + {"label": str(k), "value": str(v)} for k, v in variables.items() + ] elif isinstance(variables, list): - self._variables = [{"label": str(v.get("label")), "value": str(v.get("value"))} for v in variables if v.get("label")] + self._variables = [ + {"label": str(v.get("label")), "value": str(v.get("value"))} + for v in variables + if v.get("label") + ] else: self._variables = [] - + super().update(params) - return self \ No newline at end of file + return self diff --git a/basalt/objects/log.py b/basalt/objects/log.py index 0d38db8..bf8d4be 100644 --- a/basalt/objects/log.py +++ b/basalt/objects/log.py @@ -5,10 +5,12 @@ if TYPE_CHECKING: from .generation import Generation + class Log(BaseLog): """ Class representing a log in the monitoring system. """ + def __init__(self, params: Dict[str, Any]): super().__init__(params) self._input = params.get("input") @@ -24,114 +26,107 @@ def output(self) -> Optional[str]: """Get the log output.""" return self._output - def start(self, input: Optional[str] = None) -> 'Log': + def start(self, input: Optional[str] = None) -> "Log": """ Start the log with an optional input. - + Args: input (Optional[str]): The input to the log. - + Returns: Log: The log instance. """ if input: self._input = input - + super().start() return self - def end(self, output: Optional[str] = None) -> 'Log': + def end(self, output: Optional[str] = None) -> "Log": """ End the log with an optional output. - + Args: output (Optional[str]): The output of the log. - + Returns: Log: The log instance. """ super().end() - + if output: self._output = output - + return self - def append(self, generation: 'Generation') -> 'Log': + def append(self, generation: "Generation") -> "Log": """ Append a generation to this log. - + Args: generation (Generation): The generation to append. - + Returns: Log: The log instance. """ generation.parent = self generation.trace = self.trace - + return self - def update(self, params: Dict[str, Any]) -> 'Log': + def update(self, params: Dict[str, Any]) -> "Log": """ Update the log with new parameters. - + Args: params (Dict[str, Any]): Parameters to update. - + Returns: Log: The log instance. """ super().update(params) - + if "output" in params: self._output = params["output"] - + if "input" in params: self._input = params["input"] - + return self - def create_generation(self, params: Dict[str, Any]) -> 'Generation': + def create_generation(self, params: Dict[str, Any]) -> "Generation": """ Create a new generation as a child of this log. - + Args: params (Dict[str, Any]): Parameters for the generation. - + Returns: Generation: The new generation instance. """ from .generation import Generation - + # Set the name to the prompt slug only if name is not provided name = params.get("name") if not name and params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - - generation = Generation({ - **params, - "name": name, - "trace": self.trace, - "parent": self - }) - + + generation = Generation( + {**params, "name": name, "trace": self.trace, "parent": self} + ) + return generation - def create_log(self, params: Dict[str, Any]) -> 'Log': + def create_log(self, params: Dict[str, Any]) -> "Log": """ Create a new log as a child of this log. - + Args: params (Dict[str, Any]): Parameters for the log. - + Returns: Log: The new log instance. """ - log = Log({ - **params, - "trace": self.trace, - "parent": self - }) - - return log \ No newline at end of file + log = Log({**params, "trace": self.trace, "parent": self}) + + return log diff --git a/basalt/objects/trace.py b/basalt/objects/trace.py index d9d32b3..91bb669 100644 --- a/basalt/objects/trace.py +++ b/basalt/objects/trace.py @@ -6,13 +6,15 @@ from .generation import Generation from ..utils.flusher import Flusher + class Trace: """ Class representing a trace in the monitoring system. """ - def __init__(self, slug: str, params: Dict[str, Any], flusher: 'Flusher'): + + def __init__(self, slug: str, params: Dict[str, Any], flusher: "Flusher"): self._chain_slug = slug - + self._input = params.get("input") self._output = params.get("output") self._name = params.get("name") @@ -21,9 +23,9 @@ def __init__(self, slug: str, params: Dict[str, Any], flusher: 'Flusher'): self._user = params.get("user") self._organization = params.get("organization") self._metadata = params.get("metadata") - - self._logs: List['BaseLog'] = [] - + + self._logs: List["BaseLog"] = [] + self._flusher = flusher self._flushed_promise = None @@ -58,12 +60,12 @@ def metadata(self) -> Optional[Dict[str, Any]]: return self._metadata @property - def logs(self) -> List['BaseLog']: + def logs(self) -> List["BaseLog"]: """Get the logs.""" return self._logs @logs.setter - def logs(self, logs: List['BaseLog']): + def logs(self, logs: List["BaseLog"]): """Set the logs.""" self._logs = logs @@ -77,29 +79,29 @@ def end_time(self) -> Optional[datetime]: """Get the end time.""" return self._end_time - def start(self, input: Optional[str] = None) -> 'Trace': + def start(self, input: Optional[str] = None) -> "Trace": """ Start the trace with an optional input. - + Args: input (Optional[str]): The input to the trace. - + Returns: Trace: The trace instance. """ if input: self._input = input - + self._start_time = datetime.now() return self - def identify(self, params: Dict[str, Any]) -> 'Trace': + def identify(self, params: Dict[str, Any]) -> "Trace": """ Set identification information for the trace. - + Args: params (Dict[str, Any]): Identification parameters. - + Returns: Trace: The trace instance. """ @@ -107,26 +109,26 @@ def identify(self, params: Dict[str, Any]) -> 'Trace': self._organization = params.get("organization") return self - def set_metadata(self, metadata: Dict[str, Any]) -> 'Trace': + def set_metadata(self, metadata: Dict[str, Any]) -> "Trace": """ Set metadata for the trace. - + Args: metadata (Dict[str, Any]): The metadata to set. - + Returns: Trace: The trace instance. """ self._metadata = metadata return self - def update(self, params: Dict[str, Any]) -> 'Trace': + def update(self, params: Dict[str, Any]) -> "Trace": """ Update the trace. - + Args: params (Dict[str, Any]): Parameters to update. - + Returns: Trace: The trace instance. """ @@ -135,99 +137,94 @@ def update(self, params: Dict[str, Any]) -> 'Trace': self._output = params.get("output", self._output) self._organization = params.get("organization", self._organization) self._user = params.get("user", self._user) - + if params.get("start_time"): self._start_time = params.get("start_time") - + if params.get("end_time"): self._end_time = params.get("end_time") - + self._name = params.get("name", self._name) - + return self - def append(self, generation: 'Generation') -> 'Trace': + def append(self, generation: "Generation") -> "Trace": """ Append a generation to this trace. - + Args: generation (Generation): The generation to append. - + Returns: Trace: The trace instance. """ # Remove child log from the list of its previous trace if generation.trace: - generation.trace.logs = [log for log in generation.trace.logs if log.id != generation.id] - + generation.trace.logs = [ + log for log in generation.trace.logs if log.id != generation.id + ] + # Add child to the new trace list self._logs.append(generation) generation.trace = self - + return self - def create_generation(self, params: Dict[str, Any]) -> 'Generation': + def create_generation(self, params: Dict[str, Any]) -> "Generation": """ Create a new generation in this trace. - + Args: params (Dict[str, Any]): Parameters for the generation. - + Returns: Generation: The new generation instance. """ from .generation import Generation - + # Set the name to the prompt slug if available name = params.get("name") if params.get("prompt") and params["prompt"].get("slug"): name = params["prompt"]["slug"] - - generation = Generation({ - **params, - "name": name, - "trace": self - }) - + + generation = Generation({**params, "name": name, "trace": self}) + return generation - def create_log(self, params: Dict[str, Any]) -> 'BaseLog': + def create_log(self, params: Dict[str, Any]) -> "BaseLog": """ Create a new log in this trace. - + Args: params (Dict[str, Any]): Parameters for the log. - + Returns: Log: The new log instance. """ from .log import Log - - log = Log({ - **params, - "trace": self - }) - + + log = Log({**params, "trace": self}) + return log - def end(self, output: Optional[str] = None) -> 'Trace': + def end(self, output: Optional[str] = None) -> "Trace": """ End the trace with an optional output. - + Args: output (Optional[str]): The output of the trace. - + Returns: Trace: The trace instance. """ self._output = output if output is not None else self._output self._end_time = datetime.now() - + # Send to the API using the flusher if not self._flushed_promise: self._flusher.flush_trace(self) - - return self + + return self def to_dict(self) -> Dict[str, Any]: """Convert the trace to a dictionary for API serialization.""" @@ -241,5 +238,5 @@ def to_dict(self) -> Dict[str, Any]: "user": self._user, "organization": self._organization, "metadata": self._metadata, - "logs": self._logs - } \ No newline at end of file + "logs": self._logs, + } diff --git a/basalt/sdk/monitorsdk.py b/basalt/sdk/monitorsdk.py index a410752..53dd36d 100644 --- a/basalt/sdk/monitorsdk.py +++ b/basalt/sdk/monitorsdk.py @@ -8,23 +8,17 @@ from ..objects.log import Log from ..utils.flusher import Flusher + class MonitorSDK: """ SDK for monitoring and tracing in Basalt. """ - def __init__( - self, - api: IApi, - logger: ILogger - ): + + def __init__(self, api: IApi, logger: ILogger): self._api = api self._logger = logger - def create_trace( - self, - slug: str, - params: Optional[Dict[str, Any]] = None - ) -> Trace: + def create_trace(self, slug: str, params: Optional[Dict[str, Any]] = None) -> Trace: """ Creates a new trace for monitoring. @@ -37,14 +31,11 @@ def create_trace( """ if params is None: params = {} - + trace_params = TraceParams(**params) return self._create_trace(slug, trace_params) - def create_generation( - self, - params: Dict[str, Any] - ) -> Generation: + def create_generation(self, params: Dict[str, Any]) -> Generation: """ Creates a new generation for monitoring. @@ -57,10 +48,7 @@ def create_generation( generation_params = GenerationParams(**params) return self._create_generation(generation_params) - def create_log( - self, - params: Dict[str, Any] - ) -> Log: + def create_log(self, params: Dict[str, Any]) -> Log: """ Creates a new log for monitoring. @@ -73,11 +61,7 @@ def create_log( log_params = LogParams(**params) return self._create_log(log_params) - def _create_trace( - self, - slug: str, - params: TraceParams - ) -> Trace: + def _create_trace(self, slug: str, params: TraceParams) -> Trace: """ Internal implementation for creating a trace. @@ -98,15 +82,12 @@ def _create_trace( "end_time": params.end_time, "user": params.user, "organization": params.organization, - "metadata": params.metadata + "metadata": params.metadata, } trace = Trace(slug, params_dict, flusher) return trace - def _create_generation( - self, - params: GenerationParams - ) -> Generation: + def _create_generation(self, params: GenerationParams) -> Generation: """ Internal implementation for creating a generation. @@ -128,14 +109,11 @@ def _create_generation( "metadata": params.metadata, "start_time": params.start_time, "end_time": params.end_time, - "options": params.options + "options": params.options, } return Generation(params_dict) - def _create_log( - self, - params: LogParams - ) -> Log: + def _create_log(self, params: LogParams) -> Log: """ Internal implementation for creating a log. @@ -154,6 +132,6 @@ def _create_log( "parent": params.parent, "metadata": params.metadata, "start_time": params.start_time, - "end_time": params.end_time + "end_time": params.end_time, } - return Log(params_dict) \ No newline at end of file + return Log(params_dict) diff --git a/basalt/sdk/promptsdk.py b/basalt/sdk/promptsdk.py index 5da4cbe..36bf01d 100644 --- a/basalt/sdk/promptsdk.py +++ b/basalt/sdk/promptsdk.py @@ -1,6 +1,15 @@ from typing import Optional, Dict, Tuple, Any -from ..utils.dtos import GetPromptDTO, PromptResponse, DescribePromptResponse, DescribePromptDTO, GetResult, DescribeResult, ListResult, PromptListResponse +from ..utils.dtos import ( + GetPromptDTO, + PromptResponse, + DescribePromptResponse, + DescribePromptDTO, + GetResult, + DescribeResult, + ListResult, + PromptListResponse, +) from ..utils.protocols import ICache, IApi, ILogger from ..endpoints.get_prompt import GetPromptEndpoint @@ -12,17 +21,15 @@ from ..utils.flusher import Flusher from datetime import datetime + class PromptSDK: """ SDK for interacting with Basalt prompts. """ + def __init__( - self, - api: IApi, - cache: ICache, - fallback_cache: ICache, - logger: ILogger - ): + self, api: IApi, cache: ICache, fallback_cache: ICache, logger: ILogger + ): self._api = api self._cache = cache self._fallback_cache = fallback_cache @@ -37,7 +44,7 @@ def get( version: Optional[str] = None, tag: Optional[str] = None, variables: Dict[str, str] = {}, - cache_enabled: bool = True + cache_enabled: bool = True, ) -> Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: """ Retrieve a prompt by slug, optionally specifying version and tag. @@ -50,21 +57,19 @@ def get( cache_enabled (bool): Enable or disable cache for this request. Returns: - Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: + Tuple[Optional[Exception], Optional[PromptResponse], Optional[Generation]]: A tuple containing an optional exception, an optional PromptResponse, and an optional Generation object. """ - dto = GetPromptDTO( - slug=slug, - version=version, - tag=tag - ) + dto = GetPromptDTO(slug=slug, version=version, tag=tag) cached = self._cache.get(dto) if cache_enabled else None if cached: original_prompt_text = cached.text err, prompt_response = self._replace_vars(cached, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring( + prompt_response, slug, version, tag, variables, original_prompt_text + ) return err, prompt_response, generation err, result = self._api.invoke(GetPromptEndpoint, dto) @@ -75,7 +80,9 @@ def get( self._fallback_cache.put(dto, result.prompt) err, prompt_response = self._replace_vars(result.prompt, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring( + prompt_response, slug, version, tag, variables, original_prompt_text + ) return err, prompt_response, generation fallback = self._fallback_cache.get(dto) if cache_enabled else None @@ -83,23 +90,25 @@ def get( if fallback: original_prompt_text = fallback.text err, prompt_response = self._replace_vars(fallback, variables) - generation = self._prepare_monitoring(prompt_response, slug, version, tag, variables, original_prompt_text) + generation = self._prepare_monitoring( + prompt_response, slug, version, tag, variables, original_prompt_text + ) return err, prompt_response, generation return err, None, None def _prepare_monitoring( - self, - prompt: PromptResponse, - slug: str, - version: Optional[str] = None, + self, + prompt: PromptResponse, + slug: str, + version: Optional[str] = None, tag: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, - original_prompt_text: Optional[str] = None + original_prompt_text: Optional[str] = None, ) -> Generation: """ Prepare monitoring by creating a trace and generation object. - + Args: prompt (PromptResponse): The prompt response. slug (str): The slug identifier for the prompt. @@ -107,33 +116,35 @@ def _prepare_monitoring( tag (Optional[str]): The tag associated with the prompt. variables (Optional[Dict[str, Any]]): Variables used in the prompt. original_prompt_text (Optional[str]): The original prompt text. - + Returns: Generation: The generation object. """ # Create a flusher flusher = Flusher(self._api, self._logger) - + # Create a trace - trace = Trace(slug, { - "input": original_prompt_text or prompt.text, - "start_time": datetime.now() - }, flusher) - - # Create a generation - generation = Generation({ - "name": slug, - "trace": trace, - "prompt": { - "slug": slug, - "version": version, - "tag": tag + trace = Trace( + slug, + { + "input": original_prompt_text or prompt.text, + "start_time": datetime.now(), }, - "input": original_prompt_text or prompt.text, - "variables": variables, - "options": {"type": "single"} - }) - + flusher, + ) + + # Create a generation + generation = Generation( + { + "name": slug, + "trace": trace, + "prompt": {"slug": slug, "version": version, "tag": tag}, + "input": original_prompt_text or prompt.text, + "variables": variables, + "options": {"type": "single"}, + } + ) + return generation def describe( @@ -154,11 +165,7 @@ def describe( Returns: Tuple[Optional[Exception], Optional[DescribePromptResponse]]: A tuple containing an optional exception and an optional DescribePromptResponse. """ - dto = DescribePromptDTO( - slug=slug, - version=version, - tag=tag - ) + dto = DescribePromptDTO(slug=slug, version=version, tag=tag) err, result = self._api.invoke(DescribePromptEndpoint, dto) @@ -172,7 +179,7 @@ def describe( description=prompt.description, available_versions=prompt.available_versions, available_tags=prompt.available_tags, - variables=prompt.variables + variables=prompt.variables, ) return err, None @@ -183,14 +190,17 @@ def list(self) -> ListResult: if err is not None: return err, None - return None, [PromptListResponse( - slug=prompt.slug, - status=prompt.status, - name=prompt.name, - description=prompt.description, - available_versions=prompt.available_versions, - available_tags=prompt.available_tags - ) for prompt in result.prompts] + return None, [ + PromptListResponse( + slug=prompt.slug, + status=prompt.status, + name=prompt.name, + description=prompt.description, + available_versions=prompt.available_versions, + available_tags=prompt.available_tags, + ) + for prompt in result.prompts + ] def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): missing_vars, replaced = replace_variables(prompt.text, variables) @@ -203,5 +213,5 @@ def _replace_vars(self, prompt: PromptResponse, variables: Dict[str, str] = {}): text=replaced, systemText=prompt.systemText, version=prompt.version, - model=prompt.model - ) \ No newline at end of file + model=prompt.model, + ) diff --git a/basalt/utils/api.py b/basalt/utils/api.py index b19a1e9..c809825 100644 --- a/basalt/utils/api.py +++ b/basalt/utils/api.py @@ -3,13 +3,14 @@ from .protocols import IEndpoint, INetworker, ILogger from .networker import Networker -Input = TypeVar('Input') -Output = TypeVar('Output') +Input = TypeVar("Input") +Output = TypeVar("Output") + class Api: """ A class to interact with the Basalt API. - + Attributes: root_url (str): The root URL of the API. api_key (str): The API key for authentication. @@ -17,10 +18,19 @@ class Api: sdk_type (str): The SDK type (ex: py-pip) networker (INetworker): The networker instance to handle network requests. """ - def __init__(self, root_url: str, networker: INetworker, api_key: str, sdk_version: str, sdk_type: str, logger: Optional[ILogger] = None): + + def __init__( + self, + root_url: str, + networker: INetworker, + api_key: str, + sdk_version: str, + sdk_type: str, + logger: Optional[ILogger] = None, + ): """ Initialize the Api class with the given parameters. - + Args: root_url (str): The root URL of the API. networker (INetworker): The networker instance to handle network requests. @@ -39,9 +49,7 @@ def __init__(self, root_url: str, networker: INetworker, api_key: str, sdk_versi networker._logger = logger def invoke( - self, - endpoint: IEndpoint[Input, Output], - dto: Optional[Input] = None + self, endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None ) -> Tuple[Optional[Exception], Optional[Output]]: """ Invoke an API endpoint with the given data transfer object (DTO). @@ -61,10 +69,10 @@ def invoke( # Fetch the result from the network using the prepared request information error, result = self._network.fetch( - self._root + request_info['path'], - request_info['method'], - request_info.get('body'), - params=request_info.get('query', {}), + self._root + request_info["path"], + request_info["method"], + request_info.get("body"), + params=request_info.get("query", {}), headers=self._headers(), ) @@ -78,8 +86,8 @@ def _headers(self) -> Dict[str, str]: Generate headers for the request including authorization and SDK information. """ return { - 'Authorization': f'Bearer {self._api_key}', - 'X-BASALT-SDK-VERSION': self._sdk_version, - 'X-BASALT-SDK-TYPE': self._sdk_type, - 'Content-Type': 'application/json' + "Authorization": f"Bearer {self._api_key}", + "X-BASALT-SDK-VERSION": self._sdk_version, + "X-BASALT-SDK-TYPE": self._sdk_type, + "Content-Type": "application/json", } diff --git a/basalt/utils/dtos.py b/basalt/utils/dtos.py index d9b2031..c049469 100644 --- a/basalt/utils/dtos.py +++ b/basalt/utils/dtos.py @@ -3,6 +3,7 @@ from .utils import pick_typed, pick_number + # ------------------------------ Get Prompt ----------------------------- # @dataclass class PromptModelParameters: @@ -21,8 +22,12 @@ class PromptModelParameters: def from_dict(cls, data: Dict[str, Any]): return cls( temperature=pick_number(data, "temperature"), - frequency_penalty=pick_number(data, 'frequencyPenalty') if data.get("frequencyPenalty") else None, - presence_penalty=pick_number(data, "presencePenalty") if data.get("presencePenalty") else None, + frequency_penalty=pick_number(data, "frequencyPenalty") + if data.get("frequencyPenalty") + else None, + presence_penalty=pick_number(data, "presencePenalty") + if data.get("presencePenalty") + else None, top_p=pick_number(data, "topP"), top_k=pick_number(data, "topK") if data.get("topK") else None, max_length=data["maxLength"], @@ -30,6 +35,7 @@ def from_dict(cls, data: Dict[str, Any]): json_object=data.get("jsonObject"), ) + @dataclass(frozen=True) class PromptModel: provider: str @@ -46,6 +52,7 @@ def from_dict(cls, data: Dict[str, Any]): parameters=PromptModelParameters.from_dict(data.get("parameters")), ) + @dataclass(frozen=True) class PromptResponse: text: str @@ -62,36 +69,42 @@ def from_dict(cls, data: Dict[str, Any]): version=pick_typed(data, "version", str), ) + @dataclass(frozen=True) class GetPromptDTO: slug: str tag: Optional[str] = None version: Optional[str] = None + GetResult = Tuple[Optional[Exception], Optional[PromptResponse]] + # ------------------------------ Describe Prompt ----------------------------- # @dataclass(frozen=True) class DescribePromptResponse: - slug: str - status: str - name: str - description: str - available_versions: List[str] - available_tags: List[str] - variables: List[Dict[str, str]] - - @classmethod - def from_dict(cls, data: Dict[str, Any]): - return cls( - slug=pick_typed(data, "slug", str) if data.get("slug") else None, - status=pick_typed(data, "status", str), - name=pick_typed(data, "name", str), - description=pick_typed(data, "description", str) if data.get("description") else None, - available_versions=pick_typed(data, "availableVersions", list), - available_tags=pick_typed(data, "availableTags", list), - variables=pick_typed(data, "variables", list), - ) + slug: str + status: str + name: str + description: str + available_versions: List[str] + available_tags: List[str] + variables: List[Dict[str, str]] + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + return cls( + slug=pick_typed(data, "slug", str) if data.get("slug") else None, + status=pick_typed(data, "status", str), + name=pick_typed(data, "name", str), + description=pick_typed(data, "description", str) + if data.get("description") + else None, + available_versions=pick_typed(data, "availableVersions", list), + available_tags=pick_typed(data, "availableTags", list), + variables=pick_typed(data, "variables", list), + ) + @dataclass(frozen=True) class DescribePromptDTO: @@ -99,8 +112,10 @@ class DescribePromptDTO: tag: Optional[str] = None version: Optional[str] = None + DescribeResult = Tuple[Optional[Exception], Optional[DescribePromptResponse]] + # ------------------------------ List Prompts ----------------------------- # @dataclass(frozen=True) class PromptListResponse: @@ -117,17 +132,22 @@ def from_dict(cls, data: Dict[str, Any]): slug=pick_typed(data, "slug", str) if data.get("slug") else None, status=pick_typed(data, "status", str), name=pick_typed(data, "name", str), - description=pick_typed(data, "description", str) if data.get("description") else None, + description=pick_typed(data, "description", str) + if data.get("description") + else None, available_versions=pick_typed(data, "availableVersions", list), available_tags=pick_typed(data, "availableTags", list), ) + ListResult = Tuple[Optional[Exception], Optional[List[PromptListResponse]]] + # ------------------------------ Monitor ----------------------------- # @dataclass class TraceParams: """Parameters for creating a trace.""" + input: Optional[str] = None output: Optional[str] = None name: Optional[str] = None @@ -137,9 +157,11 @@ class TraceParams: organization: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None + @dataclass class GenerationParams: """Parameters for creating a generation.""" + name: str trace: Any prompt: Optional[Dict[str, Any]] = None @@ -152,9 +174,11 @@ class GenerationParams: end_time: Optional[Any] = None options: Optional[Dict[str, Any]] = None + @dataclass class LogParams: """Parameters for creating a log.""" + name: str trace: Any input: Optional[str] = None @@ -164,4 +188,5 @@ class LogParams: start_time: Optional[Any] = None end_time: Optional[Any] = None -MonitorResult = Tuple[Optional[Exception], Optional[Any]] \ No newline at end of file + +MonitorResult = Tuple[Optional[Exception], Optional[Any]] diff --git a/basalt/utils/errors.py b/basalt/utils/errors.py index c5032c1..a93425f 100644 --- a/basalt/utils/errors.py +++ b/basalt/utils/errors.py @@ -2,17 +2,22 @@ class FetchError(Exception): def __init__(self, message: str): self.message = message + class BadRequest(FetchError): pass + class Unauthorized(FetchError): pass + class Forbidden(FetchError): pass + class NotFound(FetchError): pass + class NetworkBaseError(FetchError): pass diff --git a/basalt/utils/flusher.py b/basalt/utils/flusher.py index ca11c85..f49ac10 100644 --- a/basalt/utils/flusher.py +++ b/basalt/utils/flusher.py @@ -8,21 +8,23 @@ from ..endpoints.monitor.send_trace import SendTraceEndpoint + class Flusher: """ Class for flushing traces to the API. """ - def __init__(self, api: 'IApi', logger: 'ILogger'): + + def __init__(self, api: "IApi", logger: "ILogger"): self._api = api self._logger = logger - def _trace_to_dict(self, trace: 'Trace') -> Dict[str, Any]: + def _trace_to_dict(self, trace: "Trace") -> Dict[str, Any]: """ Convert a trace to a dictionary. - + Args: trace (Trace): The trace to convert. - + Returns: Dict[str, Any]: The trace as a dictionary. """ @@ -39,16 +41,18 @@ def _trace_to_dict(self, trace: 'Trace') -> Dict[str, Any]: "user": trace.user, "organization": trace.organization, "metadata": trace.metadata, - "logs": [self._log_to_dict(log) for log in trace.logs] if trace.logs else [] + "logs": [self._log_to_dict(log) for log in trace.logs] + if trace.logs + else [], } def _log_to_dict(self, log: Any) -> Dict[str, Any]: """ Convert a log to a dictionary. - + Args: log (Any): The log to convert. - + Returns: Dict[str, Any]: The log as a dictionary. """ @@ -63,26 +67,36 @@ def _log_to_dict(self, log: Any) -> Dict[str, Any]: "name": log._name, "input": log.input, "output": output, - "start_time": log.start_time.isoformat() if hasattr(log, 'start_time') and log.start_time else None, - "end_time": log.end_time.isoformat() if hasattr(log, 'end_time') and log.end_time else None, - "metadata": log.metadata if hasattr(log, 'metadata') else None, - "parent": {"id": log.parent.id} if hasattr(log, 'parent') and log.parent else None + "start_time": log.start_time.isoformat() + if hasattr(log, "start_time") and log.start_time + else None, + "end_time": log.end_time.isoformat() + if hasattr(log, "end_time") and log.end_time + else None, + "metadata": log.metadata if hasattr(log, "metadata") else None, + "parent": {"id": log.parent.id} + if hasattr(log, "parent") and log.parent + else None, } # Add generation-specific fields if it's a generation if log.type == "generation" and hasattr(log, "prompt"): - base_dict.update({ - "prompt": log.prompt, - "variables": log.variables if log.variables else [], # Ensure variables is always a list - "options": log.options if hasattr(log, "options") else None - }) + base_dict.update( + { + "prompt": log.prompt, + "variables": log.variables + if log.variables + else [], # Ensure variables is always a list + "options": log.options if hasattr(log, "options") else None, + } + ) return base_dict - def flush_trace(self, trace: 'Trace') -> None: + def flush_trace(self, trace: "Trace") -> None: """ Flush a trace to the API. - + Args: trace (Trace): The trace to flush. """ @@ -90,24 +104,26 @@ def flush_trace(self, trace: 'Trace') -> None: if not self._api: self._logger.warn("Cannot flush trace: no API instance available") return - + # Create an endpoint instance endpoint = SendTraceEndpoint() - + # Convert trace to dictionary trace_dict = self._trace_to_dict(trace) - + # Create the DTO with the trace dictionary dto = {"trace": trace_dict} - + # Invoke the API with the endpoint and DTO error, result = self._api.invoke(endpoint, dto) - + if error: self._logger.warn(f"Failed to flush trace: {error}") return - - self._logger.warn(f"Successfully flushed trace {trace.chain_slug} to the API") - + + self._logger.warn( + f"Successfully flushed trace {trace.chain_slug} to the API" + ) + except Exception as e: - self._logger.warn(f"Exception while flushing trace: {str(e)}") \ No newline at end of file + self._logger.warn(f"Exception while flushing trace: {str(e)}") diff --git a/basalt/utils/logger.py b/basalt/utils/logger.py index e49e25d..b6f5133 100644 --- a/basalt/utils/logger.py +++ b/basalt/utils/logger.py @@ -1,7 +1,8 @@ from .protocols import ILogger + class Logger(ILogger): - def __init__(self, log_level: str = 'all'): + def __init__(self, log_level: str = "all"): self._log_level = log_level def warn(self, *args): @@ -13,7 +14,7 @@ def debug(self, *args): print(*args) def _can_warn(self): - return self._log_level in ['all', 'warning', 'debug'] + return self._log_level in ["all", "warning", "debug"] def _can_debug(self): - return self._log_level in ['debug'] + return self._log_level in ["debug"] diff --git a/basalt/utils/memcache.py b/basalt/utils/memcache.py index 4aa7a47..0b66956 100644 --- a/basalt/utils/memcache.py +++ b/basalt/utils/memcache.py @@ -1,6 +1,7 @@ import time from typing import Any, Dict, Hashable + class MemoryCache: """ MemoryCache is a simple in-memory cache that stores values for a given key. @@ -29,7 +30,7 @@ def get(self, key: Hashable): return None - def put(self, key: Hashable, value: Any, ttl: float = float('inf')) -> None: + def put(self, key: Hashable, value: Any, ttl: float = float("inf")) -> None: """ Stores a value in the cache with an associated time-to-live (TTL). diff --git a/basalt/utils/networker.py b/basalt/utils/networker.py index b89e8b1..1386983 100644 --- a/basalt/utils/networker.py +++ b/basalt/utils/networker.py @@ -1,25 +1,29 @@ import requests from typing import Any, Dict, Optional, Tuple -from .errors import BadRequest, FetchError, Forbidden, NetworkBaseError, NotFound, Unauthorized +from .errors import ( + BadRequest, + FetchError, + Forbidden, + NetworkBaseError, + NotFound, + Unauthorized, +) from .protocols import INetworker, ILogger + class Networker(INetworker): """ Networker class that implements the INetworker protocol. Provides a method to fetch data from a given URL using HTTP methods. """ + def __init__(self, logger: Optional[ILogger] = None): self._logger = logger def fetch( - self, - url: str, - method: str, - body = None, - headers = None, - params = None - ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: + self, url: str, method: str, body=None, headers=None, params=None + ) -> Tuple[Optional[FetchError], Optional[Dict[str, Any]]]: """ Fetch data from a given URL using the specified HTTP method. This method should never throw. @@ -43,11 +47,7 @@ def fetch( self._logger.debug(f"[DEBUG] Body: {body}") response = requests.request( - method, - url, - params=params, - json=body, - headers=headers + method, url, params=params, json=body, headers=headers ) if self._logger: @@ -60,17 +60,17 @@ def fetch( self._logger.debug(f"[DEBUG] Response body: {json_response}") if response.status_code == 400: - return BadRequest(json_response.get('error', 'Bad Request')), None + return BadRequest(json_response.get("error", "Bad Request")), None if response.status_code == 401: - return Unauthorized(json_response.get('error', 'Unauthorized')), None + return Unauthorized(json_response.get("error", "Unauthorized")), None if response.status_code == 403: - return Forbidden(json_response.get('error', 'Forbidden')), None + return Forbidden(json_response.get("error", "Forbidden")), None if response.status_code == 404: - return NotFound(json_response.get('error', 'Not Found')), None - + return NotFound(json_response.get("error", "Not Found")), None + response.raise_for_status() return None, json_response diff --git a/basalt/utils/protocols.py b/basalt/utils/protocols.py index d98c903..851ab05 100644 --- a/basalt/utils/protocols.py +++ b/basalt/utils/protocols.py @@ -1,44 +1,68 @@ from typing import Any, Optional, Protocol, Hashable, Tuple, TypeVar, Dict, Mapping from .dtos import GetResult, DescribeResult, ListResult, MonitorResult -Input = TypeVar('Input') -Output = TypeVar('Output') +Input = TypeVar("Input") +Output = TypeVar("Output") + class ICache(Protocol): def get(self, key: Hashable) -> Optional[Any]: ... def put(self, key: Hashable, value: Any, duration: int) -> None: ... + class IEndpoint(Protocol[Input, Output]): def prepare_request(self, dto: Optional[Input] = None) -> Dict[str, Any]: ... - def decode_response(self, response: Any) -> Tuple[Optional[Exception], Optional[Output]]: ... + def decode_response( + self, response: Any + ) -> Tuple[Optional[Exception], Optional[Output]]: ... + class IApi(Protocol): - def invoke(self,endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None) -> Tuple[Optional[Exception], Optional[Output]]: ... + def invoke( + self, endpoint: IEndpoint[Input, Output], dto: Optional[Input] = None + ) -> Tuple[Optional[Exception], Optional[Output]]: ... + class INetworker(Protocol): - def fetch(self, - url: str, - method: str, - body: Optional[Any] = None, - params: Optional[Mapping[str, str]] = None, - headers: Optional[Mapping[str, str]] = None - ) -> Tuple[Optional[Exception], Optional[Output]]: ... + def fetch( + self, + url: str, + method: str, + body: Optional[Any] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> Tuple[Optional[Exception], Optional[Output]]: ... + class IPromptSDK(Protocol): - def get(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None, variables: Dict[str, str] = {}, cache_enabled: bool = True) -> GetResult: ... - def describe(self, slug: str, tag: Optional[str] = None, version: Optional[str] = None) -> DescribeResult: ... + def get( + self, + slug: str, + tag: Optional[str] = None, + version: Optional[str] = None, + variables: Dict[str, str] = {}, + cache_enabled: bool = True, + ) -> GetResult: ... + def describe( + self, slug: str, tag: Optional[str] = None, version: Optional[str] = None + ) -> DescribeResult: ... def list(self) -> ListResult: ... + class IMonitorSDK(Protocol): - def create_trace(self, slug: str, params: Optional[Dict[str, Any]] = None) -> Any: ... + def create_trace( + self, slug: str, params: Optional[Dict[str, Any]] = None + ) -> Any: ... def create_generation(self, params: Dict[str, Any]) -> Any: ... def create_log(self, params: Dict[str, Any]) -> Any: ... + class IBasaltSDK(Protocol): @property def prompt(self) -> IPromptSDK: ... @property def monitor(self) -> IMonitorSDK: ... + class ILogger: def warn(self, message: str): ... diff --git a/basalt/utils/utils.py b/basalt/utils/utils.py index 51113ca..b1c1a8e 100644 --- a/basalt/utils/utils.py +++ b/basalt/utils/utils.py @@ -1,19 +1,20 @@ from re import sub as reg_replace from typing import Tuple, Set, Dict, Any + def replace_variables(template: str, replacements: dict) -> Tuple[Set[str], str]: missing_keys = set([]) count = lambda key: missing_keys.add(key) get = lambda key: replacements.get(key) if replacements.get(key) else count(key) - replacer = lambda match: str(get(match.group(1))) if get(match.group(1)) else match.group(0) - - replaced = reg_replace( - r'{{(.*?)}}', - replacer, - template + replacer = ( + lambda match: str(get(match.group(1))) + if get(match.group(1)) + else match.group(0) ) + replaced = reg_replace(r"{{(.*?)}}", replacer, template) + return missing_keys, replaced @@ -21,7 +22,9 @@ def pick_typed(dict: Dict[str, Any], field_name: str, expected_type: Any) -> Any value = dict.get(field_name) if not isinstance(value, expected_type): - raise TypeError(f"Field '{field_name}' must be of type {expected_type.__name__}, got {type(value).__name__}.") + raise TypeError( + f"Field '{field_name}' must be of type {expected_type.__name__}, got {type(value).__name__}." + ) # Additional check for int, because isinstance(True, int) == True if expected_type == int and isinstance(value, bool): @@ -29,14 +32,17 @@ def pick_typed(dict: Dict[str, Any], field_name: str, expected_type: Any) -> Any return value + def pick_number(dict: Dict[str, Any], field_name: str) -> float: value = dict.get(field_name) - + if isinstance(value, float): return float(value) # Additional check for int, because isinstance(True, int) == True if isinstance(value, bool) == False and isinstance(value, int): return int(value) - - raise TypeError(f"Field '{field_name}' must be a number (int or float), got {type(value).__name__}.") \ No newline at end of file + + raise TypeError( + f"Field '{field_name}' must be a number (int or float), got {type(value).__name__}." + ) diff --git a/examples/monitor_sdk_demo.ipynb b/examples/monitor_sdk_demo.ipynb index a5d7233..029f997 100644 --- a/examples/monitor_sdk_demo.ipynb +++ b/examples/monitor_sdk_demo.ipynb @@ -17,7 +17,10 @@ "source": [ "import sys\n", "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "sys.path.append(\n", + " os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + ") # Needed to make notebook work in VSCode\n", "\n", "os.environ[\"BASALT_BUILD\"] = \"development\"\n", "\n", @@ -25,9 +28,9 @@ "\n", "# Initialize the SDK\n", "basalt = Basalt(\n", - "\tapi_key=\"sk-d4ef...\", # Replace with your API key\n", - "\tlog_level=\"debug\" # Optional: Set log level\n", - ") " + " api_key=\"sk-d4ef...\", # Replace with your API key\n", + " log_level=\"debug\", # Optional: Set log level\n", + ")" ] }, { @@ -52,8 +55,8 @@ " \"input\": \"What are the benefits of AI in healthcare?\",\n", " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"metadata\": {\"source\": \"web\", \"priority\": \"high\"}\n", - " }\n", + " \"metadata\": {\"source\": \"web\", \"priority\": \"high\"},\n", + " },\n", ")\n", "\n", "print(f\"Created trace with input: {trace.input}\")" @@ -75,15 +78,17 @@ "outputs": [], "source": [ "# Create a log for content moderation\n", - "moderation_log = trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"content-moderation\",\n", - " \"input\": trace.input,\n", - " \"metadata\": {\"model\": \"text-moderation-latest\"},\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - "\t\t\"metadata\": {\"source\": \"web\", \"priority\": \"high\"}\n", - "})\n", + "moderation_log = trace.create_log(\n", + " {\n", + " \"type\": \"span\",\n", + " \"name\": \"content-moderation\",\n", + " \"input\": trace.input,\n", + " \"metadata\": {\"model\": \"text-moderation-latest\"},\n", + " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", + " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", + " \"metadata\": {\"source\": \"web\", \"priority\": \"high\"},\n", + " }\n", + ")\n", "\n", "# Simulate moderation check\n", "moderation_result = {\"flagged\": False, \"categories\": [], \"scores\": {}}\n", @@ -111,31 +116,37 @@ "outputs": [], "source": [ "# Create a log for the main processing\n", - "main_log = trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"main-processing\",\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": trace.input\n", - "})\n", + "main_log = trace.create_log(\n", + " {\n", + " \"type\": \"span\",\n", + " \"name\": \"main-processing\",\n", + " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", + " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", + " \"input\": trace.input,\n", + " }\n", + ")\n", "\n", "# Create a generation within the main log found in Basalt\n", - "generation = main_log.create_generation({\n", - " \"name\": \"healthcare-benefits-generation\",\n", - " \"prompt\": {\n", - " \"slug\": \"prompt-slug\", # This tells the SDK to fetch the prompt from Basalt\n", - " \"version\": \"0.1\" # This specifies the version to use\n", - " },\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", + "generation = main_log.create_generation(\n", + " {\n", + " \"name\": \"healthcare-benefits-generation\",\n", + " \"prompt\": {\n", + " \"slug\": \"prompt-slug\", # This tells the SDK to fetch the prompt from Basalt\n", + " \"version\": \"0.1\", # This specifies the version to use\n", + " },\n", + " \"variables\": {\"variable_example\": \"test variable\"},\n", + " }\n", + ")\n", "\n", "# Create a generation within the main log not managed in Basalt\n", - "generation = main_log.create_generation({\n", - " \"name\": \"healthcare-benefits-generation\",\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": trace.input\n", - "})\n", + "generation = main_log.create_generation(\n", + " {\n", + " \"name\": \"healthcare-benefits-generation\",\n", + " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", + " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", + " \"input\": trace.input,\n", + " }\n", + ")\n", "\n", "# Simulate AI response\n", "ai_response = \"\"\"\n", @@ -178,45 +189,55 @@ " \"theo-slug\",\n", " {\n", " \"input\": \"Patient presents with frequent headaches and fatigue.\",\n", - "\t\t\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"}\n", - " }\n", + " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", + " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", + " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", + " },\n", ")\n", "\n", "# Initial analysis log\n", - "analysis_log = complex_trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"symptom-analysis\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - " \"input\": complex_trace.input\n", - "})\n", + "analysis_log = complex_trace.create_log(\n", + " {\n", + " \"type\": \"span\",\n", + " \"name\": \"symptom-analysis\",\n", + " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", + " \"input\": complex_trace.input,\n", + " }\n", + ")\n", "\n", "# Generate initial analysis\n", - "analysis_gen = analysis_log.create_generation({\n", - " \"name\": \"symptom-classification\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", - "analysis_gen.end(\"Primary symptoms suggest possible migraine or chronic fatigue syndrome\")\n", + "analysis_gen = analysis_log.create_generation(\n", + " {\n", + " \"name\": \"symptom-classification\",\n", + " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", + " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", + " \"variables\": {\"variable_example\": \"test variable\"},\n", + " }\n", + ")\n", + "analysis_gen.end(\n", + " \"Primary symptoms suggest possible migraine or chronic fatigue syndrome\"\n", + ")\n", "\n", "# Create a nested log for recommendations\n", - "recommendations_log = analysis_log.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"treatment-recommendations\",\n", - "\t\t\"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", - "\t\t\"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", - "\t\t\"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", - " \"input\": analysis_gen.output\n", - "})\n", + "recommendations_log = analysis_log.create_log(\n", + " {\n", + " \"type\": \"span\",\n", + " \"name\": \"treatment-recommendations\",\n", + " \"metadata\": {\"department\": \"neurology\", \"priority\": \"high\"},\n", + " \"user\": {\"id\": \"user123\", \"name\": \"John Doe\"},\n", + " \"organization\": {\"id\": \"org123\", \"name\": \"Healthcare Inc\"},\n", + " \"input\": analysis_gen.output,\n", + " }\n", + ")\n", "\n", "# Generate treatment recommendations\n", - "treatment_gen = recommendations_log.create_generation({\n", - " \"name\": \"treatment-suggestions\",\n", - " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", - "\t\t\"variables\": {\"variable_example\": \"test variable\"}\n", - "})\n", + "treatment_gen = recommendations_log.create_generation(\n", + " {\n", + " \"name\": \"treatment-suggestions\",\n", + " \"prompt\": {\"slug\": \"generate-test-cases\", \"version\": \"0.1\"},\n", + " \"variables\": {\"variable_example\": \"test variable\"},\n", + " }\n", + ")\n", "\n", "treatment_response = \"\"\"\n", "Recommended treatments:\n", diff --git a/examples/prompt_sdk_demo.ipynb b/examples/prompt_sdk_demo.ipynb index 824ee3a..bcd201d 100644 --- a/examples/prompt_sdk_demo.ipynb +++ b/examples/prompt_sdk_demo.ipynb @@ -17,7 +17,10 @@ "source": [ "import sys\n", "import os\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "\n", + "sys.path.append(\n", + " os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + ") # Needed to make notebook work in VSCode\n", "\n", "os.environ[\"BASALT_BUILD\"] = \"development\"\n", "\n", @@ -25,8 +28,8 @@ "\n", "# Initialize the SDK\n", "basalt = Basalt(\n", - "\tapi_key=\"sk-d4ef...\", # Replace with your API key\n", - "\tlog_level=\"debug\" # Optional: Set log level\n", + " api_key=\"sk-d4ef...\", # Replace with your API key\n", + " log_level=\"debug\", # Optional: Set log level\n", ")" ] }, @@ -46,7 +49,7 @@ "outputs": [], "source": [ "# Get a prompt by slug (default is production version)\n", - "error, result = basalt.prompt.get('prompt-slug')\n", + "error, result = basalt.prompt.get(\"prompt-slug\")\n", "\n", "if error:\n", " print(f\"Error fetching prompt: {error}\")\n", @@ -70,20 +73,24 @@ "outputs": [], "source": [ "# Get a prompt with a specific tag\n", - "error, result_tag = basalt.prompt.get(slug='prompt-slug', tag='latest')\n", + "error, result_tag = basalt.prompt.get(slug=\"prompt-slug\", tag=\"latest\")\n", "\n", "if error:\n", " print(f\"Error fetching prompt with tag: {error}\")\n", "else:\n", - " print(f\"Successfully fetched prompt with tag 'latest': {result_tag.prompt[:100]}...\")\n", + " print(\n", + " f\"Successfully fetched prompt with tag 'latest': {result_tag.prompt[:100]}...\"\n", + " )\n", "\n", "# Get a prompt with a specific version\n", - "error, result_version = basalt.prompt.get(slug='prompt-slug', version='1.0.0')\n", + "error, result_version = basalt.prompt.get(slug=\"prompt-slug\", version=\"1.0.0\")\n", "\n", "if error:\n", " print(f\"Error fetching prompt with version: {error}\")\n", "else:\n", - " print(f\"Successfully fetched prompt with version '1.0.0': {result_version.prompt[:100]}...\")" + " print(\n", + " f\"Successfully fetched prompt with version '1.0.0': {result_version.prompt[:100]}...\"\n", + " )" ] }, { @@ -103,12 +110,8 @@ "source": [ "# Get a prompt with variables\n", "error, result_vars = basalt.prompt.get(\n", - " slug='prompt-slug-with-vars', \n", - " variables={\n", - " 'name': 'John Doe',\n", - " 'role': 'Developer',\n", - " 'company': 'Acme Inc'\n", - " }\n", + " slug=\"prompt-slug-with-vars\",\n", + " variables={\"name\": \"John Doe\", \"role\": \"Developer\", \"company\": \"Acme Inc\"},\n", ")\n", "\n", "if error:\n", @@ -135,13 +138,15 @@ "# Example with OpenAI (you'll need to install the openai package)\n", "try:\n", " import openai\n", - " \n", + "\n", " # Set up OpenAI client\n", - " client = openai.OpenAI(api_key=\"your-openai-api-key\") # Replace with your OpenAI API key\n", - " \n", + " client = openai.OpenAI(\n", + " api_key=\"your-openai-api-key\"\n", + " ) # Replace with your OpenAI API key\n", + "\n", " # Get a prompt from Basalt\n", - " error, result = basalt.prompt.get('prompt-slug')\n", - " \n", + " error, result = basalt.prompt.get(\"prompt-slug\")\n", + "\n", " if error:\n", " print(f\"Error fetching prompt: {error}\")\n", " else:\n", @@ -150,10 +155,10 @@ " model=\"gpt-4\",\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", - " {\"role\": \"user\", \"content\": result.prompt}\n", - " ]\n", + " {\"role\": \"user\", \"content\": result.prompt},\n", + " ],\n", " )\n", - " \n", + "\n", " print(f\"OpenAI Response: {response.choices[0].message.content}\")\n", "except ImportError:\n", " print(\"OpenAI package not installed. Install with: pip install openai\")" @@ -178,12 +183,9 @@ "def get_prompt_safely(slug, tag=None, version=None, variables=None):\n", " try:\n", " error, result = basalt.prompt.get(\n", - " slug=slug,\n", - " tag=tag,\n", - " version=version,\n", - " variables=variables\n", + " slug=slug, tag=tag, version=version, variables=variables\n", " )\n", - " \n", + "\n", " if error:\n", " print(f\"Error fetching prompt '{slug}': {error}\")\n", " return None\n", @@ -192,6 +194,7 @@ " print(f\"Unexpected error: {str(e)}\")\n", " return None\n", "\n", + "\n", "# Test with a non-existent prompt\n", "prompt_text = get_prompt_safely(\"non-existent-prompt\")\n", "print(f\"Result: {prompt_text}\")\n", @@ -223,31 +226,27 @@ " {\n", " \"input\": \"Tell me about artificial intelligence\",\n", " \"user\": {\"id\": \"user123\", \"name\": \"Jane Smith\"},\n", - " \"metadata\": {\"source\": \"web\"}\n", - " }\n", + " \"metadata\": {\"source\": \"web\"},\n", + " },\n", ")\n", "\n", "# Create a log for processing\n", - "processing_log = trace.create_log({\n", - " \"type\": \"span\",\n", - " \"name\": \"ai-response-generation\",\n", - " \"input\": trace.input\n", - "})\n", + "processing_log = trace.create_log(\n", + " {\"type\": \"span\", \"name\": \"ai-response-generation\", \"input\": trace.input}\n", + ")\n", "\n", "# Get a prompt from Basalt\n", - "error, prompt_result = basalt.prompt.get('ai-explanation-prompt')\n", + "error, prompt_result = basalt.prompt.get(\"ai-explanation-prompt\")\n", "\n", "if error:\n", " processing_log.end({\"error\": str(error)})\n", " trace.end({\"status\": \"error\", \"message\": f\"Failed to get prompt: {error}\"})\n", "else:\n", " # Create a generation using the retrieved prompt\n", - " generation = processing_log.create_generation({\n", - " \"name\": \"ai-explanation\",\n", - " \"input\": trace.input,\n", - " \"prompt\": prompt_result.prompt\n", - " })\n", - " \n", + " generation = processing_log.create_generation(\n", + " {\"name\": \"ai-explanation\", \"input\": trace.input, \"prompt\": prompt_result.prompt}\n", + " )\n", + "\n", " # Simulate AI response\n", " ai_response = \"\"\"\n", " Artificial Intelligence (AI) refers to systems or machines that mimic human intelligence\n", @@ -259,14 +258,14 @@ " 3. Computer Vision\n", " 4. Robotics\n", " \"\"\"\n", - " \n", + "\n", " # End the generation\n", " generation.end(ai_response)\n", - " \n", + "\n", " # End the log and trace\n", " processing_log.end({\"status\": \"success\", \"output\": ai_response})\n", " trace.end({\"status\": \"success\"})\n", - " \n", + "\n", " print(f\"Generated response using Basalt prompt:\\n{ai_response}\")" ] } diff --git a/setup.py b/setup.py index 83d073f..ed2e5f2 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() + def get_version(): version_file = "basalt/_version.py" with open(version_file) as f: @@ -13,20 +14,21 @@ def get_version(): return line.split(delim)[1] raise RuntimeError("Unable to find version string.") + setup( - name="basalt_sdk", - version=get_version(), - description="Basalt SDK for python", - long_description=long_description, - long_description_content_type='text/markdown', - license="MIT", - keywords="basalt, ai, sdk, python", - author="Basalt", - author_email="support@getbasalt.ai", - url="https://github.com/basalt-ai/basalt-python", - packages=find_packages(), - install_requires=[ - "requests>=2.32", - ], - python_requires=">=3.6" + name="basalt_sdk", + version=get_version(), + description="Basalt SDK for python", + long_description=long_description, + long_description_content_type="text/markdown", + license="MIT", + keywords="basalt, ai, sdk, python", + author="Basalt", + author_email="support@getbasalt.ai", + url="https://github.com/basalt-ai/basalt-python", + packages=find_packages(), + install_requires=[ + "requests>=2.32", + ], + python_requires=">=3.6", ) diff --git a/tests/test_api.py b/tests/test_api.py index 12b16f7..e6c67cc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,225 +4,230 @@ from basalt.utils.api import Api -class TestApi(unittest.TestCase): - def test_uses_endpoint_to_encode_request(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, {}) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/some-path", - "method": "GET", - "query": { "tag": "abc" } - } - mocked_endpoint.decode_response.return_value = (None, {}) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - mocked_endpoint.prepare_request.assert_called_once_with({ "some": "dto" }) - - def test_uses_endpoint_to_decode_response(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/some-path", - "method": "GET", - "query": { "tag": "abc" } - } - mocked_endpoint.decode_response.return_value = (None, { "decoded": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - err, res = api.invoke(mocked_endpoint, { "some": "dto" }) - - mocked_endpoint.decode_response.assert_called_once_with({ "some": "response" }) - - self.assertIsNone(err) - self.assertEqual(res, { "decoded": "response" }) - - def test_forwards_decoder_error(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/some-path", - "method": "GET", - "query": { "tag": "abc" } - } - mocked_endpoint.decode_response.return_value = (Exception("Bad response format"), None) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - err, res = api.invoke(mocked_endpoint, { "some": "dto" }) - - mocked_endpoint.decode_response.assert_called_once_with({ "some": "response" }) - - self.assertIsNone(res) - self.assertIsInstance(err, Exception) - self.assertEqual(str(err), "Bad response format") - - @parameterized.expand(["GET", "POST", "PUT", "DELETE"]) - def test_uses_http_verb(self, http_verb): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/", - "method": http_verb, - "query": {} - } - mocked_endpoint.decode_response.return_value = (None, { "some": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - call_args = mocked_network.fetch.call_args[0] - - self.assertEqual(call_args[1], http_verb) - - def test_prefixes_api_root(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/test-path", - "method": "GET", - "query": {} - } - mocked_endpoint.decode_response.return_value = (None, { "some": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - call_args = mocked_network.fetch.call_args[0] - - self.assertTrue(call_args[0].startswith("https://basalt-test/")) - - def test_includes_path_in_url(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/test-path", - "method": "GET", - "query": {} - } - mocked_endpoint.decode_response.return_value = (None, { "some": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - call_args = mocked_network.fetch.call_args[0] - - self.assertIn("/test-path", call_args[0]) - - @parameterized.expand([ - (None), - ({ "tag": "abc" }), - ]) - def test_includes_path_in_url(self, params): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/test-path", - "method": "GET", - "query": params - } - mocked_endpoint.decode_response.return_value = (None, { "some": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="1.0.0", - sdk_type="py-test" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - call_args = mocked_network.fetch.call_args - - self.assertEqual(call_args.kwargs["params"], params) - - def test_passes_headers_to_network(self): - mocked_network = MagicMock() - mocked_network.fetch.return_value = (None, { "some": "response" }) - - mocked_endpoint = MagicMock() - mocked_endpoint.prepare_request.return_value = { - "path": "/test-path", - "method": "GET", - "query": {} - } - mocked_endpoint.decode_response.return_value = (None, { "some": "response" }) - - api = Api( - networker=mocked_network, - root_url="https://basalt-test/", - api_key="my-api-key", - sdk_version="test-sdk-version", - sdk_type="test-sdk-type" - ) - - api.invoke(mocked_endpoint, { "some": "dto" }) - - headers = mocked_network.fetch.call_args.kwargs["headers"] - - self.assertIn("Authorization", headers) - self.assertIn("my-api-key", headers["Authorization"]) - - self.assertIn("X-BASALT-SDK-VERSION", headers) - self.assertIn("test-sdk-version", headers["X-BASALT-SDK-VERSION"]) - - self.assertIn("X-BASALT-SDK-TYPE", headers) - self.assertIn("test-sdk-type", headers["X-BASALT-SDK-TYPE"]) \ No newline at end of file +class TestApi(unittest.TestCase): + def test_uses_endpoint_to_encode_request(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/some-path", + "method": "GET", + "query": {"tag": "abc"}, + } + mocked_endpoint.decode_response.return_value = (None, {}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + mocked_endpoint.prepare_request.assert_called_once_with({"some": "dto"}) + + def test_uses_endpoint_to_decode_response(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/some-path", + "method": "GET", + "query": {"tag": "abc"}, + } + mocked_endpoint.decode_response.return_value = (None, {"decoded": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + err, res = api.invoke(mocked_endpoint, {"some": "dto"}) + + mocked_endpoint.decode_response.assert_called_once_with({"some": "response"}) + + self.assertIsNone(err) + self.assertEqual(res, {"decoded": "response"}) + + def test_forwards_decoder_error(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/some-path", + "method": "GET", + "query": {"tag": "abc"}, + } + mocked_endpoint.decode_response.return_value = ( + Exception("Bad response format"), + None, + ) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + err, res = api.invoke(mocked_endpoint, {"some": "dto"}) + + mocked_endpoint.decode_response.assert_called_once_with({"some": "response"}) + + self.assertIsNone(res) + self.assertIsInstance(err, Exception) + self.assertEqual(str(err), "Bad response format") + + @parameterized.expand(["GET", "POST", "PUT", "DELETE"]) + def test_uses_http_verb(self, http_verb): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/", + "method": http_verb, + "query": {}, + } + mocked_endpoint.decode_response.return_value = (None, {"some": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + call_args = mocked_network.fetch.call_args[0] + + self.assertEqual(call_args[1], http_verb) + + def test_prefixes_api_root(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/test-path", + "method": "GET", + "query": {}, + } + mocked_endpoint.decode_response.return_value = (None, {"some": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + call_args = mocked_network.fetch.call_args[0] + + self.assertTrue(call_args[0].startswith("https://basalt-test/")) + + def test_includes_path_in_url(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/test-path", + "method": "GET", + "query": {}, + } + mocked_endpoint.decode_response.return_value = (None, {"some": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + call_args = mocked_network.fetch.call_args[0] + + self.assertIn("/test-path", call_args[0]) + + @parameterized.expand( + [ + (None), + ({"tag": "abc"}), + ] + ) + def test_includes_path_in_url(self, params): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/test-path", + "method": "GET", + "query": params, + } + mocked_endpoint.decode_response.return_value = (None, {"some": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="1.0.0", + sdk_type="py-test", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + call_args = mocked_network.fetch.call_args + + self.assertEqual(call_args.kwargs["params"], params) + + def test_passes_headers_to_network(self): + mocked_network = MagicMock() + mocked_network.fetch.return_value = (None, {"some": "response"}) + + mocked_endpoint = MagicMock() + mocked_endpoint.prepare_request.return_value = { + "path": "/test-path", + "method": "GET", + "query": {}, + } + mocked_endpoint.decode_response.return_value = (None, {"some": "response"}) + + api = Api( + networker=mocked_network, + root_url="https://basalt-test/", + api_key="my-api-key", + sdk_version="test-sdk-version", + sdk_type="test-sdk-type", + ) + + api.invoke(mocked_endpoint, {"some": "dto"}) + + headers = mocked_network.fetch.call_args.kwargs["headers"] + + self.assertIn("Authorization", headers) + self.assertIn("my-api-key", headers["Authorization"]) + + self.assertIn("X-BASALT-SDK-VERSION", headers) + self.assertIn("test-sdk-version", headers["X-BASALT-SDK-VERSION"]) + + self.assertIn("X-BASALT-SDK-TYPE", headers) + self.assertIn("test-sdk-type", headers["X-BASALT-SDK-TYPE"]) diff --git a/tests/test_get_prompt_endpoint.py b/tests/test_get_prompt_endpoint.py index e0f1350..72eeb03 100644 --- a/tests/test_get_prompt_endpoint.py +++ b/tests/test_get_prompt_endpoint.py @@ -4,49 +4,52 @@ class TestGetPromptEndpoint(unittest.TestCase): - - def test_includes_slug_in_path(self): - result = GetPromptEndpoint.prepare_request(GetPromptDTO(slug="my-complex-slug-that-should-be-unique")) - - self.assertIn("my-complex-slug-that-should-be-unique", result["path"]) - - def test_includes_tags_as_queryparam(self): - result = GetPromptEndpoint.prepare_request(GetPromptDTO(slug="slug", tag="abc")) - - self.assertEqual(result["query"].get("tag"), "abc") - - def test_includes_version_as_queryparam(self): - result = GetPromptEndpoint.prepare_request(GetPromptDTO(slug="slug", version="2.0")) - - self.assertEqual(result["query"].get("version"), "2.0") - - def test_decodes_valid_response(self): - response = { - "warning": "This is a warning", - "prompt": { - "text": "Valid prompt text", - "systemText": "Some system prompt", - "version": "0.1", - "model": { - "provider": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - } - } - } - - exception, decoded = GetPromptEndpoint.decode_response(response) - - self.assertIsNone(exception) - self.assertEqual(decoded.warning, "This is a warning") - self.assertEqual(decoded.prompt.text, "Valid prompt text") - self.assertEqual(decoded.prompt.systemText, "Some system prompt") - self.assertEqual(decoded.prompt.version, "0.1") - self.assertEqual(decoded.prompt.model.model, "gpt-4o") - self.assertEqual(decoded.prompt.model.provider, "open-ai") + def test_includes_slug_in_path(self): + result = GetPromptEndpoint.prepare_request( + GetPromptDTO(slug="my-complex-slug-that-should-be-unique") + ) + + self.assertIn("my-complex-slug-that-should-be-unique", result["path"]) + + def test_includes_tags_as_queryparam(self): + result = GetPromptEndpoint.prepare_request(GetPromptDTO(slug="slug", tag="abc")) + + self.assertEqual(result["query"].get("tag"), "abc") + + def test_includes_version_as_queryparam(self): + result = GetPromptEndpoint.prepare_request( + GetPromptDTO(slug="slug", version="2.0") + ) + + self.assertEqual(result["query"].get("version"), "2.0") + + def test_decodes_valid_response(self): + response = { + "warning": "This is a warning", + "prompt": { + "text": "Valid prompt text", + "systemText": "Some system prompt", + "version": "0.1", + "model": { + "provider": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + }, + }, + } + + exception, decoded = GetPromptEndpoint.decode_response(response) + + self.assertIsNone(exception) + self.assertEqual(decoded.warning, "This is a warning") + self.assertEqual(decoded.prompt.text, "Valid prompt text") + self.assertEqual(decoded.prompt.systemText, "Some system prompt") + self.assertEqual(decoded.prompt.version, "0.1") + self.assertEqual(decoded.prompt.model.model, "gpt-4o") + self.assertEqual(decoded.prompt.model.provider, "open-ai") diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 34e76bc..654f8c2 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -7,40 +7,44 @@ time_mock = MagicMock() + # A dataclass is always hashable @dataclass(frozen=True) class SomeHashableDTO: - name: str - age: int + name: str + age: int + class TestMemCache(unittest.TestCase): - @parameterized.expand([ - ("key", "value"), - ("key2", 123), - ("key3", { "bar": 1 }), - ("key4", ["a", "b", "c"]), - ]) - def test_can_put_with_str_key(self, key, val): - cache = MemoryCache() - cache.put(key, val) + @parameterized.expand( + [ + ("key", "value"), + ("key2", 123), + ("key3", {"bar": 1}), + ("key4", ["a", "b", "c"]), + ] + ) + def test_can_put_with_str_key(self, key, val): + cache = MemoryCache() + cache.put(key, val) - self.assertEqual(cache.get(key), val) + self.assertEqual(cache.get(key), val) - def test_can_use_hashable_as_key(self): - key = SomeHashableDTO(name="John", age=30) + def test_can_use_hashable_as_key(self): + key = SomeHashableDTO(name="John", age=30) - cache = MemoryCache() - cache.put(key, 1) + cache = MemoryCache() + cache.put(key, 1) - self.assertEqual(cache.get(key), 1) + self.assertEqual(cache.get(key), 1) - @patch("time.time") - def test_cache_times_out_after_ttl(self, time_mock): - time_mock.return_value = 0.0 + @patch("time.time") + def test_cache_times_out_after_ttl(self, time_mock): + time_mock.return_value = 0.0 - cache = MemoryCache() - cache.put("abc123", "value", ttl=200.0) + cache = MemoryCache() + cache.put("abc123", "value", ttl=200.0) - time_mock.return_value = 201.0 + time_mock.return_value = 201.0 - self.assertIsNone(cache.get("abc123")) \ No newline at end of file + self.assertIsNone(cache.get("abc123")) diff --git a/tests/test_monitor_sdk.py b/tests/test_monitor_sdk.py index 60f5f9f..f580fa4 100644 --- a/tests/test_monitor_sdk.py +++ b/tests/test_monitor_sdk.py @@ -10,43 +10,46 @@ from basalt.objects.generation import Generation from basalt.objects.log import Log + # Mock classes for testing class MockOpenAI: """Mock OpenAI client for demonstration purposes.""" - + def generate_text(self, prompt: str) -> str: """Generate text using a mock OpenAI.""" return f"Generated response for: {prompt[:50]}..." - + def classify_content(self, content: str) -> str: """Classify content using a mock OpenAI.""" return "Classification: Technology, Healthcare, AI" - + def translate_text(self, text: str) -> str: """Translate text using a mock OpenAI.""" return "Traducción: Este es un texto traducido al español." - + def summarize_text(self, text: str) -> str: """Summarize text using a mock OpenAI.""" return "Summary: This is a concise summary of the provided content." + # Setup common test objects logger = Logger() mocked_api = MagicMock() mocked_api.invoke.return_value = (None, None) # Default return value + class TestMonitorSDK(unittest.TestCase): """Test cases for the MonitorSDK class.""" - + def setUp(self): """Set up test fixtures before each test method.""" self.monitor = MonitorSDK(mocked_api, logger) self.openai = MockOpenAI() - + # Common test data self.user = {"id": "user123", "name": "John Doe"} self.content = "Create a technical article about machine learning applications in healthcare" - + def test_create_trace(self): """Test creating a trace.""" trace = self.monitor.create_trace( @@ -56,10 +59,10 @@ def test_create_trace(self): "user": self.user, "organization": {"id": "org-123", "name": "Basalt"}, "metadata": {"property1": "value1", "property2": "value2"}, - "name": "Test Trace" - } + "name": "Test Trace", + }, ) - + # Assert trace was created correctly self.assertIsNotNone(trace) self.assertIsInstance(trace, Trace) @@ -68,18 +71,20 @@ def test_create_trace(self): self.assertEqual(trace.organization, {"id": "org-123", "name": "Basalt"}) self.assertEqual(trace.metadata, {"property1": "value1", "property2": "value2"}) self.assertEqual(trace.chain_slug, "test-slug") - + def test_create_log(self): """Test creating a log within a trace.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content, - "metadata": {"property1": "value1"} - }) - + + log = trace.create_log( + { + "type": "span", + "name": "test-log", + "input": self.content, + "metadata": {"property1": "value1"}, + } + ) + # Assert log was created correctly self.assertIsNotNone(log) self.assertIsInstance(log, Log) @@ -87,27 +92,27 @@ def test_create_log(self): self.assertEqual(log.name, "test-log") self.assertEqual(log.metadata, {"property1": "value1"}) self.assertEqual(log.trace, trace) - + # Assert log is in trace logs self.assertIn(log, trace.logs) - + def test_create_generation(self): """Test creating a generation within a log.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content - }) - - generation = log.create_generation({ - "name": "test-generation", - "input": self.content, - "prompt": {"slug": "test-prompt", "version": "1.0"}, - "variables": [{"label": "var1", "value": "value1"}] - }) - + + log = trace.create_log( + {"type": "span", "name": "test-log", "input": self.content} + ) + + generation = log.create_generation( + { + "name": "test-generation", + "input": self.content, + "prompt": {"slug": "test-prompt", "version": "1.0"}, + "variables": [{"label": "var1", "value": "value1"}], + } + ) + # Assert generation was created correctly self.assertIsNotNone(generation) self.assertIsInstance(generation, Generation) @@ -116,134 +121,125 @@ def test_create_generation(self): self.assertEqual(generation.prompt, {"slug": "test-prompt", "version": "1.0"}) self.assertEqual(generation.variables, [{"label": "var1", "value": "value1"}]) self.assertEqual(generation.trace, trace) - + def test_update_log(self): """Test updating a log.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content - }) - + + log = trace.create_log( + {"type": "span", "name": "test-log", "input": self.content} + ) + # Update the log - log.update({ - "metadata": {"updated": True, "timestamp": "2023-01-01"}, - "output": "Updated output" - }) - + log.update( + { + "metadata": {"updated": True, "timestamp": "2023-01-01"}, + "output": "Updated output", + } + ) + # Assert log was updated correctly self.assertEqual(log.output, "Updated output") self.assertEqual(log.metadata, {"updated": True, "timestamp": "2023-01-01"}) - + def test_update_generation(self): """Test updating a generation.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content - }) - - generation = log.create_generation({ - "name": "test-generation", - "input": self.content - }) - + + log = trace.create_log( + {"type": "span", "name": "test-log", "input": self.content} + ) + + generation = log.create_generation( + {"name": "test-generation", "input": self.content} + ) + # Update the generation - generation.update({ - "metadata": {"updated": True, "timestamp": "2023-01-01"}, - "output": "Updated output", - "prompt": {"slug": "updated-prompt"} - }) - + generation.update( + { + "metadata": {"updated": True, "timestamp": "2023-01-01"}, + "output": "Updated output", + "prompt": {"slug": "updated-prompt"}, + } + ) + # Assert generation was updated correctly self.assertEqual(generation.output, "Updated output") - self.assertEqual(generation.metadata, {"updated": True, "timestamp": "2023-01-01"}) + self.assertEqual( + generation.metadata, {"updated": True, "timestamp": "2023-01-01"} + ) self.assertEqual(generation.prompt, {"slug": "updated-prompt"}) - + def test_end_log(self): """Test ending a log.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content - }) - + + log = trace.create_log( + {"type": "span", "name": "test-log", "input": self.content} + ) + # End the log log.end("Log output") - + # Assert log was ended correctly self.assertEqual(log.output, "Log output") self.assertIsNotNone(log.end_time) - + def test_end_generation(self): """Test ending a generation.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - log = trace.create_log({ - "type": "span", - "name": "test-log", - "input": self.content - }) - - generation = log.create_generation({ - "name": "test-generation", - "input": self.content - }) - + + log = trace.create_log( + {"type": "span", "name": "test-log", "input": self.content} + ) + + generation = log.create_generation( + {"name": "test-generation", "input": self.content} + ) + # End the generation generation.end("Generation output") - + # Assert generation was ended correctly self.assertEqual(generation.output, "Generation output") self.assertIsNotNone(generation.end_time) - + def test_end_trace(self): """Test ending a trace.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - + # End the trace trace.end("Trace output") - + # Assert trace was ended correctly self.assertEqual(trace.output, "Trace output") self.assertIsNotNone(trace.end_time) - + def test_nested_logs(self): """Test creating nested logs.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - - parent_log = trace.create_log({ - "type": "span", - "name": "parent-log", - "input": self.content - }) - - child_log = parent_log.create_log({ - "type": "span", - "name": "child-log", - "input": "Child input" - }) - + + parent_log = trace.create_log( + {"type": "span", "name": "parent-log", "input": self.content} + ) + + child_log = parent_log.create_log( + {"type": "span", "name": "child-log", "input": "Child input"} + ) + # Assert parent-child relationship self.assertEqual(child_log.parent, parent_log) self.assertEqual(child_log.trace, trace) - + def test_trace_identify(self): """Test identifying a trace with user and organization.""" trace = self.monitor.create_trace("test-slug", {"input": self.content}) - + # Identify the trace - trace.identify({ - "user": self.user, - "organization": {"id": "org-123", "name": "Basalt"} - }) - + trace.identify( + {"user": self.user, "organization": {"id": "org-123", "name": "Basalt"}} + ) + # Assert trace was identified correctly self.assertEqual(trace.user, self.user) self.assertEqual(trace.organization, {"id": "org-123", "name": "Basalt"}) @@ -251,16 +247,18 @@ def test_trace_identify(self): class TestMonitorSDKIntegration(unittest.TestCase): """Integration tests for the MonitorSDK with PromptSDK.""" - + def setUp(self): """Set up test fixtures before each test method.""" self.monitor = MonitorSDK(mocked_api, logger) self.openai = MockOpenAI() - + # Common test data self.user = {"id": "user456", "name": "Jane Smith"} - self.query = "What are the best practices for machine learning model deployment?" - + self.query = ( + "What are the best practices for machine learning model deployment?" + ) + def test_prompt_generation_integration(self): """Test the integration between PromptSDK and MonitorSDK.""" # Create a main trace @@ -271,18 +269,20 @@ def test_prompt_generation_integration(self): "user": self.user, "organization": {"id": "org-456", "name": "Basalt Testing"}, "metadata": {"source": "test", "environment": "development"}, - "name": "Prompt Generation Test" - } + "name": "Prompt Generation Test", + }, ) - + # Create a span for the prompt generation - prompt_span = main_trace.create_log({ - "type": "span", - "name": "prompt-retrieval", - "input": self.query, - "metadata": {"action": "retrieve-prompt"} - }) - + prompt_span = main_trace.create_log( + { + "type": "span", + "name": "prompt-retrieval", + "input": self.query, + "metadata": {"action": "retrieve-prompt"}, + } + ) + # Mock the API response for get_prompt mock_prompt_response = GetPromptEndpointResponse( warning=None, @@ -298,104 +298,106 @@ def test_prompt_generation_integration(self): "temperature": 0.7, "topP": 1, "maxLength": 4096, - "responseFormat": "text" - } - ) - ) + "responseFormat": "text", + }, + ), + ), ) - + # Create a mock API for PromptSDK prompt_api = MagicMock() prompt_api.invoke.return_value = (None, mock_prompt_response) - + # Create a PromptSDK instance from basalt.utils.memcache import MemoryCache - + # Create a PromptSDK instance prompt_sdk = PromptSDK( api=prompt_api, cache=MemoryCache(), fallback_cache=MemoryCache(), - logger=logger + logger=logger, ) - + # Get prompt from Basalt err, prompt_response, generation = prompt_sdk.get( - "ml-best-practices", + "ml-best-practices", variables={"topic": "machine learning", "question": self.query}, - version="1.0" + version="1.0", ) - + # Verify the prompt was retrieved successfully self.assertIsNone(err) self.assertIsNotNone(prompt_response) self.assertIsNotNone(generation) - + # Verify prompt response properties expected_text = "Answer the following question about machine learning: What are the best practices for machine learning model deployment?" self.assertEqual(prompt_response.text, expected_text) self.assertEqual(prompt_response.model.provider, "open-ai") self.assertEqual(prompt_response.model.model, "gpt-4o") - + # Verify generation object properties self.assertEqual(generation.prompt["slug"], "ml-best-practices") self.assertEqual(generation.prompt["version"], "1.0") - self.assertEqual(generation.input, "Answer the following question about {{topic}}: {{question}}") - self.assertEqual(generation.variables, [ - {"label": "topic", "value": "machine learning"}, - {"label": "question", "value": self.query} - ]) + self.assertEqual( + generation.input, + "Answer the following question about {{topic}}: {{question}}", + ) + self.assertEqual( + generation.variables, + [ + {"label": "topic", "value": "machine learning"}, + {"label": "question", "value": self.query}, + ], + ) self.assertEqual(generation.options["type"], "single") - + # End the prompt span prompt_span.end(prompt_response.text) - + # Create a span for the model generation - model_span = main_trace.create_log({ - "type": "span", - "name": "model-generation", - "input": prompt_response.text, - "metadata": {"action": "generate-response"} - }) - + model_span = main_trace.create_log( + { + "type": "span", + "name": "model-generation", + "input": prompt_response.text, + "metadata": {"action": "generate-response"}, + } + ) + # Generate text using OpenAI model_response = self.openai.generate_text(prompt_response.text) - + # Update the generation with the output generation.end(model_response) - + # End the model span model_span.end(model_response) - + # End the main trace main_trace.end("Completed prompt generation test") - + # Verify trace structure # Filter logs to only include those of type "span" span_logs = [log for log in main_trace.logs if log.type == "span"] self.assertEqual(len(span_logs), 2) self.assertEqual(span_logs[0], prompt_span) self.assertEqual(span_logs[1], model_span) - + def test_complex_workflow(self): """Test a complex workflow with multiple generations and spans.""" # Create a main trace main_trace = self.monitor.create_trace( "complex-workflow-test", - { - "input": self.query, - "user": self.user, - "name": "Complex Workflow Test" - } + {"input": self.query, "user": self.user, "name": "Complex Workflow Test"}, ) - + # Step 1: Content generation - generation_span = main_trace.create_log({ - "type": "span", - "name": "content-generation", - "input": self.query - }) - + generation_span = main_trace.create_log( + {"type": "span", "name": "content-generation", "input": self.query} + ) + # Mock the API response for get_prompt mock_generate_prompt_response = GetPromptEndpointResponse( warning=None, @@ -411,58 +413,56 @@ def test_complex_workflow(self): "temperature": 0.7, "topP": 1, "maxLength": 4096, - "responseFormat": "text" - } - ) - ) + "responseFormat": "text", + }, + ), + ), ) - + # Create a mock API for PromptSDK prompt_api = MagicMock() prompt_api.invoke.return_value = (None, mock_generate_prompt_response) - + # Create a PromptSDK instance from basalt.utils.memcache import MemoryCache + prompt_sdk = PromptSDK( api=prompt_api, cache=MemoryCache(), fallback_cache=MemoryCache(), - logger=logger + logger=logger, ) - + # Get prompt from Basalt err, prompt_response, generation = prompt_sdk.get( - "generate-content", - variables={"query": self.query}, - version="1.0" + "generate-content", variables={"query": self.query}, version="1.0" ) - + # Create generation log - generation_log = generation_span.create_generation({ - "name": "text-generation", - "input": self.query, - "prompt": {"slug": "generate-content", "version": "1.0"}, - "variables": [{"label": "query", "value": self.query}] - }) - + generation_log = generation_span.create_generation( + { + "name": "text-generation", + "input": self.query, + "prompt": {"slug": "generate-content", "version": "1.0"}, + "variables": [{"label": "query", "value": self.query}], + } + ) + # Generate text generated_text = self.openai.generate_text(prompt_response.text) - + # Update generation with output - generation_log.update({ - "output": generated_text, - "metadata": {"processingTime": 500} - }) - + generation_log.update( + {"output": generated_text, "metadata": {"processingTime": 500}} + ) + generation_span.end(generated_text) - + # Step 2: Classification - classification_span = main_trace.create_log({ - "type": "span", - "name": "classification", - "input": generated_text - }) - + classification_span = main_trace.create_log( + {"type": "span", "name": "classification", "input": generated_text} + ) + # Mock the API response for classification prompt mock_classify_prompt_response = GetPromptEndpointResponse( warning=None, @@ -478,50 +478,48 @@ def test_complex_workflow(self): "temperature": 0.3, "topP": 1, "maxLength": 2048, - "responseFormat": "text" - } - ) - ) + "responseFormat": "text", + }, + ), + ), ) - + # Update the mock API for the classification prompt prompt_api.invoke.return_value = (None, mock_classify_prompt_response) - + # Get prompt from Basalt err, classify_prompt_response, classify_generation = prompt_sdk.get( - "classify-content", - variables={"content": generated_text}, - version="1.0" + "classify-content", variables={"content": generated_text}, version="1.0" ) - + # Create generation log - class_gen = classification_span.create_generation({ - "name": "content-classification", - "input": generated_text, - "prompt": {"slug": "classify-content", "version": "1.0"}, - "variables": [{"label": "content", "value": generated_text}] - }) - + class_gen = classification_span.create_generation( + { + "name": "content-classification", + "input": generated_text, + "prompt": {"slug": "classify-content", "version": "1.0"}, + "variables": [{"label": "content", "value": generated_text}], + } + ) + # Classify content categories = self.openai.classify_content(generated_text) - + # Update generation with output - class_gen.update({ - "output": categories - }) - + class_gen.update({"output": categories}) + classification_span.end(categories) - + # End the main trace main_trace.end("Workflow completed") - + # Verify trace structure # Filter logs to only include those of type "span" span_logs = [log for log in main_trace.logs if log.type == "span"] self.assertEqual(len(span_logs), 2) self.assertEqual(span_logs[0], generation_span) self.assertEqual(span_logs[1], classification_span) - + # Verify outputs self.assertEqual(generation_span.output, generated_text) self.assertEqual(classification_span.output, categories) @@ -529,4 +527,4 @@ def test_complex_workflow(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_networker.py b/tests/test_networker.py index 7d0098b..0b186d1 100644 --- a/tests/test_networker.py +++ b/tests/test_networker.py @@ -5,63 +5,68 @@ from basalt.utils.networker import Networker from basalt.utils.errors import NetworkBaseError, FetchError -class TestNetworker(unittest.TestCase): - - @patch('requests.request') - def test_uses_requests_to_make_http_calls(self, request_mock): - networker = Networker() - - networker.fetch('http://test/abc', 'GET') - - request_mock.assert_called_once_with('GET', 'http://test/abc', params=None, json=None, headers=None) - - @patch('requests.request') - def test_captures_requests_exceptions(self, request_mock): - networker = Networker() - request_mock.side_effect = Exception('Some unknown error') - - err, res = networker.fetch('http://test/abc', 'GET') - - self.assertIsNone(res) - self.assertEqual(err.message, 'Some unknown error') - self.assertIsInstance(err, NetworkBaseError) - - @patch('requests.request') - def test_rejects_non_json_responses(self, request_mock): - networker = Networker() - request_mock.return_value = Mock() - request_mock.return_value.json.side_effect = Exception('No JSON object could be decoded') - - err, res = networker.fetch('http://test/abc', 'GET') - - self.assertIsNone(res) - self.assertIsInstance(err, FetchError) - - @patch('requests.request') - def test_returns_valid_json_as_dict(self, request_mock): - networker = Networker() - request_mock.return_value = Mock() - request_mock.return_value.json.return_value = { "some": "data" } - - err, res = networker.fetch('http://test/abc', 'GET') - - self.assertIsNone(err) - self.assertEqual(res, { "some": "data" }) - - @parameterized.expand([ - (400, 'BadRequest'), - (401, 'Unauthorized'), - (403, 'Forbidden'), - (404, 'NotFound'), - ]) - @patch('requests.request') - def test_uses_custom_errors(self, response_code, error_type, request_mock): - networker = Networker() - request_mock.return_value = Mock() - request_mock.return_value.status_code = response_code - - err, _ = networker.fetch('http://test/abc', 'GET') - - self.assertIsInstance(err, FetchError) - self.assertEqual(type(err).__name__, error_type) +class TestNetworker(unittest.TestCase): + @patch("requests.request") + def test_uses_requests_to_make_http_calls(self, request_mock): + networker = Networker() + + networker.fetch("http://test/abc", "GET") + + request_mock.assert_called_once_with( + "GET", "http://test/abc", params=None, json=None, headers=None + ) + + @patch("requests.request") + def test_captures_requests_exceptions(self, request_mock): + networker = Networker() + request_mock.side_effect = Exception("Some unknown error") + + err, res = networker.fetch("http://test/abc", "GET") + + self.assertIsNone(res) + self.assertEqual(err.message, "Some unknown error") + self.assertIsInstance(err, NetworkBaseError) + + @patch("requests.request") + def test_rejects_non_json_responses(self, request_mock): + networker = Networker() + request_mock.return_value = Mock() + request_mock.return_value.json.side_effect = Exception( + "No JSON object could be decoded" + ) + + err, res = networker.fetch("http://test/abc", "GET") + + self.assertIsNone(res) + self.assertIsInstance(err, FetchError) + + @patch("requests.request") + def test_returns_valid_json_as_dict(self, request_mock): + networker = Networker() + request_mock.return_value = Mock() + request_mock.return_value.json.return_value = {"some": "data"} + + err, res = networker.fetch("http://test/abc", "GET") + + self.assertIsNone(err) + self.assertEqual(res, {"some": "data"}) + + @parameterized.expand( + [ + (400, "BadRequest"), + (401, "Unauthorized"), + (403, "Forbidden"), + (404, "NotFound"), + ] + ) + @patch("requests.request") + def test_uses_custom_errors(self, response_code, error_type, request_mock): + networker = Networker() + request_mock.return_value = Mock() + request_mock.return_value.status_code = response_code + + err, _ = networker.fetch("http://test/abc", "GET") + + self.assertIsInstance(err, FetchError) + self.assertEqual(type(err).__name__, error_type) diff --git a/tests/test_promptsdk.py b/tests/test_promptsdk.py index 6ce29bb..ecb4c60 100644 --- a/tests/test_promptsdk.py +++ b/tests/test_promptsdk.py @@ -10,283 +10,271 @@ logger = Logger() mocked_api = MagicMock() -mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( - warning=None, - prompt=PromptResponse( - text="Some prompt", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) -)) +mocked_api.invoke.return_value = ( + None, + GetPromptEndpointResponse( + warning=None, + prompt=PromptResponse( + text="Some prompt", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ), + ), +) mocked_cache = MagicMock() mocked_cache.get.return_value = None fallback_cache = MagicMock() fallback_cache.get.return_value = None -class TestPromptSDK(unittest.TestCase): - def test_uses_correct_endpoint(self): - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - prompt.get("slug") - endpoint = mocked_api.invoke.call_args[0][0] - - self.assertEqual(endpoint, GetPromptEndpoint) - - @parameterized.expand([ - # (slug, version, tag) - ("slug", "version", "tag"), - ("slug", "version", None), - ("slug", None, "tag"), - ("slug", None, None), - ]) - def test_passes_correct_dto(self, slug, version, tag): - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - prompt.get(slug, version=version, tag=tag) - - dto = mocked_api.invoke.call_args[0][1] - - self.assertEqual( - dto, - GetPromptDTO(slug=slug, version=version, tag=tag) - ) - - def test_forwards_api_error(self): - mocked_api = MagicMock() - mocked_api.invoke.return_value = (Exception("Some error"), None) - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - err, res, generation = prompt.get("slug") - - self.assertIsInstance(err, Exception) - self.assertIsNone(res) - self.assertIsNone(generation) - - def test_replaces_variables(self): - mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( - warning=None, - prompt=PromptResponse( - text="Say hello {{name}}", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) - )) - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - _, prompt_response, generation = prompt.get("slug", variables={ "name": "Basalt" }) - - self.assertEqual(prompt_response.text, "Say hello Basalt") - self.assertIsInstance(generation, Generation) - self.assertEqual(generation.input, "Say hello {{name}}") - self.assertEqual(generation.prompt["slug"], "slug") - - def test_saves_raw_prompt_to_cache(self): - mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( - warning=None, - prompt=PromptResponse( - text="Say hello {{name}}", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) - )) - - mocked_cache = MagicMock() - mocked_cache.get.return_value = None - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - prompt.get("slug", variables={ "name": "Basalt" }) - - mocked_cache.put.assert_called_once() - - cached_value = mocked_cache.put.call_args[0][1] - - self.assertEqual(cached_value.text, "Say hello {{name}}") - - def test_does_not_request_when_cache_hit(self): - mocked_api = MagicMock() - - mocked_cache = MagicMock() - mocked_cache.get.return_value = PromptResponse( - text="Say hello {{name}}", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - err, res, generation = prompt.get("slug", variables={ "name": "Cached" }) - - mocked_api.invoke.assert_not_called() - - self.assertIsNone(err) - self.assertEqual(res.text, "Say hello Cached") - self.assertIsInstance(generation, Generation) - self.assertEqual(generation.input, "Say hello {{name}}") - - def test_caches_in_fallback_forever(self): - mocked_api = MagicMock() - mocked_api.invoke.return_value = (None, GetPromptEndpointResponse( - warning=None, - prompt=PromptResponse( - text="Say hello {{name}}", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) - )) - - mocked_cache = MagicMock() - mocked_cache.get.return_value = None - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - prompt.get("slug", variables={ "name": "Cached" }) - - fallback_cache.put.assert_called_once() - - def test_uses_fallback_cache_on_api_failure(self): - mocked_api = MagicMock() - mocked_api.invoke.return_value = (Exception("Some error"), None) - - fallback_cache = MagicMock() - fallback_cache.get.return_value = PromptResponse( - text="From fallback cache", - systemText="Some system prompt", - version="0.1", - model=PromptModel( - provider="open-ai", - model="gpt-4o", - version="latest", - parameters={ - "temperature": 0.7, - "topP": 1, - "maxLength": 4096, - "responseFormat": "text" - } - ) - ) - - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - _, res, generation = prompt.get("slug", variables={ "name": "Cached" }) - - fallback_cache.get.assert_called_once() - self.assertEqual(res.text, "From fallback cache") - self.assertIsInstance(generation, Generation) - - def test_returns_generation_object(self): - prompt = PromptSDK( - mocked_api, - cache=mocked_cache, - fallback_cache=fallback_cache, - logger=logger - ) - - _, _, generation = prompt.get("test-slug", version="1.0", tag="prod", variables={"key": "value"}) - - self.assertIsInstance(generation, Generation) - self.assertEqual(generation.prompt["slug"], "test-slug") - self.assertEqual(generation.prompt["version"], "1.0") - self.assertEqual(generation.prompt["tag"], "prod") - self.assertEqual(generation.variables, [{"label": "key", "value": "value"}]) - self.assertEqual(generation.options["type"], "single") +class TestPromptSDK(unittest.TestCase): + def test_uses_correct_endpoint(self): + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + prompt.get("slug") + endpoint = mocked_api.invoke.call_args[0][0] + + self.assertEqual(endpoint, GetPromptEndpoint) + + @parameterized.expand( + [ + # (slug, version, tag) + ("slug", "version", "tag"), + ("slug", "version", None), + ("slug", None, "tag"), + ("slug", None, None), + ] + ) + def test_passes_correct_dto(self, slug, version, tag): + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + prompt.get(slug, version=version, tag=tag) + + dto = mocked_api.invoke.call_args[0][1] + + self.assertEqual(dto, GetPromptDTO(slug=slug, version=version, tag=tag)) + + def test_forwards_api_error(self): + mocked_api = MagicMock() + mocked_api.invoke.return_value = (Exception("Some error"), None) + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + err, res, generation = prompt.get("slug") + + self.assertIsInstance(err, Exception) + self.assertIsNone(res) + self.assertIsNone(generation) + + def test_replaces_variables(self): + mocked_api = MagicMock() + mocked_api.invoke.return_value = ( + None, + GetPromptEndpointResponse( + warning=None, + prompt=PromptResponse( + text="Say hello {{name}}", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ), + ), + ) + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + _, prompt_response, generation = prompt.get( + "slug", variables={"name": "Basalt"} + ) + + self.assertEqual(prompt_response.text, "Say hello Basalt") + self.assertIsInstance(generation, Generation) + self.assertEqual(generation.input, "Say hello {{name}}") + self.assertEqual(generation.prompt["slug"], "slug") + + def test_saves_raw_prompt_to_cache(self): + mocked_api = MagicMock() + mocked_api.invoke.return_value = ( + None, + GetPromptEndpointResponse( + warning=None, + prompt=PromptResponse( + text="Say hello {{name}}", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ), + ), + ) + + mocked_cache = MagicMock() + mocked_cache.get.return_value = None + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + prompt.get("slug", variables={"name": "Basalt"}) + + mocked_cache.put.assert_called_once() + + cached_value = mocked_cache.put.call_args[0][1] + + self.assertEqual(cached_value.text, "Say hello {{name}}") + + def test_does_not_request_when_cache_hit(self): + mocked_api = MagicMock() + + mocked_cache = MagicMock() + mocked_cache.get.return_value = PromptResponse( + text="Say hello {{name}}", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ) + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + err, res, generation = prompt.get("slug", variables={"name": "Cached"}) + + mocked_api.invoke.assert_not_called() + + self.assertIsNone(err) + self.assertEqual(res.text, "Say hello Cached") + self.assertIsInstance(generation, Generation) + self.assertEqual(generation.input, "Say hello {{name}}") + + def test_caches_in_fallback_forever(self): + mocked_api = MagicMock() + mocked_api.invoke.return_value = ( + None, + GetPromptEndpointResponse( + warning=None, + prompt=PromptResponse( + text="Say hello {{name}}", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ), + ), + ) + + mocked_cache = MagicMock() + mocked_cache.get.return_value = None + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + prompt.get("slug", variables={"name": "Cached"}) + + fallback_cache.put.assert_called_once() + + def test_uses_fallback_cache_on_api_failure(self): + mocked_api = MagicMock() + mocked_api.invoke.return_value = (Exception("Some error"), None) + + fallback_cache = MagicMock() + fallback_cache.get.return_value = PromptResponse( + text="From fallback cache", + systemText="Some system prompt", + version="0.1", + model=PromptModel( + provider="open-ai", + model="gpt-4o", + version="latest", + parameters={ + "temperature": 0.7, + "topP": 1, + "maxLength": 4096, + "responseFormat": "text", + }, + ), + ) + + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + _, res, generation = prompt.get("slug", variables={"name": "Cached"}) + + fallback_cache.get.assert_called_once() + self.assertEqual(res.text, "From fallback cache") + self.assertIsInstance(generation, Generation) + + def test_returns_generation_object(self): + prompt = PromptSDK( + mocked_api, cache=mocked_cache, fallback_cache=fallback_cache, logger=logger + ) + + _, _, generation = prompt.get( + "test-slug", version="1.0", tag="prod", variables={"key": "value"} + ) + + self.assertIsInstance(generation, Generation) + self.assertEqual(generation.prompt["slug"], "test-slug") + self.assertEqual(generation.prompt["version"], "1.0") + self.assertEqual(generation.prompt["tag"], "prod") + self.assertEqual(generation.variables, [{"label": "key", "value": "value"}]) + self.assertEqual(generation.options["type"], "single") diff --git a/tests/test_send_trace_endpoint.py b/tests/test_send_trace_endpoint.py index a35c3f4..04b7579 100644 --- a/tests/test_send_trace_endpoint.py +++ b/tests/test_send_trace_endpoint.py @@ -34,9 +34,9 @@ def test_prepare_request_with_full_trace(self): "input": {"log": "input"}, "output": {"log": "output"}, "prompt": "test prompt", - "variables": [{"label": "var1", "value": "value1"}] + "variables": [{"label": "var1", "value": "value1"}], } - ] + ], } result = SendTraceEndpoint().prepare_request({"trace": trace}) @@ -44,7 +44,7 @@ def test_prepare_request_with_full_trace(self): # Verify the basic request structure self.assertEqual(result["method"], "post") self.assertEqual(result["path"], "/monitor/trace") - + # Verify the body contains all required fields body = result["body"] self.assertEqual(body["chainSlug"], "test-chain") @@ -72,12 +72,7 @@ def test_prepare_request_with_full_trace(self): self.assertEqual(log["variables"], [{"label": "var1", "value": "value1"}]) def test_decode_valid_response(self): - response = { - "trace": { - "id": "trace-123", - "status": "success" - } - } + response = {"trace": {"id": "trace-123", "status": "success"}} exception, decoded = SendTraceEndpoint().decode_response(response) @@ -89,4 +84,6 @@ def test_decode_invalid_response(self): self.assertIsNotNone(exception) self.assertIsNone(decoded) - self.assertEqual(str(exception), "Failed to decode response (invalid body format)") + self.assertEqual( + str(exception), "Failed to decode response (invalid body format)" + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ba71c46..486ba85 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,61 +2,82 @@ from parameterized import parameterized from basalt.utils.utils import replace_variables, pick_typed, pick_number -class TestUtils(unittest.TestCase): - - @parameterized.expand([ - # Replaces a variable - ('Hello {{name}}', { "name": 'Basalt' }, (set([]), 'Hello Basalt')), - # Replaces a multiple occurences of a variable - ('Hello {{name}} {{name}}', { "name": 'Basalt' }, (set([]), 'Hello Basalt Basalt')), +class TestUtils(unittest.TestCase): + @parameterized.expand( + [ + # Replaces a variable + ("Hello {{name}}", {"name": "Basalt"}, (set([]), "Hello Basalt")), + # Replaces a multiple occurences of a variable + ( + "Hello {{name}} {{name}}", + {"name": "Basalt"}, + (set([]), "Hello Basalt Basalt"), + ), + # Can replace multiple variables + ("{{a}} + {{b}} = {{c}}", {"a": 1, "b": 2, "c": 3}, (set([]), "1 + 2 = 3")), + # Can replace multiple variables + ( + "{{a}} + {{b}} = {{c {{c}} }}", + {"a": 1, "b": 2, "c": 3}, + (set(["c {{c"]), "1 + 2 = {{c {{c}} }}"), + ), + # Doesn't replace or empty missing variables + ("Hello {{missing}}", {}, (set(["missing"]), "Hello {{missing}}")), + ] + ) + def test_replace_variables(self, str_value, vars, expected): + self.assertEqual(replace_variables(str_value, vars), expected) - # Can replace multiple variables - ('{{a}} + {{b}} = {{c}}', { "a": 1, "b": 2, "c": 3 }, (set([]), '1 + 2 = 3')), - - # Can replace multiple variables - ('{{a}} + {{b}} = {{c {{c}} }}', { "a": 1, "b": 2, "c": 3 }, (set(["c {{c"]), '1 + 2 = {{c {{c}} }}')), - - # Doesn't replace or empty missing variables - ('Hello {{missing}}', {}, (set(["missing"]), 'Hello {{missing}}')) - ]) - def test_replace_variables(self, str_value, vars, expected): - self.assertEqual(replace_variables(str_value, vars), expected) + @parameterized.expand( + [ + ({"a": 1}, int, True), + ({"a": 1}, float, False), + ({"a": 2.0}, float, True), + ({"a": "1"}, str, True), + ({"a": 1}, bool, False), + ({"a": 0}, bool, False), + ( + { + "a": True, + }, + bool, + True, + ), + ({"a": False}, bool, True), + ({"a": False}, int, False), + ({"a": True}, int, False), + ] + ) + def test_pick_typed(self, d, expected_type, succeeds): + if succeeds: + val = pick_typed(d, "a", expected_type) + self.assertIsInstance(val, expected_type) + else: + with self.assertRaises(Exception): + pick_typed(d, "a", expected_type) - @parameterized.expand([ - ({ "a": 1 }, int, True), - ({ "a": 1 }, float, False), - ({ "a": 2. }, float, True), - ({ "a": "1" }, str, True), - ({ "a": 1 }, bool, False), - ({ "a": 0 }, bool, False), - ({ "a": True, }, bool, True), - ({ "a": False }, bool, True), - ({ "a": False }, int, False), - ({ "a": True }, int, False), - ]) - def test_pick_typed(self, d, expected_type, succeeds): - if succeeds: - val = pick_typed(d, "a", expected_type) - self.assertIsInstance(val, expected_type) - else: - with self.assertRaises(Exception): - pick_typed(d, "a", expected_type) - - @parameterized.expand([ - ({ "a": 1 }, int), - ({ "a": 1. }, float), - ({ "a": 2.78998 }, float), - ({ "a": -1 }, int), - ({ "a": -1.9 }, float), - ({ "a": "1" }, None), - ({ "a": True, }, None), - ({ "a": False }, None), - ]) - def test_pick_number(self, d, expected_type): - if expected_type: - val = pick_number(d, "a") - self.assertIsInstance(val, expected_type) - else: - with self.assertRaises(Exception): - pick_number(d, "a") \ No newline at end of file + @parameterized.expand( + [ + ({"a": 1}, int), + ({"a": 1.0}, float), + ({"a": 2.78998}, float), + ({"a": -1}, int), + ({"a": -1.9}, float), + ({"a": "1"}, None), + ( + { + "a": True, + }, + None, + ), + ({"a": False}, None), + ] + ) + def test_pick_number(self, d, expected_type): + if expected_type: + val = pick_number(d, "a") + self.assertIsInstance(val, expected_type) + else: + with self.assertRaises(Exception): + pick_number(d, "a")