diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000..39e9a7a --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(uv run ruff check:*)" + ] + } +} diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 7fc99fd..750ce34 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -20,22 +20,29 @@ jobs: steps: - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6 with: python-version: ${{ matrix.python-version }} + - name: Install hatch + run: uv tool install hatch + - name: Install dependencies working-directory: ./ run: | - python -m pip install --upgrade pip - pip install pytest pytest-cov anyio pytest-asyncio - pip install -e . + uv pip install --system --editable . + uv pip install --system pytest pytest-cov anyio pytest-asyncio - name: Run tests working-directory: ./ run: | - pytest tests/ --cov=basalt --cov-report=term --cov-report=xml + hatch run test echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_OUTPUT echo "COVERAGE_PCT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(f'{float(root.attrib[\"line-rate\"]) * 100:.2f}')")" >> $GITHUB_OUTPUT echo "TEST_COUNT=$(python -c "import xml.etree.ElementTree as ET; tree = ET.parse('coverage.xml'); root = tree.getroot(); print(root.find('.//metrics').attrib['tests'])")" >> $GITHUB_OUTPUT diff --git a/.gitignore b/.gitignore index 23e68ac..0eb063f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,10 @@ venv/ .env* requirements* .serena +.hatch/ +*.pyc +.pytest_cache/ +.ruff_cache/ +coverage.xml +.coverage +htmlcov/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 0000000..47cc54d --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,268 @@ +# Development Guide + +This project uses **Hatch** for build/packaging/environments, **uv** for fast dependency management, and **ruff** for linting/formatting. + +## Prerequisites + +1. **Python 3.10+** (Python 3.12 recommended, specified in `.python-version`) +2. **uv** - Fast Python package installer + ```bash + curl -LsSf https://astral.sh/uv/install.sh | sh + ``` +3. **Hatch** - Modern Python project manager + ```bash + uv tool install hatch + # or: pipx install hatch + ``` + +## Quick Start + +### Install the package in development mode + +```bash +# Using uv (recommended) +uv pip install --editable . + +# Or using pip +pip install -e . +``` + +### Install with all optional dependencies + +```bash +uv pip install --editable ".[dev]" +# or for all LLM/vector/framework instrumentations: +uv pip install --editable ".[all,dev]" +``` + +### Using uv's lock file + +```bash +# Sync dependencies from uv.lock +uv sync + +# Update lock file +uv lock + +# Install with specific extras +uv sync --extra dev --extra openai +``` + +## Development Workflows + +Hatch provides convenient commands for common development tasks: + +### Running Tests + +```bash +# Run all tests with coverage +hatch run test + +# Run tests verbosely +hatch run test-verbose + +# Run tests on specific Python version +hatch run +py=3.10 test +hatch run +py=3.11 test +hatch run +py=3.12 test +hatch run +py=3.13 test +hatch run +py=3.14 test + +# Run tests across all Python versions (matrix) +hatch run test:run +``` + +### Code Quality + +```bash +# Check code with ruff +hatch run lint + +# Auto-fix linting issues +hatch run lint-fix + +# Format code with ruff +hatch run fmt + +# Check formatting without modifying +hatch run fmt-check + +# Type check with mypy +hatch run typecheck + +# Run all quality checks + tests +hatch run all +``` + +### Building + +```bash +# Build wheel and source distribution +hatch build + +# Build only wheel +hatch build --target wheel + +# Build only source distribution +hatch build --target sdist + +# Build to specific directory +hatch build --outdir dist/ +``` + +### Version Management + +Hatch can automatically manage versions in `basalt/_version.py`: + +```bash +# Show current version +hatch version + +# Bump patch version (1.1.0 -> 1.1.1) +hatch version patch + +# Bump minor version (1.1.0 -> 1.2.0) +hatch version minor + +# Bump major version (1.1.0 -> 2.0.0) +hatch version major + +# Set specific version +hatch version 1.2.3 +``` + +### Publishing to PyPI + +```bash +# Build and publish to PyPI (requires credentials) +hatch publish + +# Build and publish to TestPyPI +hatch publish -r test + +# Dry run (build without publishing) +hatch build +``` + +## Hatch Environments + +Hatch creates isolated virtual environments for different purposes: + +- **default** - Development environment with pytest, ruff, mypy +- **test** - Matrix testing across Python 3.10-3.14 +- **full** - Environment with all optional dependencies for comprehensive testing + +You can access environments directly: + +```bash +# Enter a shell in the default environment +hatch shell + +# Run a command in a specific environment +hatch run full:test + +# Show all environments +hatch env show +``` + +## Project Structure + +``` +basalt-python/ +├── basalt/ # Main package source +│ ├── __init__.py +│ ├── _version.py # Version file (managed by Hatch) +│ ├── client.py +│ └── ... +├── tests/ # Test suite +├── docs/ # Documentation +├── examples/ # Example scripts +├── pyproject.toml # Project metadata and configuration +├── uv.lock # Locked dependencies +├── .python-version # Python version for uv +└── README.md +``` + +## CI/CD + +GitHub Actions automatically: +- Runs tests on Python 3.10, 3.11, 3.12, 3.13, 3.14 +- Uses uv for fast dependency installation +- Uses hatch for running test suite +- Generates coverage reports + +## Optional Dependencies + +The SDK provides many optional instrumentation packages: + +### LLM Providers +```bash +pip install basalt-sdk[openai] +pip install basalt-sdk[anthropic] +pip install basalt-sdk[google-generativeai] +pip install basalt-sdk[bedrock] +pip install basalt-sdk[vertex-ai] +pip install basalt-sdk[mistralai] +# Or all at once: +pip install basalt-sdk[llm-all] +``` + +### Vector Databases +```bash +pip install basalt-sdk[chromadb] +pip install basalt-sdk[pinecone] +pip install basalt-sdk[qdrant] +# Or all at once: +pip install basalt-sdk[vector-all] +``` + +### Frameworks +```bash +pip install basalt-sdk[langchain] +pip install basalt-sdk[llamaindex] +# Or all at once: +pip install basalt-sdk[framework-all] +``` + +### Everything +```bash +pip install basalt-sdk[all] +``` + +## Troubleshooting + +### Clean build artifacts + +```bash +# Remove all build artifacts +rm -rf dist/ build/ *.egg-info .hatch/ + +# Rebuild from scratch +hatch build +``` + +### Reset environments + +```bash +# Remove all Hatch environments +hatch env prune + +# Recreate default environment +hatch env create +``` + +### Lock file issues + +```bash +# Regenerate uv.lock +uv lock --upgrade + +# Force reinstall all dependencies +uv sync --reinstall +``` + +## Additional Resources + +- [Hatch Documentation](https://hatch.pypa.io/) +- [uv Documentation](https://docs.astral.sh/uv/) +- [Ruff Documentation](https://docs.astral.sh/ruff/) diff --git a/PROJECT_INDEX.md b/PROJECT_INDEX.md new file mode 100644 index 0000000..2b1c509 --- /dev/null +++ b/PROJECT_INDEX.md @@ -0,0 +1,423 @@ +# Project Index: basalt-python + +**Generated:** 2026-01-22 +**Version:** 1.1.0 +**Language:** Python 3.10+ + +## 📊 Repository Overview + +The Basalt SDK is a comprehensive Python client for managing AI prompts, monitoring AI applications, and tracking experiments via OpenTelemetry. The SDK provides async/sync APIs for prompts, datasets, and experiments with built-in observability. + +**Key Statistics:** +- Total Python files: 78 +- Repository size: ~45MB (excluding .venv) +- Main package: `basalt/` (26 modules) +- Test coverage: 35+ test files +- Examples: 8 examples + 2 notebooks + +--- + +## 📁 Project Structure + +``` +basalt-python/ +├── basalt/ # Main SDK package +│ ├── __init__.py # Package entry (lazy imports) +│ ├── client.py # Main Basalt client +│ ├── config.py # Global configuration +│ ├── _version.py # Version: 1.1.0 +│ ├── _internal/ # Internal utilities +│ │ ├── base_client.py # Base API client +│ │ └── http.py # HTTP client wrapper +│ ├── observability/ # OpenTelemetry integration (14 modules) +│ │ ├── __init__.py # Observability facade +│ │ ├── api.py # High-level observe decorators +│ │ ├── config.py # TelemetryConfig +│ │ ├── decorators.py # @observe, @evaluate +│ │ ├── context_managers.py # Span context managers +│ │ ├── trace.py # Trace API +│ │ ├── trace_context.py # Identity & experiment tracking +│ │ ├── instrumentation.py # LLM/DB auto-instrumentation +│ │ ├── processors.py # OTEL span processors +│ │ ├── spans.py # Basalt span wrappers +│ │ ├── evaluators.py # Custom evaluators +│ │ ├── request_tracing.py # Request span tracking +│ │ ├── resilient_exporters.py # Error-resilient exporters +│ │ ├── semconv.py # Semantic conventions +│ │ ├── types.py # Type definitions +│ │ └── utils.py # Utility functions +│ ├── prompts/ # Prompts API client +│ │ ├── __init__.py +│ │ ├── client.py # PromptsClient (list, get, describe) +│ │ └── models.py # Prompt, PromptResponse models +│ ├── datasets/ # Datasets API client +│ │ ├── __init__.py +│ │ ├── client.py # DatasetsClient (list, get, add_row) +│ │ ├── models.py # Dataset, DatasetRow models +│ │ └── file_upload.py # File attachment handling +│ ├── experiments/ # Experiments API client +│ │ ├── __init__.py +│ │ ├── client.py # ExperimentsClient +│ │ └── models.py # Experiment models +│ ├── types/ # Shared types +│ │ ├── __init__.py +│ │ ├── exceptions.py # API exceptions +│ │ └── cache.py # Cache protocol +│ └── utils/ # Utilities +│ ├── __init__.py +│ └── memcache.py # Memory cache implementation +├── tests/ # Test suite (35+ files) +│ ├── conftest.py # Pytest fixtures +│ ├── observability/ # Observability tests (14 files) +│ ├── prompts/ # Prompts API tests +│ ├── datasets/ # Datasets API tests +│ ├── experiments/ # Experiments API tests +│ ├── internal/ # Internal tests +│ └── otel/ # OTEL integration tests +├── examples/ # Usage examples +│ ├── openai_example.py # OpenAI + observability +│ ├── async_observe_example.py # Async decorators +│ ├── dataset_api_example.py # Dataset operations +│ ├── gemini_random_data_example.py # Gemini integration +│ ├── multi_exporter_example.py # Multiple OTLP exporters +│ ├── internal_api.py # Internal API demo +│ ├── prompt_sdk_demo.ipynb # Prompt SDK notebook +│ └── dataset_sdk_demo.ipynb # Dataset SDK notebook +├── docs/ # Documentation (13 guides) +│ ├── 01-introduction.md +│ ├── 02-getting-started.md +│ ├── 03-prompts.md +│ ├── 04-datasets.md +│ ├── 05-observability.md +│ ├── 06-manual-tracing.md +│ ├── 07-llm-tracing.md +│ ├── 08-async-observability.md +│ ├── 09-auto-instrumentation.md +│ ├── 10-evaluators.md +│ ├── 11-experiments.md +│ ├── 12-user-org-tracking.md +│ └── 13-trace-context.md +├── pyproject.toml # Project config (Hatch) +├── README.md # Main documentation +├── DEVELOPMENT.md # Development guide +├── AGENTS.md # Agent instructions +└── renovate.json # Dependency updates + +``` + +--- + +## 🚀 Entry Points + +### 1. **Main SDK Entry**: `basalt/__init__.py` +- Exports: `Basalt`, `TelemetryConfig`, `__version__` +- Uses lazy imports via `__getattr__` for faster startup + +### 2. **Primary Client**: `basalt/client.py` +- **Class**: `Basalt` +- **Services**: `prompts`, `datasets`, `experiments` +- **Features**: HTTP client, telemetry config, instrumentation, global metadata +- **Key method**: `shutdown()` - flushes telemetry + +### 3. **CLI/Programmatic Usage**: +```python +from basalt import Basalt, TelemetryConfig +basalt = Basalt(api_key="...", telemetry_config=TelemetryConfig(...)) +``` + +--- + +## 📦 Core Modules + +### **basalt.client** - Main Client +- **Exports**: `Basalt` class +- **Purpose**: Central SDK entry point, orchestrates sub-clients +- **Dependencies**: HTTP client, instrumentation, telemetry config + +### **basalt.observability** - Telemetry & Tracing +- **Exports**: + - Decorators: `@observe`, `@start_observe`, `@evaluate` + - Context managers: `LLMSpanHandle`, `RetrievalSpanHandle`, etc. + - Config: `TelemetryConfig` + - API: `Trace`, `trace`, `TraceIdentity`, `TraceExperiment` +- **Purpose**: OpenTelemetry integration with LLM-specific semantics +- **Key features**: + - Auto-instrumentation for OpenAI, Anthropic, Gemini, Bedrock, etc. + - Custom evaluators and processors + - Identity/experiment tracking across traces + - Resilient exporters with error handling + +### **basalt.prompts** - Prompts API +- **Exports**: `PromptsClient`, `Prompt`, `PromptResponse` +- **Methods**: + - `list_sync()` / `list_async()` - List all prompts + - `get_sync(slug, tag?, version?, variables?)` - Get prompt + - `describe_sync(slug)` - Get metadata +- **Purpose**: Fetch and render prompts from Basalt API + +### **basalt.datasets** - Datasets API +- **Exports**: `DatasetsClient`, `Dataset`, `DatasetRow` +- **Methods**: + - `list_sync()` / `list_async()` - List datasets + - `get_sync(slug)` - Get dataset with rows + - `add_row_sync(slug, data, attachments?)` - Add row with file uploads +- **Purpose**: Manage evaluation/test datasets + +### **basalt.experiments** - Experiments API +- **Exports**: `ExperimentsClient`, `Experiment` +- **Methods**: + - `create_sync(name, description)` - Create experiment +- **Purpose**: Track A/B tests and experiments + +### **basalt.types.exceptions** - Exception Hierarchy +- **Base**: `BasaltAPIError` +- **Specific**: `NotFoundError`, `UnauthorizedError`, `NetworkError` +- **Purpose**: Type-safe error handling + +### **basalt._internal.http** - HTTP Client +- **Exports**: `HTTPClient` +- **Features**: Sync/async requests, auth headers, error handling +- **Purpose**: Shared HTTP transport for all API clients + +--- + +## 🔧 Configuration + +### **pyproject.toml** - Project Metadata +- **Build system**: Hatchling +- **Python**: >=3.10 +- **Core deps**: OpenTelemetry, httpx, jinja2, wrapt +- **Optional deps**: + - LLM providers (10): openai, anthropic, google-generativeai, etc. + - Vector DBs (3): chromadb, pinecone, qdrant + - Frameworks (2): langchain, llamaindex + +### **basalt/config.py** - Runtime Config +- Default API base URL +- Environment variable parsing +- Global settings + +### **basalt/observability/config.py** - Telemetry Config +- **Class**: `TelemetryConfig` +- **Fields**: service_name, environment, trace_content, enabled_providers +- **Purpose**: Centralized OTEL configuration + +--- + +## 📚 Documentation + +| File | Topic | +|------|-------| +| `01-introduction.md` | SDK overview | +| `02-getting-started.md` | Installation & setup | +| `03-prompts.md` | Prompts API guide | +| `04-datasets.md` | Datasets API guide | +| `05-observability.md` | Telemetry overview | +| `06-manual-tracing.md` | Custom span creation | +| `07-llm-tracing.md` | LLM provider tracing | +| `08-async-observability.md` | Async patterns | +| `09-auto-instrumentation.md` | Auto-instrumentation setup | +| `10-evaluators.md` | Custom evaluators | +| `11-experiments.md` | Experiment tracking | +| `12-user-org-tracking.md` | Identity tracking | +| `13-trace-context.md` | Context propagation | + +--- + +## 🧪 Test Coverage + +### Test Organization +- **Total test files**: 35+ +- **Coverage areas**: API clients, observability, OTEL integration, decorators + +### Key Test Modules +| Module | Focus | +|--------|-------| +| `tests/prompts/` | Prompts API, context managers | +| `tests/datasets/` | Datasets API, file uploads | +| `tests/experiments/` | Experiments API | +| `tests/observability/` | Decorators, spans, processors, evaluators | +| `tests/otel/` | OTLP export, LLM instrumentation | +| `tests/internal/` | HTTP client | + +### Running Tests +```bash +# Via Hatch +hatch run test + +# Via pytest directly +pytest tests/ --cov=basalt +``` + +--- + +## 🔗 Key Dependencies + +| Dependency | Version | Purpose | +|------------|---------|---------| +| opentelemetry-api | ~1.39.1 | OTEL core API | +| opentelemetry-sdk | ~1.39.1 | OTEL SDK | +| opentelemetry-exporter-otlp | ~1.39.1 | OTLP exporter | +| opentelemetry-instrumentation | ~0.59b0 | Auto-instrumentation base | +| opentelemetry-instrumentation-httpx | ~0.59b0 | HTTP tracing | +| httpx | >=0.28.1 | Async HTTP client | +| jinja2 | >=3.1.6 | Prompt template rendering | +| wrapt | ~1.17.3 | Decorator utilities | +| pytest | - | Testing framework | +| ruff | - | Linting & formatting | + +--- + +## 📝 Quick Start + +### 1. Installation +```bash +pip install basalt-sdk[openai,anthropic] # With LLM providers +``` + +### 2. Basic Usage +```python +from basalt import Basalt + +basalt = Basalt(api_key="your-key") + +# Get a prompt +prompt = basalt.prompts.get_sync("my-prompt") +print(prompt.text) + +# Get a dataset +dataset = basalt.datasets.get_sync("my-dataset") +for row in dataset.rows: + print(row.data) + +# Shutdown (flush telemetry) +basalt.shutdown() +``` + +### 3. Observability +```python +from basalt.observability import observe, start_observe + +@start_observe(name="process_workflow", feature_slug="main") +def main(): + result = generate_text() + return result + +@observe(kind="generation", name="llm.generate") +def generate_text(): + # Your LLM call here + return "Generated text" +``` + +### 4. Run Tests +```bash +hatch run test +``` + +--- + +## 🎯 Architecture Patterns + +### 1. **Lazy Imports** +- `basalt/__init__.py` uses `__getattr__` to defer imports +- Reduces startup time, avoids loading unused dependencies + +### 2. **Base Client Pattern** +- `BaseServiceClient` provides common API functionality +- Subclassed by `PromptsClient`, `DatasetsClient`, `ExperimentsClient` + +### 3. **Observability Facade** +- High-level API (`@observe`) wraps low-level OTEL primitives +- Context managers (`LLMSpanHandle`) simplify span management + +### 4. **Resilient Exporters** +- `ResilientOTLPExporter` wraps OTLP exporter with error handling +- Prevents telemetry failures from breaking app logic + +### 5. **Identity Propagation** +- `TraceIdentity` attaches user/org metadata to root spans +- Automatically propagates to child spans via context + +--- + +## 🔍 Key Symbols Reference + +### Classes +- `Basalt` - Main SDK client (`basalt/client.py:23`) +- `TelemetryConfig` - Telemetry configuration (`basalt/observability/config.py`) +- `PromptsClient` - Prompts API (`basalt/prompts/client.py`) +- `DatasetsClient` - Datasets API (`basalt/datasets/client.py`) +- `ExperimentsClient` - Experiments API (`basalt/experiments/client.py`) +- `HTTPClient` - HTTP transport (`basalt/_internal/http.py`) +- `InstrumentationManager` - Auto-instrumentation (`basalt/observability/instrumentation.py`) +- `Trace` - Low-level trace API (`basalt/observability/trace.py`) + +### Decorators +- `@observe` - Create nested span (`basalt/observability/api.py`) +- `@start_observe` - Create root span with identity (`basalt/observability/api.py`) +- `@evaluate` - Attach evaluators (`basalt/observability/decorators.py`) + +### Exceptions +- `BasaltAPIError` - Base exception (`basalt/types/exceptions.py`) +- `NotFoundError` - 404 errors +- `UnauthorizedError` - 401 errors +- `NetworkError` - Connection failures + +--- + +## 📊 Token Efficiency Impact + +**Before indexing**: Full codebase read = ~58,000 tokens per session +**After indexing**: Read this index = ~3,000 tokens (94% reduction) + +**Expected savings**: +- 10 sessions: 550,000 tokens saved +- 100 sessions: 5,500,000 tokens saved + +--- + +## 🔄 Development Workflow + +### Formatting & Linting +```bash +hatch run fmt # Format with ruff +hatch run lint # Lint code +hatch run lint-fix # Auto-fix lint issues +``` + +### Testing +```bash +hatch run test # Run tests with coverage +hatch run test-verbose # Verbose output +``` + +### Type Checking +```bash +hatch run typecheck # Run mypy +``` + +### All Checks +```bash +hatch run all # fmt + lint-fix + typecheck + test +``` + +--- + +## 📌 Important Files + +| File | Purpose | +|------|---------| +| `pyproject.toml` | Project config, dependencies, Hatch scripts | +| `basalt/__init__.py` | Main package entry (lazy imports) | +| `basalt/client.py` | Primary SDK client | +| `basalt/observability/__init__.py` | Observability facade | +| `basalt/observability/config.py` | Telemetry configuration | +| `basalt/types/exceptions.py` | Exception hierarchy | +| `README.md` | User-facing documentation | +| `DEVELOPMENT.md` | Contributor guide | +| `tests/conftest.py` | Shared pytest fixtures | + +--- + +**End of Index** +*This index provides a comprehensive overview of the basalt-python repository structure, enabling efficient navigation and reducing token usage in future sessions.* diff --git a/basalt/__init__.py b/basalt/__init__.py index db27db3..621fb80 100644 --- a/basalt/__init__.py +++ b/basalt/__init__.py @@ -13,6 +13,7 @@ prompt = await basalt.prompts.get("my-prompt") ``` """ + from typing import TYPE_CHECKING from ._version import __version__ @@ -30,8 +31,10 @@ def __getattr__(name: str): if name == "Basalt": from .client import Basalt # imported only when accessed + return Basalt if name == "TelemetryConfig": from .observability.config import TelemetryConfig # imported only when accessed + return TelemetryConfig raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/basalt/_internal/base_client.py b/basalt/_internal/base_client.py index a8c1fcd..c85409e 100644 --- a/basalt/_internal/base_client.py +++ b/basalt/_internal/base_client.py @@ -6,9 +6,22 @@ import logging import os from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict -from .http import HTTPClient +try: + from typing import Unpack +except ImportError: # pragma: no cover - fallback for Python < 3.11 + from typing_extensions import Unpack + +from .http import HTTPClient, HTTPResponse + + +class HTTPRequestKwargs(TypedDict, total=False): + """Keyword arguments passed through to HTTPClient.fetch methods.""" + + body: Any + params: Mapping[str, str] | None + headers: Mapping[str, str] | None class BaseServiceClient: @@ -52,8 +65,8 @@ async def _request_async( span_attributes: Mapping[str, Any] | None = None, span_variables: Mapping[str, Any] | None = None, cache_hit: bool | None = None, - **request_kwargs: Any, - ): + **request_kwargs: Unpack[HTTPRequestKwargs], + ) -> HTTPResponse | None: # Lazy import to avoid circular dependency from basalt.observability.request_tracing import trace_async_request from basalt.observability.spans import BasaltRequestSpan @@ -84,8 +97,8 @@ def _request_sync( span_attributes: Mapping[str, Any] | None = None, span_variables: Mapping[str, Any] | None = None, cache_hit: bool | None = None, - **request_kwargs: Any, - ): + **request_kwargs: Unpack[HTTPRequestKwargs], + ) -> HTTPResponse | None: # Lazy import to avoid circular dependency from basalt.observability.request_tracing import trace_sync_request from basalt.observability.spans import BasaltRequestSpan diff --git a/basalt/_internal/http.py b/basalt/_internal/http.py index f9627ee..31539fc 100644 --- a/basalt/_internal/http.py +++ b/basalt/_internal/http.py @@ -7,10 +7,13 @@ import time from collections.abc import Iterator, Mapping from dataclasses import dataclass -from typing import Any, Literal +from types import TracebackType +from typing import Any, Literal, cast import httpx +from basalt.types.common import JSONValue + from ..types.exceptions import ( BadRequestError, ForbiddenError, @@ -44,12 +47,12 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self.data or {}) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> object: if not self.data: raise KeyError(key) - return self.data[key] + return cast(object, self.data[key]) - def get(self, key: str, default: Any = None) -> Any: + def get(self, key: str, default: object | None = None) -> object | None: if not self.data: return default return self.data.get(key, default) @@ -78,7 +81,7 @@ def __init__( retry_backoff_factor: float = DEFAULT_RETRY_BACKOFF_FACTOR, async_client: httpx.AsyncClient | None = None, sync_client: httpx.Client | None = None, - ): + ) -> None: """ Initialize HTTPClient with configuration options. @@ -100,29 +103,39 @@ def __init__( self._owns_async_client = async_client is None self._owns_sync_client = sync_client is None - async def __aenter__(self): + async def __aenter__(self) -> HTTPClient: """Async context manager entry.""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Async context manager exit.""" await self.aclose() - def __enter__(self): + def __enter__(self) -> HTTPClient: """Sync context manager entry.""" return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Sync context manager exit.""" self.close() - async def aclose(self): + async def aclose(self) -> None: """Close the async session.""" if self._async_client and self._owns_async_client: await self._async_client.aclose() self._async_client = None - def close(self): + def close(self) -> None: """Close the sync session.""" if self._sync_client and self._owns_sync_client: self._sync_client.close() @@ -144,7 +157,7 @@ async def fetch( self, url: str, method: str | HTTPMethod, - body: Any | None = None, + body: JSONValue = None, params: Mapping[str, str] | None = None, headers: Mapping[str, str] | None = None, ) -> HTTPResponse | None: @@ -187,13 +200,15 @@ async def fetch( except (BadRequestError, UnauthorizedError, ForbiddenError, NotFoundError): # Don't retry client errors raise - except (httpx.TimeoutException, asyncio.TimeoutError, httpx.TransportError) as e: + except (TimeoutError, httpx.TimeoutException, httpx.TransportError) as e: # Retry on transient errors if attempt == self.max_retries - 1: - raise NetworkError(f"Request failed after {self.max_retries} attempts: {e}") from e + raise NetworkError( + f"Request failed after {self.max_retries} attempts: {e}" + ) from e # Exponential backoff - wait_time = self.retry_backoff_factor * (2 ** attempt) + wait_time = self.retry_backoff_factor * (2**attempt) await asyncio.sleep(wait_time) except Exception as e: raise NetworkError(str(e)) from e @@ -205,7 +220,7 @@ def fetch_sync( self, url: str, method: str | HTTPMethod, - body: Any | None = None, + body: JSONValue = None, params: Mapping[str, str] | None = None, headers: Mapping[str, str] | None = None, ) -> HTTPResponse | None: @@ -251,10 +266,12 @@ def fetch_sync( except (httpx.TimeoutException, httpx.TransportError) as e: # Retry on transient errors if attempt == self.max_retries - 1: - raise NetworkError(f"Request failed after {self.max_retries} attempts: {e}") from e + raise NetworkError( + f"Request failed after {self.max_retries} attempts: {e}" + ) from e # Exponential backoff - wait_time = self.retry_backoff_factor * (2 ** attempt) + wait_time = self.retry_backoff_factor * (2**attempt) time.sleep(wait_time) except Exception as e: raise NetworkError(str(e)) from e @@ -274,9 +291,7 @@ def _handle_response(response: httpx.Response) -> HTTPResponse | None: headers_obj = getattr(response, "headers", {}) if isinstance(headers_obj, Mapping): raw_content_type = ( - headers_obj.get("content-type") - or headers_obj.get("Content-Type") - or "" + headers_obj.get("content-type") or headers_obj.get("Content-Type") or "" ) else: raw_content_type = str(headers_obj or "") diff --git a/basalt/client.py b/basalt/client.py index a402bf2..775595e 100644 --- a/basalt/client.py +++ b/basalt/client.py @@ -3,6 +3,7 @@ This module provides the main Basalt client class for interacting with the Basalt API. """ + from __future__ import annotations from typing import Any @@ -40,10 +41,7 @@ class Basalt: basalt = Basalt(api_key="your-api-key", telemetry_config=telemetry) # Or use client-level parameters for simple cases - basalt = Basalt( - api_key="your-api-key", - enabled_instruments=["openai", "anthropic"] - ) + basalt = Basalt(api_key="your-api-key", enabled_instruments=["openai", "anthropic"]) ``` """ @@ -55,11 +53,11 @@ def __init__( enable_telemetry: bool = True, base_url: str | None = None, observability_metadata: dict[str, Any] | None = None, - cache : CacheProtocol | None = None, + cache: CacheProtocol | None = None, log_level: str | None = None, enabled_instruments: list[str] | None = None, disabled_instruments: list[str] | None = None, - ): + ) -> None: """ Initialize the Basalt client. @@ -157,7 +155,7 @@ def experiments(self) -> ExperimentsClient: """ return self._experiments_client - def shutdown(self): + def shutdown(self) -> None: """ Shutdown the client and flush any pending telemetry data. diff --git a/basalt/config.py b/basalt/config.py index 634c50f..f294b40 100644 --- a/basalt/config.py +++ b/basalt/config.py @@ -6,8 +6,10 @@ config: dict[str, str] = { - 'api_url': 'http://localhost:3001' if build == 'development' else 'https://api.getbasalt.ai', - 'otel_endpoint': 'http://127.0.0.1:4317' if build == 'development' else 'https://grpc.otel.getbasalt.ai', - 'sdk_version': __version__, - 'sdk_type': 'python', + "api_url": "http://localhost:3001" if build == "development" else "https://api.getbasalt.ai", + "otel_endpoint": "http://127.0.0.1:4317" + if build == "development" + else "https://grpc.otel.getbasalt.ai", + "sdk_version": __version__, + "sdk_type": "python", } diff --git a/basalt/datasets/__init__.py b/basalt/datasets/__init__.py index 906404e..911c27f 100644 --- a/basalt/datasets/__init__.py +++ b/basalt/datasets/__init__.py @@ -5,6 +5,7 @@ with the Basalt Datasets API. The client is lazily imported to avoid module-level circular dependencies during initialization. """ + from typing import TYPE_CHECKING, Any if TYPE_CHECKING: # pragma: no cover @@ -16,7 +17,7 @@ __all__ = ["DatasetsClient", "Dataset", "DatasetRow", "FileAttachment"] -def __getattr__(name: str) -> Any: +def __getattr__(name: str) -> object: if name == "DatasetsClient": from .client import DatasetsClient diff --git a/basalt/datasets/client.py b/basalt/datasets/client.py index 7926371..4d5474c 100644 --- a/basalt/datasets/client.py +++ b/basalt/datasets/client.py @@ -3,6 +3,7 @@ This module provides the DatasetsClient for interacting with the Basalt Datasets API. """ + from __future__ import annotations from typing import Any @@ -28,7 +29,7 @@ def __init__( base_url: str | None = None, http_client: HTTPClient | None = None, log_level: str | None = None, - ): + ) -> None: """ Initialize the DatasetsClient. @@ -38,7 +39,7 @@ def __init__( log_level: Optional log level for the client logger. """ self._api_key = api_key - self._base_url = base_url or config.get("api_url") + self._base_url = base_url or config["api_url"] super().__init__(client_name="datasets", http_client=http_client, log_level=log_level) # Initialize file upload handler @@ -48,6 +49,73 @@ def __init__( api_key=self._api_key, ) + @staticmethod + def _ensure_response(response: object) -> object: + if response is None: + raise BasaltAPIError("Empty response from dataset API") + return response + + def _dataset_from_response(self, response: object) -> Dataset: + # The response is expected to be an HTTP response object with a .json() method. + # Add a runtime check and type ignore for static analysis tools. + response = self._ensure_response(response) + if not hasattr(response, "json") or not callable(getattr(response, "json", None)): + raise BasaltAPIError("Response object does not have a callable .json() method") + payload = response.json() # type: ignore[attr-defined] + payload = payload or {} + + dataset_data = payload.get("dataset", {}) + dataset = Dataset.from_dict(dataset_data) + + # Log warning if present + if warning := payload.get("warning"): + self._logger.warning("Dataset API warning: %s", warning) + + return dataset + + def _dataset_items_url(self, slug: str) -> str: + return f"{self._base_url}/datasets/{slug}/items" + + def _dataset_row_from_response(self, response: object) -> DatasetRow: + # The response is expected to be an HTTP response object with a .json() method. + response = self._ensure_response(response) + if not hasattr(response, "json") or not callable(getattr(response, "json", None)): + raise BasaltAPIError("Response object does not have a callable .json() method") + payload = response.json() # type: ignore[attr-defined] + payload = payload or {} + row_data = payload.get("datasetRow", {}) + + # Log warning if present + if warning := payload.get("warning"): + self._logger.warning("Dataset API warning: %s", warning) + + return DatasetRow.from_dict(row_data) + + def _build_dataset_row_request( + self, + slug: str, + processed_values: dict[str, str], + name: str | None, + ideal_output: str | None, + metadata: dict[str, Any] | None, + ) -> tuple[str, dict[str, Any], dict[str, Any]]: + url = self._dataset_items_url(slug) + body: dict[str, Any] = { + "values": processed_values, + } + if name is not None: + body["name"] = name + if ideal_output is not None: + body["idealOutput"] = ideal_output + if metadata is not None: + body["metadata"] = metadata + + span_attributes = { + "basalt.dataset.slug": slug, + "basalt.dataset.row_name": name, + } + return url, body, span_attributes + async def list(self) -> list[Dataset]: """ List all datasets available in the workspace. @@ -72,11 +140,9 @@ async def list(self) -> list[Dataset]: return [] datasets_data = response.get("datasets", []) - return [ - Dataset.from_dict(ds) - for ds in datasets_data - if isinstance(ds, dict) - ] + if not isinstance(datasets_data, list): + return [] + return [Dataset.from_dict(ds) for ds in datasets_data if isinstance(ds, dict)] def list_sync(self) -> list[Dataset]: """ @@ -102,11 +168,9 @@ def list_sync(self) -> list[Dataset]: return [] datasets_data = response.get("datasets", []) - return [ - Dataset.from_dict(ds) - for ds in datasets_data - if isinstance(ds, dict) - ] + if not isinstance(datasets_data, list): + return [] + return [Dataset.from_dict(ds) for ds in datasets_data if isinstance(ds, dict)] async def get(self, slug: str) -> Dataset: """ @@ -132,19 +196,7 @@ async def get(self, slug: str) -> Dataset: span_attributes={"basalt.dataset.slug": slug}, ) - if response is None: - raise BasaltAPIError("Empty response from dataset API") - - payload = response.json() or {} - - dataset_data = payload.get("dataset", {}) - dataset = Dataset.from_dict(dataset_data) - - # Log warning if present - if warning := payload.get("warning"): - self._logger.warning("Dataset API warning: %s", warning) - - return dataset + return self._dataset_from_response(response) def get_sync(self, slug: str) -> Dataset: """ @@ -170,19 +222,7 @@ def get_sync(self, slug: str) -> Dataset: span_attributes={"basalt.dataset.slug": slug}, ) - if response is None: - raise BasaltAPIError("Empty response from dataset API") - - payload = response.json() or {} - - dataset_data = payload.get("dataset", {}) - dataset = Dataset.from_dict(dataset_data) - - # Log warning if present - if warning := payload.get("warning"): - self._logger.warning("Dataset API warning: %s", warning) - - return dataset + return self._dataset_from_response(response) async def _process_file_uploads( self, values: dict[str, str | FileAttachment] @@ -212,9 +252,7 @@ async def _process_file_uploads( return processed - def _process_file_uploads_sync( - self, values: dict[str, str | FileAttachment] - ) -> dict[str, str]: + def _process_file_uploads_sync(self, values: dict[str, str | FileAttachment]) -> dict[str, str]: """ Process file uploads and return values with S3 keys (synchronous version). @@ -272,17 +310,9 @@ async def add_row( # Process file uploads first processed_values = await self._process_file_uploads(values) - url = f"{self._base_url}/datasets/{slug}/items" - - body: dict[str, Any] = { - "values": processed_values, - } - if name is not None: - body["name"] = name - if ideal_output is not None: - body["idealOutput"] = ideal_output - if metadata is not None: - body["metadata"] = metadata + url, body, span_attributes = self._build_dataset_row_request( + slug, processed_values, name, ideal_output, metadata + ) response = await self._request_async( "add_row", @@ -290,24 +320,10 @@ async def add_row( url=url, body=body, headers=self._get_headers(), - span_attributes={ - "basalt.dataset.slug": slug, - "basalt.dataset.row_name": name, - }, + span_attributes=span_attributes, ) - if response is None: - raise BasaltAPIError("Empty response from dataset add row API") - - payload = response.json() or {} - - row_data = payload.get("datasetRow", {}) - - # Log warning if present - if warning := payload.get("warning"): - self._logger.warning("Dataset API warning: %s", warning) - - return DatasetRow.from_dict(row_data) + return self._dataset_row_from_response(response) def add_row_sync( self, @@ -341,17 +357,9 @@ def add_row_sync( # Process file uploads first processed_values = self._process_file_uploads_sync(values) - url = f"{self._base_url}/datasets/{slug}/items" - - body: dict[str, Any] = { - "values": processed_values, - } - if name is not None: - body["name"] = name - if ideal_output is not None: - body["idealOutput"] = ideal_output - if metadata is not None: - body["metadata"] = metadata + url, body, span_attributes = self._build_dataset_row_request( + slug, processed_values, name, ideal_output, metadata + ) response = self._request_sync( "add_row", @@ -359,24 +367,10 @@ def add_row_sync( url=url, body=body, headers=self._get_headers(), - span_attributes={ - "basalt.dataset.slug": slug, - "basalt.dataset.row_name": name, - }, + span_attributes=span_attributes, ) - if response is None: - raise BasaltAPIError("Empty response from dataset add row API") - - payload = response.json() or {} - - row_data = payload.get("datasetRow", {}) - - # Log warning if present - if warning := payload.get("warning"): - self._logger.warning("Dataset API warning: %s", warning) - - return DatasetRow.from_dict(row_data) + return self._dataset_row_from_response(response) def _get_headers(self) -> dict[str, str]: """ diff --git a/basalt/datasets/file_upload.py b/basalt/datasets/file_upload.py index e893368..86e8cf9 100644 --- a/basalt/datasets/file_upload.py +++ b/basalt/datasets/file_upload.py @@ -8,7 +8,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO +from typing import TYPE_CHECKING, Any, BinaryIO, cast import httpx @@ -56,7 +56,7 @@ class FileAttachment: content_type: str | None = None filename: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: """Validate the attachment at construction time.""" # Validate source type if not isinstance(self.source, (str, Path, bytes, io.BytesIO, io.BufferedReader)): @@ -74,7 +74,9 @@ def __post_init__(self): # File-like object with name attribute self.filename = os.path.basename(str(self.source.name)) else: - raise FileValidationError("filename is required for file-like objects without a name attribute") + raise FileValidationError( + "filename is required for file-like objects without a name attribute" + ) @dataclass @@ -159,15 +161,17 @@ def _read_file_bytes(source: str | Path | bytes | BinaryIO) -> bytes: else: # File-like object if hasattr(source, "read"): + # Cast to BinaryIO for type checker - we've verified it has read() + file_obj = cast(BinaryIO, source) # Save current position - current_pos = source.tell() if hasattr(source, "tell") else None + current_pos = file_obj.tell() if hasattr(file_obj, "tell") else None # Seek to beginning if possible - if hasattr(source, "seek"): - source.seek(0) - data = source.read() + if hasattr(file_obj, "seek"): + file_obj.seek(0) + data = file_obj.read() # Restore position if possible - if current_pos is not None and hasattr(source, "seek"): - source.seek(current_pos) + if current_pos is not None and hasattr(file_obj, "seek"): + file_obj.seek(current_pos) if isinstance(data, bytes): return data else: @@ -202,7 +206,7 @@ def _validate_file_size(file_bytes: bytes) -> None: class FileUploadHandler: """Handles file validation, presigned URL requests, and S3 uploads.""" - def __init__(self, http_client: HTTPClient, base_url: str, api_key: str): + def __init__(self, http_client: HTTPClient, base_url: str, api_key: str) -> None: """ Initialize FileUploadHandler. @@ -271,7 +275,7 @@ async def request_presigned_url( BasaltAPIError: If the API request fails """ url = f"{self._base_url}/files/generate-upload-url" - body = {"fileName": filename, "contentType": content_type} + body: dict[str, Any] = {"fileName": filename, "contentType": content_type} headers = {"Authorization": f"Bearer {self._api_key}"} self._logger.debug( @@ -303,7 +307,7 @@ def request_presigned_url_sync( BasaltAPIError: If the API request fails """ url = f"{self._base_url}/files/generate-upload-url" - body = {"fileName": filename, "contentType": content_type} + body: dict[str, Any] = {"fileName": filename, "contentType": content_type} headers = {"Authorization": f"Bearer {self._api_key}"} self._logger.debug( @@ -318,9 +322,7 @@ def request_presigned_url_sync( return PresignedUploadResponse.from_dict(response.data) - async def upload_to_s3( - self, presigned_url: str, file_bytes: bytes, content_type: str - ) -> None: + async def upload_to_s3(self, presigned_url: str, file_bytes: bytes, content_type: str) -> None: """ Upload file to S3 using presigned URL. @@ -369,9 +371,7 @@ async def upload_to_s3( except Exception as e: raise FileUploadError(f"Unexpected error during upload: {e}") from e - def upload_to_s3_sync( - self, presigned_url: str, file_bytes: bytes, content_type: str - ) -> None: + def upload_to_s3_sync(self, presigned_url: str, file_bytes: bytes, content_type: str) -> None: """ Upload file to S3 using presigned URL (synchronous version). diff --git a/basalt/datasets/models.py b/basalt/datasets/models.py index bd204d6..df83938 100644 --- a/basalt/datasets/models.py +++ b/basalt/datasets/models.py @@ -4,6 +4,7 @@ This module contains all data models and data transfer objects used by the DatasetsClient. """ + from __future__ import annotations from collections.abc import Iterable, Mapping @@ -35,6 +36,7 @@ def from_dict(cls, data: Mapping[str, Any] | str) -> DatasetColumn: col_type = data.get("type") if isinstance(data.get("type"), str) else None return cls(name=str(name), type=col_type) + @dataclass(slots=True) class DatasetRow: """ @@ -46,6 +48,7 @@ class DatasetRow: ideal_output: Optional ideal output for evaluation. metadata: Optional metadata dictionary. """ + # store as a plain dict internally for fast access; accept Mapping inputs values: dict[str, str] name: str | None = None @@ -105,6 +108,7 @@ class Dataset: columns: List of column names in the dataset. rows: List of rows in the dataset. """ + slug: str name: str # public attributes; rows kept mutable, columns are immutable objects @@ -150,4 +154,6 @@ def from_dict( slug_val = data.get("slug") if isinstance(data.get("slug"), str) else "" name_val = data.get("name") if isinstance(data.get("name"), str) else "" - return cls(slug=str(slug_val), name=str(name_val), columns=column_definitions, rows=rows_list) + return cls( + slug=str(slug_val), name=str(name_val), columns=column_definitions, rows=rows_list + ) diff --git a/basalt/experiments/__init__.py b/basalt/experiments/__init__.py index 4df73ac..058e2d0 100644 --- a/basalt/experiments/__init__.py +++ b/basalt/experiments/__init__.py @@ -3,6 +3,7 @@ This module provides access to the Basalt Experiments API. """ + from .client import ExperimentsClient from .models import Experiment diff --git a/basalt/experiments/client.py b/basalt/experiments/client.py index 3fcc48b..38e8b06 100644 --- a/basalt/experiments/client.py +++ b/basalt/experiments/client.py @@ -3,6 +3,7 @@ This module provides the ExperimentsClient for interacting with the Basalt Experiments API. """ + from __future__ import annotations from typing import Any @@ -27,7 +28,7 @@ def __init__( base_url: str | None = None, http_client: HTTPClient | None = None, log_level: str | None = None, - ): + ) -> None: """ Initialize the ExperimentsClient. diff --git a/basalt/experiments/models.py b/basalt/experiments/models.py index 9677cf4..dedf36b 100644 --- a/basalt/experiments/models.py +++ b/basalt/experiments/models.py @@ -4,6 +4,7 @@ This module contains all data models and data transfer objects used by the ExperimentsClient. """ + from __future__ import annotations from collections.abc import Mapping @@ -24,6 +25,7 @@ class Experiment: feature_slug: The feature slug associated with the experiment. created_at: ISO 8601 timestamp of when the experiment was created. """ + id: str name: str feature_slug: str @@ -47,7 +49,9 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> Experiment: # Defensive reads with defaults id_val = data.get("id") if isinstance(data.get("id"), str) else "" name_val = data.get("name") if isinstance(data.get("name"), str) else "" - feature_slug_val = data.get("featureSlug") if isinstance(data.get("featureSlug"), str) else "" + feature_slug_val = ( + data.get("featureSlug") if isinstance(data.get("featureSlug"), str) else "" + ) created_at_val = data.get("createdAt") if isinstance(data.get("createdAt"), str) else "" return cls( diff --git a/basalt/observability/api.py b/basalt/observability/api.py index 6dbb38f..c0666ef 100644 --- a/basalt/observability/api.py +++ b/basalt/observability/api.py @@ -2,12 +2,14 @@ import functools import inspect -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import ContextDecorator -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from opentelemetry.trace import StatusCode +from ..types.common import JSONValue, SpanAttributeValue + if TYPE_CHECKING: from basalt.experiments.models import Experiment from basalt.prompts.models import Prompt @@ -50,6 +52,77 @@ F = TypeVar("F", bound=Callable[..., Any]) +def _resolve_experiment_id(experiment: str | Experiment | None) -> str | None: + """Resolve an experiment identifier from supported experiment types.""" + if not experiment: + return None + if isinstance(experiment, str): + return experiment + exp_id = getattr(experiment, "id", None) + if isinstance(exp_id, str) and exp_id: + return exp_id + return None + + +def _get_observe_config_for_kind( + kind_str: str, +) -> tuple[type[SpanHandle], str, Callable[[Any], Any] | None, Callable[[Any], Any] | None]: + """Return handle class, tracer name, and default resolvers for the kind.""" + if kind_str == "generation": + return ( + LLMSpanHandle, + "basalt.observability.generation", + default_generation_input, + default_generation_variables, + ) + if kind_str == "retrieval": + return ( + RetrievalSpanHandle, + "basalt.observability.retrieval", + default_retrieval_input, + default_retrieval_variables, + ) + if kind_str == "tool": + return ( + ToolSpanHandle, + "basalt.observability.tool", + None, + None, + ) + if kind_str == "function": + return ( + FunctionSpanHandle, + "basalt.observability.function", + None, + None, + ) + if kind_str == "event": + return ( + EventSpanHandle, + "basalt.observability.event", + None, + None, + ) + return ( + SpanHandle, + "basalt.observability", + None, + None, + ) + + +def _resolve_kind_str(kind: ObserveKind | str) -> str: + if isinstance(kind, ObserveKind): + return kind.value + kind_str = str(kind).lower() + valid_kinds = {k.value for k in ObserveKind} + if kind_str not in valid_kinds: + raise ValueError( + f"Invalid kind '{kind_str}'. Must be one of: {', '.join(sorted(valid_kinds))}" + ) + return kind_str + + class StartObserve(ContextDecorator): """ Entry point for Basalt observability. @@ -66,7 +139,7 @@ def __init__( evaluators: Sequence[Any] | None = None, experiment: str | Experiment | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: # Validate feature_slug is provided and non-empty if not feature_slug or not isinstance(feature_slug, str) or not feature_slug.strip(): raise ValueError( @@ -98,6 +171,7 @@ def __enter__(self) -> StartSpanHandle: user_identity, org_identity = resolve_identity_payload(self.identity_resolver, None) # Initialize context manager + experiment_id = _resolve_experiment_id(self.experiment) self._ctx_manager = _with_span_handle( name=span_name, attributes=None, @@ -110,7 +184,7 @@ def __enter__(self) -> StartSpanHandle: feature_slug=self.feature_slug, metadata=self._metadata, evaluate_config=self.evaluate_config, - experiment=self.experiment, + experiment=experiment_id, ) span = self._ctx_manager.__enter__() # Type assertion: we know this is StartSpanHandle since we passed it as handle_cls @@ -127,7 +201,7 @@ def __enter__(self) -> StartSpanHandle: return self._span_handle - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: if self._ctx_manager: return self._ctx_manager.__exit__(exc_type, exc_value, traceback) return None @@ -141,6 +215,7 @@ def wrapper(*args, **kwargs): pre_evaluators = resolve_evaluators_payload(self.evaluators, bound) span_name = self.name + experiment_id = _resolve_experiment_id(self.experiment) with _with_span_handle( name=span_name, attributes=None, @@ -153,7 +228,7 @@ def wrapper(*args, **kwargs): feature_slug=self.feature_slug, metadata=self._metadata, evaluate_config=self.evaluate_config, - experiment=self.experiment, + experiment=experiment_id, ) as handle: # Type assertion: we know this is StartSpanHandle since we passed it as handle_cls assert isinstance(handle, StartSpanHandle) @@ -175,9 +250,14 @@ def wrapper(*args, **kwargs): if inspect.iscoroutinefunction(func): @functools.wraps(func) - async def async_wrapper(*args, **kwargs): + async def async_wrapper( + *args: object, + **kwargs: object, + ) -> object: bound = resolve_bound_arguments(func, args, kwargs) - user_identity, org_identity = resolve_identity_payload(self.identity_resolver, bound) + user_identity, org_identity = resolve_identity_payload( + self.identity_resolver, bound + ) pre_evaluators = resolve_evaluators_payload(self.evaluators, bound) span_name = self.name @@ -193,7 +273,7 @@ async def async_wrapper(*args, **kwargs): feature_slug=self.feature_slug, metadata=self._metadata, evaluate_config=self.evaluate_config, - experiment=self.experiment, + experiment=_resolve_experiment_id(self.experiment), ) as handle: # Type assertion: we know this is StartSpanHandle since we passed it as handle_cls assert isinstance(handle, StartSpanHandle) @@ -223,18 +303,10 @@ def _apply_experiment(self, span: StartSpanHandle | None) -> None: Supports either a string experiment ID or an `Experiment` dataclass instance from `basalt.experiments.models`. """ - if span is None or not self.experiment: + if span is None: return - exp_id: str | None = None - - # Handle string ID case - if isinstance(self.experiment, str): - exp_id = self.experiment - # Check if it's an Experiment dataclass by looking for the 'id' attribute - # (avoid isinstance to prevent circular import at runtime) - elif hasattr(self.experiment, "id"): - exp_id = self.experiment.id # type: ignore[attr-defined] + exp_id = _resolve_experiment_id(self.experiment) if not exp_id: return # Must have an id to attach @@ -257,11 +329,11 @@ def __init__( *, metadata: dict[str, Any] | None = None, evaluators: Sequence[Any] | None = None, - input: Any = None, - output: Any = None, - variables: dict[str, Any] | None = None, + input: JSONValue | Callable[[Any], JSONValue] = None, + output: Callable[[Any], JSONValue] | None = None, + variables: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, prompt: Prompt | None = None, - ): + ) -> None: # Validate name is provided and non-empty if not name or not isinstance(name, str) or not name.strip(): raise ValueError( @@ -282,61 +354,12 @@ def __init__( @staticmethod def _get_config_for_kind(kind_str: str): - """Return handle class, tracer name, and default resolvers for the kind.""" - if kind_str == "generation": - return ( - LLMSpanHandle, - "basalt.observability.generation", - default_generation_input, - default_generation_variables, - ) - elif kind_str == "retrieval": - return ( - RetrievalSpanHandle, - "basalt.observability.retrieval", - default_retrieval_input, - default_retrieval_variables, - ) - elif kind_str == "tool": - return ( - ToolSpanHandle, - "basalt.observability.tool", - None, - None, - ) - elif kind_str == "function": - return ( - FunctionSpanHandle, - "basalt.observability.function", - None, - None, - ) - elif kind_str == "event": - return ( - EventSpanHandle, - "basalt.observability.event", - None, - None, - ) - else: - return ( - SpanHandle, - "basalt.observability", - None, - None, - ) + return _get_observe_config_for_kind(kind_str) def __enter__(self) -> SpanHandle: span_name = self.name - if isinstance(self.kind, ObserveKind): - kind_str = self.kind.value - else: - kind_str = str(self.kind).lower() - # Validate that the string kind is valid - valid_kinds = {k.value for k in ObserveKind} - if kind_str not in valid_kinds: - raise ValueError(f"Invalid kind '{kind_str}'. Must be one of: {', '.join(sorted(valid_kinds))}") + kind_str = _resolve_kind_str(self.kind) # Reject ROOT kind if kind_str == ObserveKind.ROOT.value: @@ -371,7 +394,9 @@ def __enter__(self) -> SpanHandle: import logging logger = logging.getLogger(__name__) - logger.warning("Observe used without a preceding start_observe. This may lead to missing trace context.") + logger.warning( + "Observe used without a preceding start_observe. This may lead to missing trace context." + ) self._ctx_manager = _with_span_handle( name=span_name, @@ -389,7 +414,7 @@ def __enter__(self) -> SpanHandle: self._span_handle = self._ctx_manager.__enter__() return self._span_handle - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: if self._ctx_manager: return self._ctx_manager.__exit__(exc_type, exc_value, traceback) return None @@ -411,32 +436,49 @@ def __call__(self, func: F) -> F: # Use defaults if not provided input_resolver = self.input_resolver if self.input_resolver is not None else default_input - variables_resolver = self.variables_resolver if self.variables_resolver is not None else default_vars + variables_resolver = ( + self.variables_resolver if self.variables_resolver is not None else default_vars + ) # Process prompt parameter if provided prompt_metadata = {} if self.prompt is not None: import json + prompt = cast("Prompt", self.prompt) + # Override input resolver with prompt.text - def input_resolver(bound): - return self.prompt.text + def prompt_input_resolver(bound: inspect.BoundArguments | None) -> str: + return prompt.text + + input_resolver = prompt_input_resolver # Prepare prompt metadata for span attributes prompt_metadata = { - "basalt.prompt.slug": self.prompt.slug, - "basalt.prompt.version": self.prompt.version, - "basalt.prompt.model.provider": self.prompt.model.provider, - "basalt.prompt.model.model": self.prompt.model.model, + "basalt.prompt.slug": prompt.slug, + "basalt.prompt.version": prompt.version, + "basalt.prompt.model.provider": prompt.model.provider, + "basalt.prompt.model.model": prompt.model.model, } # Store prompt.variables separately if available (must serialize to JSON for OpenTelemetry) - if self.prompt.variables: - prompt_metadata["basalt.prompt.variables"] = json.dumps(self.prompt.variables) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - computed_metadata = resolve_attributes(self._metadata, args, kwargs) + if prompt.variables: + prompt_metadata["basalt.prompt.variables"] = json.dumps(prompt.variables) + + def prepare_call_data( + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[ + Mapping[str, Any] | None, + inspect.BoundArguments | None, + object, + Mapping[str, Any] | None, + list[Any] | None, + ]: + computed_metadata_raw = resolve_attributes(self._metadata, args, kwargs) + computed_metadata = ( + computed_metadata_raw if isinstance(computed_metadata_raw, Mapping) else None + ) bound = resolve_bound_arguments(func, args, kwargs) input_payload = resolve_payload_from_bound(input_resolver, bound) variables_payload = resolve_variables_payload(variables_resolver, bound) @@ -453,19 +495,27 @@ def wrapper(*args, **kwargs): "Observe used without a preceding start_observe. This may lead to missing trace context." ) - # Pre-hooks - def apply_pre(span, bound): - if kind_str == "generation" and isinstance(span, LLMSpanHandle): - apply_llm_request_metadata(span, bound) - elif kind_str == "retrieval" and isinstance(span, RetrievalSpanHandle): - query = resolve_payload_from_bound(input_resolver, bound) - if isinstance(query, str): - span.set_query(query) + return computed_metadata, bound, input_payload, variables_payload, pre_evaluators + + # Pre-hooks + def apply_pre(span, bound): + if kind_str == "generation" and isinstance(span, LLMSpanHandle): + apply_llm_request_metadata(span, bound) + elif kind_str == "retrieval" and isinstance(span, RetrievalSpanHandle): + query = resolve_payload_from_bound(input_resolver, bound) + if isinstance(query, str): + span.set_query(query) - # Post-hooks - def apply_post(span, result): - if kind_str == "generation" and isinstance(span, LLMSpanHandle): - apply_llm_response_metadata(span, result) + # Post-hooks + def apply_post(span, result): + if kind_str == "generation" and isinstance(span, LLMSpanHandle): + apply_llm_response_metadata(span, result) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + computed_metadata, bound, input_payload, variables_payload, pre_evaluators = ( + prepare_call_data(args, kwargs) + ) with _with_span_handle( name=self.name, @@ -473,7 +523,7 @@ def apply_post(span, result): tracer_name=tracer_name, handle_cls=handle_cls, span_type=kind_str, - input_payload=input_payload, + input_payload=cast("JSONValue | None", input_payload), variables=variables_payload, evaluators=pre_evaluators, metadata=computed_metadata, @@ -484,51 +534,31 @@ def apply_post(span, result): try: result = func(*args, **kwargs) - transformed = self.output_resolver(result) if self.output_resolver else result - span.set_output(transformed) + if self.output_resolver and callable(self.output_resolver): + transformed = self.output_resolver(result) + else: + transformed = result + span.set_output(cast("str | dict[str, Any]", transformed)) if apply_post: apply_post(span, result) return result except Exception: - span.set_output({"error": "Exception occurred"}) + error_output: dict[str, JSONValue] = {"error": "Exception occurred"} + span.set_output(error_output) raise if inspect.iscoroutinefunction(func): @functools.wraps(func) - async def async_wrapper(*args, **kwargs): - computed_metadata = resolve_attributes(self._metadata, args, kwargs) - bound = resolve_bound_arguments(func, args, kwargs) - input_payload = resolve_payload_from_bound(input_resolver, bound) - variables_payload = resolve_variables_payload(variables_resolver, bound) - pre_evaluators = resolve_evaluators_payload(self.evaluators, bound) - - # Check for root span - from opentelemetry import context as otel_context - - if not otel_context.get_value(ROOT_SPAN_CONTEXT_KEY): - import logging - - logger = logging.getLogger(__name__) - logger.warning( - "Observe used without a preceding start_observe. This may lead to missing trace context." - ) - - # Pre-hooks (same as sync) - def apply_pre(span, bound): - if kind_str == "generation" and isinstance(span, LLMSpanHandle): - apply_llm_request_metadata(span, bound) - elif kind_str == "retrieval" and isinstance(span, RetrievalSpanHandle): - query = resolve_payload_from_bound(input_resolver, bound) - if isinstance(query, str): - span.set_query(query) - - # Post-hooks (same as sync) - def apply_post(span, result): - if kind_str == "generation" and isinstance(span, LLMSpanHandle): - apply_llm_response_metadata(span, result) + async def async_wrapper( + *args: object, + **kwargs: object, + ) -> object: + computed_metadata, bound, input_payload, variables_payload, pre_evaluators = ( + prepare_call_data(args, kwargs) + ) with _with_span_handle( name=self.name, @@ -536,7 +566,7 @@ def apply_post(span, result): tracer_name=tracer_name, handle_cls=handle_cls, span_type=kind_str, - input_payload=input_payload, + input_payload=cast("JSONValue | None", input_payload), variables=variables_payload, evaluators=pre_evaluators, metadata=computed_metadata, @@ -547,8 +577,11 @@ def apply_post(span, result): try: result = await func(*args, **kwargs) - transformed = self.output_resolver(result) if self.output_resolver else result - span.set_output(transformed) + if self.output_resolver and callable(self.output_resolver): + transformed = self.output_resolver(result) + else: + transformed = result + span.set_output(cast("str | dict[str, Any]", transformed)) if apply_post: apply_post(span, result) @@ -565,7 +598,9 @@ def apply_post(span, result): # Static Domain Methods @staticmethod - def _identify(user: str | dict[str, Any] | None = None, organization: str | dict[str, Any] | None = None) -> None: + def _identify( + user: str | dict[str, Any] | None = None, organization: str | dict[str, Any] | None = None + ) -> None: """Set the user and/or organization identity for the current context.""" if user: if isinstance(user, str): @@ -577,7 +612,9 @@ def _identify(user: str | dict[str, Any] | None = None, organization: str | dict if isinstance(organization, str): _set_trace_organization(organization_id=organization) elif isinstance(organization, dict): - _set_trace_organization(organization_id=organization.get("id", "unknown"), name=organization.get("name")) + _set_trace_organization( + organization_id=organization.get("id", "unknown"), name=organization.get("name") + ) @staticmethod def _root_span() -> StartSpanHandle | None: @@ -592,14 +629,14 @@ def _root_span() -> StartSpanHandle | None: return get_root_span_handle() @staticmethod - def set_input(data: str | dict[str, Any]) -> None: + def set_input(data: JSONValue) -> None: """Set input data for the current span.""" handle = get_current_span_handle() if handle: handle.set_input(data) @staticmethod - def set_output(data: str | dict[str, Any]) -> None: + def set_output(data: JSONValue) -> None: """Set output data for the current span.""" handle = get_current_span_handle() if handle: @@ -667,20 +704,24 @@ def set_attributes(attributes: dict[str, Any]) -> None: @staticmethod def set_io( *, - input_payload: Any | None = None, - output_payload: Any | None = None, + input_payload: JSONValue = None, + output_payload: JSONValue = None, variables: dict[str, Any] | None = None, ) -> None: """Set input, output, and variables for the current span.""" handle = get_current_span_handle() if handle: - handle.set_io(input_payload=input_payload, output_payload=output_payload, variables=variables) + handle.set_io( + input_payload=cast("str | dict[str, Any] | None", input_payload), + output_payload=cast("str | dict[str, Any] | None", output_payload), + variables=variables, + ) @staticmethod def inject_for_auto_instrumentation( *, - input_payload: Any | None = None, - output_payload: Any | None = None, + input_payload: JSONValue = None, + output_payload: JSONValue = None, prompt: Prompt | None = None, metadata: dict[str, Any] | None = None, variables: dict[str, Any] | None = None, @@ -716,7 +757,7 @@ def inject_for_auto_instrumentation( Observe.inject_for_auto_instrumentation( input_payload={"query": "What is AI?"}, prompt=my_prompt, - metadata={"session_id": "abc123"} + metadata={"session_id": "abc123"}, ) # Next auto-instrumented call will have this data attached @@ -816,17 +857,19 @@ def set_identity(identity: Identity | None = None) -> None: Each key should contain a dict with 'id' (required) and 'name' (optional). Example: - >>> Observe.set_identity({ - ... "user": {"id": "user-123", "name": "John Doe"}, - ... "organization": {"id": "org-456", "name": "ACME Corp"} - ... }) + >>> Observe.set_identity( + ... { + ... "user": {"id": "user-123", "name": "John Doe"}, + ... "organization": {"id": "org-456", "name": "ACME Corp"}, + ... } + ... ) """ handle = get_current_span_handle() if handle: handle.set_identity(identity) @staticmethod - def set_attribute(key: str, value: Any) -> None: + def set_attribute(key: str, value: SpanAttributeValue) -> None: """Set a single attribute on the current span. Args: @@ -940,7 +983,7 @@ def __init__( evaluators: Sequence[Any] | None = None, experiment: str | Experiment | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: # Validate feature_slug is provided and non-empty if not feature_slug or not isinstance(feature_slug, str) or not feature_slug.strip(): raise ValueError( @@ -1001,24 +1044,17 @@ async def __aenter__(self) -> StartSpanHandle: return self._span_handle - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type, exc_value, traceback) -> bool | None: if self._ctx_manager: return await self._ctx_manager.__aexit__(exc_type, exc_value, traceback) return None def _apply_experiment(self, span: StartSpanHandle | None) -> None: """Apply experiment metadata to the provided span.""" - if span is None or not self.experiment: + if span is None: return - exp_id: str | None = None - - # Handle string ID case - if isinstance(self.experiment, str): - exp_id = self.experiment - # Check if it's an Experiment dataclass by looking for the 'id' attribute - elif hasattr(self.experiment, "id"): - exp_id = self.experiment.id # type: ignore[attr-defined] + exp_id = _resolve_experiment_id(self.experiment) if not exp_id: return # Must have an id to attach @@ -1041,11 +1077,11 @@ def __init__( *, metadata: dict[str, Any] | None = None, evaluators: Sequence[Any] | None = None, - input: Any = None, - output: Any = None, - variables: dict[str, Any] | None = None, + input: JSONValue | Callable[[Any], JSONValue] = None, + output: Callable[[Any], JSONValue] | None = None, + variables: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, prompt: Prompt | None = None, - ): + ) -> None: # Validate name is provided and non-empty if not name or not isinstance(name, str) or not name.strip(): raise ValueError( @@ -1066,61 +1102,12 @@ def __init__( @staticmethod def _get_config_for_kind(kind_str: str): - """Return handle class, tracer name, and default resolvers for the kind.""" - if kind_str == "generation": - return ( - LLMSpanHandle, - "basalt.observability.generation", - default_generation_input, - default_generation_variables, - ) - elif kind_str == "retrieval": - return ( - RetrievalSpanHandle, - "basalt.observability.retrieval", - default_retrieval_input, - default_retrieval_variables, - ) - elif kind_str == "tool": - return ( - ToolSpanHandle, - "basalt.observability.tool", - None, - None, - ) - elif kind_str == "function": - return ( - FunctionSpanHandle, - "basalt.observability.function", - None, - None, - ) - elif kind_str == "event": - return ( - EventSpanHandle, - "basalt.observability.event", - None, - None, - ) - else: - return ( - SpanHandle, - "basalt.observability", - None, - None, - ) + return _get_observe_config_for_kind(kind_str) async def __aenter__(self) -> SpanHandle: span_name = self.name - if isinstance(self.kind, ObserveKind): - kind_str = self.kind.value - else: - kind_str = str(self.kind).lower() - # Validate that the string kind is valid - valid_kinds = {k.value for k in ObserveKind} - if kind_str not in valid_kinds: - raise ValueError(f"Invalid kind '{kind_str}'. Must be one of: {', '.join(sorted(valid_kinds))}") + kind_str = _resolve_kind_str(self.kind) # Reject ROOT kind if kind_str == ObserveKind.ROOT.value: @@ -1160,7 +1147,9 @@ async def __aenter__(self) -> SpanHandle: import logging logger = logging.getLogger(__name__) - logger.warning("AsyncObserve used without a preceding async_start_observe. This may lead to missing trace context.") + logger.warning( + "AsyncObserve used without a preceding async_start_observe. This may lead to missing trace context." + ) self._ctx_manager = _async_with_span_handle( name=span_name, @@ -1174,7 +1163,7 @@ async def __aenter__(self) -> SpanHandle: self._span_handle = await self._ctx_manager.__aenter__() return self._span_handle - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type, exc_value, traceback) -> bool | None: if self._ctx_manager: return await self._ctx_manager.__aexit__(exc_type, exc_value, traceback) return None diff --git a/basalt/observability/config.py b/basalt/observability/config.py index 8bc36c0..6183680 100644 --- a/basalt/observability/config.py +++ b/basalt/observability/config.py @@ -147,7 +147,9 @@ def clone(self) -> TelemetryConfig: cloned = replace(self) cloned.extra_resource_attributes = dict(self.extra_resource_attributes) cloned.enabled_providers = list(self.enabled_providers) if self.enabled_providers else None - cloned.disabled_providers = list(self.disabled_providers) if self.disabled_providers else None + cloned.disabled_providers = ( + list(self.disabled_providers) if self.disabled_providers else None + ) # Clone exporter list if it's a list (shallow copy of list, not exporters themselves) if isinstance(self.exporter, list): cloned.exporter = list(self.exporter) @@ -184,7 +186,9 @@ def with_env_overrides(self) -> TelemetryConfig: disabled_instruments = os.getenv("BASALT_DISABLED_INSTRUMENTS") if disabled_instruments: - cfg.disabled_providers = [p.strip() for p in disabled_instruments.split(",") if p.strip()] + cfg.disabled_providers = [ + p.strip() for p in disabled_instruments.split(",") if p.strip() + ] sample_rate_env = os.getenv("BASALT_SAMPLE_RATE") if sample_rate_env: diff --git a/basalt/observability/context_managers.py b/basalt/observability/context_managers.py index ebe38ea..471a499 100644 --- a/basalt/observability/context_managers.py +++ b/basalt/observability/context_managers.py @@ -20,6 +20,7 @@ from opentelemetry.context import attach, detach, set_value from opentelemetry.trace import Span, Tracer +from ..types.common import JSONValue, SpanAttributeValue from . import semconv from .trace_context import ( ORGANIZATION_CONTEXT_KEY, @@ -81,7 +82,7 @@ def __post_init__(self) -> None: raise TypeError("Evaluator metadata must be a mapping.") -def _normalize_evaluator_entry(entry: Any) -> EvaluatorAttachment: +def _normalize_evaluator_entry(entry: object) -> EvaluatorAttachment: """Convert assorted evaluator payloads into EvaluatorAttachment objects.""" if isinstance(entry, EvaluatorAttachment): return entry @@ -155,10 +156,10 @@ def _attach_attributes(span: Span, attributes: dict[str, Any] | None) -> None: if not attributes: return for key, value in attributes.items(): - span.set_attribute(key, value) + _set_serialized_attribute(span, key, value) -def _serialize_attribute(value: Any) -> Any | None: +def _serialize_attribute(value: JSONValue) -> SpanAttributeValue: if value is None or isinstance(value, (str, bool, int, float)): return value try: @@ -167,7 +168,7 @@ def _serialize_attribute(value: Any) -> Any | None: return str(value) -def _set_serialized_attribute(span: Span, key: str, value: Any) -> None: +def _set_serialized_attribute(span: Span, key: str, value: JSONValue) -> None: serialized = _serialize_attribute(value) if serialized is not None: span.set_attribute(key, serialized) @@ -224,33 +225,37 @@ def __init__( span: Span, parent_span: Span | None = None, defaults: _TraceContextConfig | None = None, - ): + ) -> None: self._span = span self._io_payload: dict[str, Any] = {"input": None, "output": None, "variables": None} - self._parent_span = parent_span if parent_span and parent_span.get_span_context().is_valid else None + self._parent_span = ( + parent_span if parent_span and parent_span.get_span_context().is_valid else None + ) self._evaluators: dict[str, EvaluatorAttachment] = {} self._evaluator_config: EvaluationConfig | None = None self._hydrate_existing_evaluators() - def set_attribute(self, key: str, value: Any) -> None: + def set_attribute(self, key: str, value: str | int | float | bool | None) -> None: """ Sets metadata on the current span. Args: key (str): The metadata key to set. - value (Any): The metadata value. + value (str | int | float | bool | None): The metadata value. Returns: None """ - self._span.set_attribute(key, value) + if value is not None: + self._span.set_attribute(key, value) def set_attributes(self, attributes: dict[str, Any]) -> None: """ Set multiple attributes on the current span. Args: - attributes: Dictionary of attributes to set. + attributes: Dictionary of attributes to set. Complex values (dicts, lists) + will be serialized to JSON strings. """ _attach_attributes(self._span, attributes) @@ -293,7 +298,7 @@ def set_prompt(self, prompt: Prompt) -> None: # ------------------------------------------------------------------ # IO helpers # ------------------------------------------------------------------ - def set_input(self, payload: str | dict[str, Any]) -> None: + def set_input(self, payload: JSONValue) -> None: """ Sets the input payload for the current context manager. Stores the provided payload in the internal `_io_payload` dictionary under the "input" key. @@ -307,7 +312,7 @@ def set_input(self, payload: str | dict[str, Any]) -> None: if trace_content_enabled(): _set_serialized_attribute(self._span, semconv.BasaltSpan.INPUT, payload) - def set_output(self, payload: str | dict[str, Any]) -> None: + def set_output(self, payload: JSONValue) -> None: """ Sets the output payload for the current context manager. Stores the provided payload in the internal I/O payload dictionary under the "output" key. @@ -323,8 +328,8 @@ def set_output(self, payload: str | dict[str, Any]) -> None: def set_io( self, *, - input_payload: str | dict[str, Any] | None = None, - output_payload: str | dict[str, Any] | None = None, + input_payload: JSONValue | None = None, + output_payload: JSONValue | None = None, variables: Mapping[str, Any] | None = None, ) -> None: """ @@ -339,9 +344,11 @@ def set_io( raise TypeError("Span variables must be provided as a mapping.") self._io_payload["variables"] = dict(variables) if trace_content_enabled(): - _set_serialized_attribute(self._span, semconv.BasaltSpan.VARIABLES, variables) + _set_serialized_attribute(self._span, semconv.BasaltSpan.VARIABLES, dict(variables)) if self._parent_span: - _set_serialized_attribute(self._parent_span, semconv.BasaltSpan.VARIABLES, variables) + _set_serialized_attribute( + self._parent_span, semconv.BasaltSpan.VARIABLES, dict(variables) + ) def _io_snapshot(self) -> dict[str, Any]: """Return a shallow copy of the tracked IO payload.""" @@ -350,7 +357,6 @@ def _io_snapshot(self) -> dict[str, Any]: snapshot["variables"] = dict(snapshot["variables"]) return snapshot - def add_evaluator( self, evaluator_slug: str, @@ -378,10 +384,12 @@ def set_identity(self, identity: Identity | None = None) -> None: Each key should contain a dict with 'id' (required) and 'name' (optional). Example: - >>> span.set_identity({ - ... "user": {"id": "user-123", "name": "John Doe"}, - ... "organization": {"id": "org-456", "name": "ACME Corp"} - ... }) + >>> span.set_identity( + ... { + ... "user": {"id": "user-123", "name": "John Doe"}, + ... "organization": {"id": "org-456", "name": "ACME Corp"}, + ... } + ... ) """ if identity is None: return @@ -476,7 +484,7 @@ def set_response_id(self, response_id: str) -> None: def set_finish_reasons(self, reasons: list[str]) -> None: """Set the finish reasons array for the GenAI response.""" - self.set_attribute(semconv.GenAI.RESPONSE_FINISH_REASONS, reasons) + _set_serialized_attribute(self._span, semconv.GenAI.RESPONSE_FINISH_REASONS, list(reasons)) class StartSpanHandle(SpanHandle): @@ -498,7 +506,9 @@ def set_evaluation_config(self, config: EvaluationConfig | Mapping[str, Any]) -> """ if isinstance(config, EvaluationConfig): self._evaluator_config = config - _set_serialized_attribute(self._span, semconv.BasaltSpan.EVALUATION_CONFIG, config.to_dict()) + _set_serialized_attribute( + self._span, semconv.BasaltSpan.EVALUATION_CONFIG, config.to_dict() + ) elif isinstance(config, Mapping): config_dict = dict(config) self._evaluator_config = EvaluationConfig(**config_dict) @@ -594,9 +604,10 @@ def set_results_count(self, count: int) -> None: """Set the number of results returned.""" self.set_attribute(semconv.BasaltRetrieval.RESULTS_COUNT, count) - def set_top_k(self, top_k: int) -> None: + def set_top_k(self, top_k: float) -> None: """Set the top-K parameter for retrieval.""" - self.set_attribute(semconv.BasaltRetrieval.TOP_K, top_k) + value = int(top_k) if isinstance(top_k, float) else top_k + self.set_attribute(semconv.BasaltRetrieval.TOP_K, value) class FunctionSpanHandle(SpanHandle): @@ -612,7 +623,7 @@ def set_stage(self, stage: str) -> None: """Set the stage or phase associated with the execution.""" self.set_attribute(semconv.BasaltFunction.STAGE, stage) - def add_metric(self, key: str, value: Any) -> None: + def add_metric(self, key: str, value: str | int | float | bool) -> None: """Attach custom metric data to the function execution.""" self.set_attribute(f"{semconv.BasaltFunction.METRIC_PREFIX}.{key}", value) @@ -628,7 +639,7 @@ def set_tool_name(self, name: str) -> None: """Set the name of the tool being invoked.""" self.set_attribute(semconv.BasaltTool.NAME, name) - def set_input(self, payload: Any) -> None: + def set_input(self, payload: JSONValue) -> None: """Set the input payload for the tool.""" super().set_input(payload) if trace_content_enabled(): @@ -636,7 +647,7 @@ def set_input(self, payload: Any) -> None: if value is not None: self.set_attribute(semconv.BasaltTool.INPUT, value) - def set_output(self, payload: Any) -> None: + def set_output(self, payload: JSONValue) -> None: """Set the output payload from the tool.""" super().set_output(payload) if trace_content_enabled(): @@ -656,7 +667,7 @@ def set_event_type(self, event_type: str) -> None: """Set the type of custom event.""" self.set_attribute(semconv.BasaltEvent.TYPE, event_type) - def set_payload(self, payload: Any) -> None: + def set_payload(self, payload: JSONValue) -> None: """Set the event payload.""" super().set_input(payload) if trace_content_enabled(): @@ -673,8 +684,8 @@ def _with_span_handle( handle_cls: type[SpanHandle], span_type: str | None = None, *, - input_payload: Any | None = None, - output_payload: Any | None = None, + input_payload: JSONValue | None = None, + output_payload: JSONValue | None = None, variables: Mapping[str, Any] | None = None, evaluators: Sequence[Any] | None = None, user: TraceIdentity | Mapping[str, Any] | None = None, @@ -682,13 +693,15 @@ def _with_span_handle( feature_slug: str | None = None, metadata: Mapping[str, Any] | None = None, evaluate_config: EvaluationConfig | None = None, - experiment: Any = None, + experiment: str | dict[str, Any] | None = None, ) -> Generator[SpanHandle, None, None]: tracer = get_tracer(tracer_name) defaults = _current_trace_defaults() parent_span = trace.get_current_span() - if parent_span and (not parent_span.get_span_context().is_valid or not parent_span.is_recording()): + if parent_span and ( + not parent_span.get_span_context().is_valid or not parent_span.is_recording() + ): parent_span = None # Prepare context tokens for user/org propagation @@ -767,20 +780,12 @@ def _with_span_handle( # Inject prompt context if available try: from basalt.prompts.models import _current_prompt_context + prompt_ctx = _current_prompt_context.get() if prompt_ctx: - # Inject prompt attributes into this span - import json - span.set_attribute("basalt.prompt.slug", prompt_ctx["slug"]) - if prompt_ctx.get("version"): - span.set_attribute("basalt.prompt.version", prompt_ctx["version"]) - if prompt_ctx.get("tag"): - span.set_attribute("basalt.prompt.tag", prompt_ctx["tag"]) - span.set_attribute("basalt.prompt.model.provider", prompt_ctx["provider"]) - span.set_attribute("basalt.prompt.model.model", prompt_ctx["model"]) - if prompt_ctx.get("variables"): - span.set_attribute("basalt.prompt.variables", json.dumps(prompt_ctx["variables"])) - span.set_attribute("basalt.prompt.from_cache", prompt_ctx["from_cache"]) + from .utils import apply_prompt_context_attributes + + apply_prompt_context_attributes(span, prompt_ctx) except ImportError: # Prompts module not available, skip injection pass @@ -790,6 +795,7 @@ def _with_span_handle( # Apply metadata if provided if metadata: from .utils import apply_span_metadata + apply_span_metadata(span, metadata) if span_type: @@ -811,7 +817,7 @@ def _with_span_handle( handle.set_io(variables=variables) if evaluators: handle.add_evaluators(*evaluators) - yield handle # type: ignore[misc] + yield handle if output_payload is not None: handle.set_output(output_payload) @@ -837,8 +843,8 @@ async def _async_with_span_handle( handle_cls: type[SpanHandle], span_type: str | None = None, *, - input_payload: Any | None = None, - output_payload: Any | None = None, + input_payload: JSONValue | None = None, + output_payload: JSONValue | None = None, variables: Mapping[str, Any] | None = None, evaluators: Sequence[Any] | None = None, user: TraceIdentity | Mapping[str, Any] | None = None, @@ -846,7 +852,7 @@ async def _async_with_span_handle( feature_slug: str | None = None, metadata: Mapping[str, Any] | None = None, evaluate_config: EvaluationConfig | None = None, - experiment: Any = None, + experiment: str | Experiment | None = None, ) -> AsyncGenerator[SpanHandle, None]: """Async version of _with_span_handle. @@ -854,149 +860,34 @@ async def _async_with_span_handle( so this async context manager still calls sync OTel APIs internally. The async support is primarily for use with async with statements. """ - tracer = get_tracer(tracer_name) - defaults = _current_trace_defaults() - - parent_span = trace.get_current_span() - if parent_span and (not parent_span.get_span_context().is_valid or not parent_span.is_recording()): - parent_span = None - - # Prepare context tokens for user/org propagation - tokens = [] - if user is not None: - from .trace_context import _coerce_identity - - user_identity = _coerce_identity(user) - if user_identity: - tokens.append(attach(set_value(USER_CONTEXT_KEY, user_identity))) - - if organization is not None: - from .trace_context import _coerce_identity - - org_identity = _coerce_identity(organization) - if org_identity: - tokens.append(attach(set_value(ORGANIZATION_CONTEXT_KEY, org_identity))) - - if feature_slug is not None: - from .trace_context import FEATURE_SLUG_CONTEXT_KEY - - tokens.append(attach(set_value(FEATURE_SLUG_CONTEXT_KEY, feature_slug))) - - # If this is a root span (no parent), store it in context - is_root = parent_span is None - root_span_token = None - - # Check if we're inside a basalt trace - in_basalt_trace = otel_context.get_value(ROOT_SPAN_CONTEXT_KEY) is not None - - # Make trace-level sampling decision - should_evaluate_token = None - if is_root: - # Root span: make new sampling decision - # If experiment is attached, ALWAYS evaluate (should_evaluate=True) - if experiment is not None: - should_evaluate = True - else: - # Get sample_rate from evaluate_config if provided, otherwise use global default - if evaluate_config is not None: - effective_sample_rate = evaluate_config.sample_rate - else: - effective_sample_rate = defaults.sample_rate - should_evaluate = random.random() < effective_sample_rate - should_evaluate_token = attach(set_value(SHOULD_EVALUATE_CONTEXT_KEY, should_evaluate)) - else: - # Check if should_evaluate already exists in context - existing_should_evaluate = otel_context.get_value(SHOULD_EVALUATE_CONTEXT_KEY) - if existing_should_evaluate is None: - # Orphan span without root - make its own decision - # If experiment is attached, ALWAYS evaluate - if experiment is not None: - should_evaluate = True - else: - if evaluate_config is not None: - effective_sample_rate = evaluate_config.sample_rate - else: - effective_sample_rate = defaults.sample_rate - should_evaluate = random.random() < effective_sample_rate - should_evaluate_token = attach(set_value(SHOULD_EVALUATE_CONTEXT_KEY, should_evaluate)) - - try: - with tracer.start_as_current_span(name) as span: - # Store root span in context for retrieval from nested spans - if is_root: - root_span_token = attach(set_value(ROOT_SPAN_CONTEXT_KEY, span)) - # Set basalt.root attribute - span.set_attribute("basalt.root", True) - elif in_basalt_trace: - # Child span inside a basalt trace - span.set_attribute("basalt.trace", True) - - # Mark all basalt spans with basalt.in_trace - span.set_attribute(semconv.BasaltSpan.IN_TRACE, True) - - # Inject prompt context if available - try: - from basalt.prompts.models import _current_prompt_context - prompt_ctx = _current_prompt_context.get() - if prompt_ctx: - # Inject prompt attributes into this span - import json - span.set_attribute("basalt.prompt.slug", prompt_ctx["slug"]) - if prompt_ctx.get("version"): - span.set_attribute("basalt.prompt.version", prompt_ctx["version"]) - if prompt_ctx.get("tag"): - span.set_attribute("basalt.prompt.tag", prompt_ctx["tag"]) - span.set_attribute("basalt.prompt.model.provider", prompt_ctx["provider"]) - span.set_attribute("basalt.prompt.model.model", prompt_ctx["model"]) - if prompt_ctx.get("variables"): - span.set_attribute("basalt.prompt.variables", json.dumps(prompt_ctx["variables"])) - span.set_attribute("basalt.prompt.from_cache", prompt_ctx["from_cache"]) - except ImportError: - # Prompts module not available, skip injection - pass - - _attach_attributes(span, attributes) - - # Apply metadata if provided - if metadata: - from .utils import apply_span_metadata - apply_span_metadata(span, metadata) - - if span_type: - span.set_attribute(SPAN_TYPE_ATTRIBUTE, span_type) - - # Apply user/org from context (either explicit or inherited from parent) - apply_user_from_context(span, user) - apply_organization_from_context(span, organization) - - if is_root and handle_cls == SpanHandle: - actual_handle_cls = StartSpanHandle - else: - actual_handle_cls = handle_cls - - handle = actual_handle_cls(span, parent_span, defaults) - if input_payload is not None: - handle.set_input(input_payload) - if variables: - handle.set_io(variables=variables) - if evaluators: - handle.add_evaluators(*evaluators) - yield handle # type: ignore[misc] - if output_payload is not None: - handle.set_output(output_payload) - - finally: - # Detach should_evaluate token if it was set - if should_evaluate_token is not None: - detach(should_evaluate_token) - - # Detach root span token if it was set - if root_span_token is not None: - detach(root_span_token) - - # Detach context tokens in reverse order - for token in reversed(tokens): - detach(token) + # Coerce Experiment to str or dict if needed + experiment_arg = experiment + if experiment is not None and not isinstance(experiment, (str, dict)): + # Prefer id if available, else fallback to str + experiment_arg = getattr(experiment, "id", str(experiment)) + + # Ensure experiment_arg is str, dict, or None + if experiment_arg is not None and not isinstance(experiment_arg, (str, dict)): + experiment_arg = getattr(experiment_arg, "id", str(experiment_arg)) + + with _with_span_handle( + name=name, + attributes=attributes, + tracer_name=tracer_name, + handle_cls=handle_cls, + span_type=span_type, + input_payload=input_payload, + output_payload=output_payload, + variables=variables, + evaluators=evaluators, + user=user, + organization=organization, + feature_slug=feature_slug, + metadata=metadata, + evaluate_config=evaluate_config, + experiment=experiment_arg, + ) as handle: + yield handle def _set_trace_user(user_id: str, name: str | None = None) -> None: diff --git a/basalt/observability/evaluators.py b/basalt/observability/evaluators.py index 2de10fa..3f5c77e 100644 --- a/basalt/observability/evaluators.py +++ b/basalt/observability/evaluators.py @@ -16,14 +16,13 @@ from collections.abc import Generator from contextlib import contextmanager -from typing import Any from opentelemetry import trace from .context_managers import SpanHandle, normalize_evaluator_specs, with_evaluators -def _flatten_evaluator_specs(*evaluators: Any) -> list[str]: +def _flatten_evaluator_specs(*evaluators: object) -> list[str]: """Normalize evaluator specifications into a flat list of slugs.""" slugs: list[str] = [] for attachment in normalize_evaluator_specs(evaluators): @@ -34,7 +33,7 @@ def _flatten_evaluator_specs(*evaluators: Any) -> list[str]: @contextmanager def attach_evaluator( - *evaluators: Any, + *evaluators: object, span: SpanHandle | None = None, ) -> Generator[None, None, None]: """ @@ -79,7 +78,7 @@ def attach_evaluator( yield -def attach_evaluators_to_span(span_handle: SpanHandle, *evaluators: Any) -> None: +def attach_evaluators_to_span(span_handle: SpanHandle, *evaluators: object) -> None: """ Directly attach evaluators to a span handle, respecting sample rates. @@ -96,7 +95,7 @@ def attach_evaluators_to_span(span_handle: SpanHandle, *evaluators: Any) -> None span_handle.add_evaluator(slug) -def attach_evaluators_to_current_span(*evaluators: Any) -> None: +def attach_evaluators_to_current_span(*evaluators: object) -> None: """ Attach evaluators to the current active span, respecting sample rates. diff --git a/basalt/observability/instrumentation.py b/basalt/observability/instrumentation.py index 886c332..0555d71 100644 --- a/basalt/observability/instrumentation.py +++ b/basalt/observability/instrumentation.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -def _safe_import(module: str, target: str) -> Any | None: +def _safe_import(module: str, target: str) -> object | None: """Safely import a target from a module, returning None on failure.""" try: mod = __import__(module, fromlist=[target]) @@ -130,7 +130,9 @@ def create_tracer_provider( # Add a span processor for each exporter for exp in exporters: - processor_cls = SimpleSpanProcessor if isinstance(exp, ConsoleSpanExporter) else BatchSpanProcessor + processor_cls = ( + SimpleSpanProcessor if isinstance(exp, ConsoleSpanExporter) else BatchSpanProcessor + ) provider.add_span_processor(processor_cls(exp)) return provider @@ -168,7 +170,7 @@ def setup_tracing( # Check if a tracer provider is already set globally existing_provider = trace.get_tracer_provider() # If it's a real TracerProvider (not the default proxy), reuse it - if hasattr(existing_provider, 'add_span_processor'): + if hasattr(existing_provider, "add_span_processor"): provider_type = type(existing_provider).__name__ provider_module = type(existing_provider).__module__ logger.info( @@ -219,14 +221,14 @@ def initialize( # Normalize user_exporters to list if user_exporters is None: - exporters_list = [] + exporters_list: list[SpanExporter] = [] elif isinstance(user_exporters, list): - exporters_list = user_exporters.copy() + exporters_list = list(user_exporters) else: exporters_list = [user_exporters] # Add environment exporter ONLY if no user exporters were provided - if user_exporters is None and env_exporter: + if user_exporters is None and env_exporter is not None: exporters_list.append(env_exporter) # Pass to setup_tracing (will handle None/empty list → ConsoleSpanExporter) @@ -330,7 +332,7 @@ def _should_use_http_exporter(endpoint: str) -> bool: return False # Check if hostname contains 'grpc' - indicates gRPC endpoint - if parsed.hostname and 'grpc' in parsed.hostname.lower(): + if parsed.hostname and "grpc" in parsed.hostname.lower(): return False if parsed.port == 4317 and parsed.path in {"", "/"}: @@ -363,13 +365,21 @@ def _instrument_providers(self, config: TelemetryConfig) -> None: "openai": ("opentelemetry.instrumentation.openai", "OpenAIInstrumentor"), "anthropic": ("opentelemetry.instrumentation.anthropic", "AnthropicInstrumentor"), # NEW Google GenAI SDK (from google import genai) - "google_genai": ("opentelemetry.instrumentation.google_genai", "GoogleGenAiSdkInstrumentor"), + "google_genai": ( + "opentelemetry.instrumentation.google_genai", + "GoogleGenAiSdkInstrumentor", + ), # OLD Google Generative AI SDK (import google.generativeai) - "google_generativeai": ("opentelemetry.instrumentation.google_generativeai", - "GoogleGenerativeAiInstrumentor"), + "google_generativeai": ( + "opentelemetry.instrumentation.google_generativeai", + "GoogleGenerativeAiInstrumentor", + ), "bedrock": ("opentelemetry.instrumentation.bedrock", "BedrockInstrumentor"), "vertexai": ("opentelemetry.instrumentation.vertexai", "VertexAIInstrumentor"), - "vertex-ai": ("opentelemetry.instrumentation.vertexai", "VertexAIInstrumentor"), # Alias + "vertex-ai": ( + "opentelemetry.instrumentation.vertexai", + "VertexAIInstrumentor", + ), # Alias "ollama": ("opentelemetry.instrumentation.ollama", "OllamaInstrumentor"), "mistralai": ("opentelemetry.instrumentation.mistralai", "MistralAiInstrumentor"), # Frameworks @@ -404,7 +414,7 @@ def _instrument_providers(self, config: TelemetryConfig) -> None: try: instrumentor = instrumentor_cls() # Check if already instrumented to avoid double instrumentation - if hasattr(instrumentor, 'is_instrumented_by_opentelemetry'): + if hasattr(instrumentor, "is_instrumented_by_opentelemetry"): if not instrumentor.is_instrumented_by_opentelemetry: instrumentor.instrument() self._provider_instrumentors[provider_key] = instrumentor @@ -434,6 +444,7 @@ def _initialize_instrumentation(self, config: TelemetryConfig) -> None: """ # Set global sample rate from config from .trace_context import set_global_sample_rate + set_global_sample_rate(config.sample_rate) # Set environment variables for third-party OpenTelemetry instrumentors @@ -449,9 +460,9 @@ def _initialize_instrumentation(self, config: TelemetryConfig) -> None: # - TRACELOOP_TRACE_CONTENT: Used by most OpenLLMetry instrumentors # - OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT: Used by Google GenAI instrumentor os.environ["TRACELOOP_TRACE_CONTENT"] = "true" if config.trace_content else "false" - os.environ[ - "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" - ] = "true" if config.trace_content else "false" + os.environ["OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"] = ( + "true" if config.trace_content else "false" + ) # Instrument providers directly without using Traceloop.init() self._instrument_providers(config) @@ -481,12 +492,11 @@ def _install_basalt_processors(self, provider: TracerProvider) -> None: self._span_processors = processors logger.debug(f"Successfully installed Basalt processors on {provider_type}") - def _uninstrument_providers(self) -> None: for provider_key, instrumentor in list(self._provider_instrumentors.items()): try: # Check if it's actually instrumented before trying to uninstrument - if hasattr(instrumentor, 'is_instrumented_by_opentelemetry'): + if hasattr(instrumentor, "is_instrumented_by_opentelemetry"): if instrumentor.is_instrumented_by_opentelemetry: instrumentor.uninstrument() logger.debug(f"Uninstrumented provider: {provider_key}") diff --git a/basalt/observability/processors.py b/basalt/observability/processors.py index a866b61..6dce23f 100644 --- a/basalt/observability/processors.py +++ b/basalt/observability/processors.py @@ -8,7 +8,7 @@ from typing import Any, Final from opentelemetry import context as otel_context -from opentelemetry.context import attach +from opentelemetry.context import Context, attach from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor from . import semconv @@ -71,25 +71,19 @@ def _set_default_metadata(span: Span, defaults: _TraceContextConfig) -> None: hasattr(parent_ctx, "is_valid") and not parent_ctx.is_valid ) - experiment = ( - defaults.experiment - if isinstance(defaults.experiment, TraceExperiment) - else None - ) + experiment = defaults.experiment if isinstance(defaults.experiment, TraceExperiment) else None if experiment and is_root_span: span.set_attribute(semconv.BasaltExperiment.ID, experiment.id) if experiment.name: span.set_attribute(semconv.BasaltExperiment.NAME, experiment.name) if experiment.feature_slug: - span.set_attribute( - semconv.BasaltExperiment.FEATURE_SLUG, experiment.feature_slug - ) + span.set_attribute(semconv.BasaltExperiment.FEATURE_SLUG, experiment.feature_slug) for key, value in (defaults.observe_metadata or {}).items(): span.set_attribute(f"{semconv.BASALT_META_PREFIX}{key}", value) -def _apply_user_org_from_context(span: Span, parent_context: Any | None = None) -> None: +def _apply_user_org_from_context(span: Span, parent_context: Context | None = None) -> None: """Apply user and organization from OpenTelemetry context to the span.""" if not span.is_recording(): return @@ -109,7 +103,7 @@ def _apply_user_org_from_context(span: Span, parent_context: Any | None = None) span.set_attribute(semconv.BasaltOrganization.NAME, org.name) -def _apply_feature_slug_from_context(span: Span, parent_context: Any | None = None) -> None: +def _apply_feature_slug_from_context(span: Span, parent_context: Context | None = None) -> None: """Apply feature slug from OpenTelemetry context to the span.""" if not span.is_recording(): return @@ -127,7 +121,7 @@ def _apply_feature_slug_from_context(span: Span, parent_context: Any | None = No class BasaltContextProcessor(SpanProcessor): """Apply Basalt trace defaults to every started span.""" - def on_start(self, span: Span, parent_context: Any | None = None) -> None: # type: ignore[override] + def on_start(self, span: Span, parent_context: Context | None = None) -> None: if not span.is_recording(): return defaults = _current_trace_defaults() @@ -137,13 +131,13 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty # Apply feature_slug from OpenTelemetry context (enables propagation to child spans) _apply_feature_slug_from_context(span, parent_context) - def on_end(self, span: ReadableSpan) -> None: # type: ignore[override] + def on_end(self, span: ReadableSpan) -> None: return - def shutdown(self) -> None: # type: ignore[override] + def shutdown(self) -> None: return - def force_flush(self, timeout_millis: int = 30000) -> bool: # type: ignore[override] + def force_flush(self, timeout_millis: int = 30000) -> bool: return True @@ -153,7 +147,7 @@ class BasaltCallEvaluatorProcessor(SpanProcessor): def __init__(self, context_key: str = EVALUATOR_CONTEXT_KEY) -> None: self._context_key = context_key - def on_start(self, span: Span, parent_context: Any | None = None) -> None: # type: ignore[override] + def on_start(self, span: Span, parent_context: Context | None = None) -> None: if not span.is_recording(): return @@ -182,15 +176,11 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty except Exception as exc: # pragma: no cover - defensive logger.debug("Failed to normalize call evaluators: %s", exc) else: - slugs = [ - attachment.slug for attachment in attachments if attachment.slug - ] + slugs = [attachment.slug for attachment in attachments if attachment.slug] _merge_evaluators(span, slugs) # Attach evaluator config from context - context_config = otel_context.get_value( - EVALUATOR_CONFIG_CONTEXT_KEY, parent_context - ) + context_config = otel_context.get_value(EVALUATOR_CONFIG_CONTEXT_KEY, parent_context) if context_config and isinstance(context_config, EvaluationConfig): try: import json @@ -202,13 +192,13 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty except Exception as exc: # pragma: no cover - defensive logger.debug("Failed to set evaluator config: %s", exc) - def on_end(self, span: ReadableSpan) -> None: # type: ignore[override] + def on_end(self, span: ReadableSpan) -> None: return - def shutdown(self) -> None: # type: ignore[override] + def shutdown(self) -> None: return - def force_flush(self, timeout_millis: int = 30000) -> bool: # type: ignore[override] + def force_flush(self, timeout_millis: int = 30000) -> bool: return True @@ -221,7 +211,7 @@ class BasaltShouldEvaluateProcessor(SpanProcessor): should_evaluate value, enabling trace-level sampling for evaluators. """ - def on_start(self, span: Span, parent_context: Any | None = None) -> None: # type: ignore[override] + def on_start(self, span: Span, parent_context: Context | None = None) -> None: if not span.is_recording(): return @@ -235,33 +225,35 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty if should_evaluate is not None: span.set_attribute(semconv.BasaltSpan.SHOULD_EVALUATE, bool(should_evaluate)) - def on_end(self, span: ReadableSpan) -> None: # type: ignore[override] + def on_end(self, span: ReadableSpan) -> None: return - def shutdown(self) -> None: # type: ignore[override] + def shutdown(self) -> None: return - def force_flush(self, timeout_millis: int = 30000) -> bool: # type: ignore[override] + def force_flush(self, timeout_millis: int = 30000) -> bool: return True # Known auto-instrumentation scope names -KNOWN_AUTO_INSTRUMENTATION_SCOPES: Final[frozenset[str]] = frozenset({ - "opentelemetry.instrumentation.openai", - "opentelemetry.instrumentation.openai.v1", # OpenAI SDK v1+ - "opentelemetry.instrumentation.anthropic", - "opentelemetry.instrumentation.google_genai", - "opentelemetry.instrumentation.google_generativeai", - "opentelemetry.instrumentation.bedrock", - "opentelemetry.instrumentation.vertexai", - "opentelemetry.instrumentation.ollama", - "opentelemetry.instrumentation.mistralai", - "opentelemetry.instrumentation.langchain", - "opentelemetry.instrumentation.llamaindex", - "opentelemetry.instrumentation.chromadb", - "opentelemetry.instrumentation.pinecone", - "opentelemetry.instrumentation.qdrant", -}) +KNOWN_AUTO_INSTRUMENTATION_SCOPES: Final[frozenset[str]] = frozenset( + { + "opentelemetry.instrumentation.openai", + "opentelemetry.instrumentation.openai.v1", # OpenAI SDK v1+ + "opentelemetry.instrumentation.anthropic", + "opentelemetry.instrumentation.google_genai", + "opentelemetry.instrumentation.google_generativeai", + "opentelemetry.instrumentation.bedrock", + "opentelemetry.instrumentation.vertexai", + "opentelemetry.instrumentation.ollama", + "opentelemetry.instrumentation.mistralai", + "opentelemetry.instrumentation.langchain", + "opentelemetry.instrumentation.llamaindex", + "opentelemetry.instrumentation.chromadb", + "opentelemetry.instrumentation.pinecone", + "opentelemetry.instrumentation.qdrant", + } +) # Mapping of instrumentation scope names to Basalt span kinds @@ -298,7 +290,7 @@ class BasaltAutoInstrumentationProcessor(SpanProcessor): after being applied to ensure single-use semantics. """ - def on_start(self, span: Span, parent_context: Any | None = None) -> None: # type: ignore[override] + def on_start(self, span: Span, parent_context: Context | None = None) -> None: """ Called when a span starts. @@ -331,13 +323,19 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty # Read and apply input input_payload = otel_context.get_value(PENDING_INJECT_INPUT_KEY, ctx) if input_payload is not None: - serialized = json.dumps(input_payload) if not isinstance(input_payload, str) else input_payload + serialized = ( + json.dumps(input_payload) if not isinstance(input_payload, str) else input_payload + ) span.set_attribute(semconv.BasaltSpan.INPUT, serialized) # Read and apply output output_payload = otel_context.get_value(PENDING_INJECT_OUTPUT_KEY, ctx) if output_payload is not None: - serialized = json.dumps(output_payload) if not isinstance(output_payload, str) else output_payload + serialized = ( + json.dumps(output_payload) + if not isinstance(output_payload, str) + else output_payload + ) span.set_attribute(semconv.BasaltSpan.OUTPUT, serialized) # Read and apply variables @@ -349,46 +347,62 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty metadata = otel_context.get_value(PENDING_INJECT_METADATA_KEY, ctx) if metadata: from .utils import apply_span_metadata + if isinstance(metadata, dict): - apply_span_metadata(span, metadata) + metadata_map = {str(key): value for key, value in metadata.items()} + apply_span_metadata(span, metadata_map) else: # Non-dict metadata: store at basalt.metadata - span.set_attribute(semconv.BasaltSpan.METADATA, json.dumps(metadata) if not isinstance(metadata, str) else metadata) + span.set_attribute( + semconv.BasaltSpan.METADATA, + json.dumps(metadata) if not isinstance(metadata, str) else metadata, + ) # Read and apply prompt metadata prompt_data = otel_context.get_value(PENDING_INJECT_PROMPT_KEY, ctx) - if prompt_data: + if isinstance(prompt_data, dict): # Apply prompt attributes (slug, version, provider, model) for key, value in prompt_data.items(): - span.set_attribute(f"basalt.prompt.{key}", value) + if value is None: + continue + if isinstance(value, (str, int, float)): + safe_value = value + elif isinstance(value, (list, tuple)) and all( + isinstance(item, (str, int, float)) for item in value + ): + safe_value = list(value) + else: + safe_value = json.dumps(value) + span.set_attribute(f"basalt.prompt.{key}", safe_value) else: # If no explicit injection, try to read from ContextVar # This allows auto-instrumented spans to inherit prompt context # from the parent prompt context manager try: from basalt.prompts.models import _current_prompt_context + prompt_ctx = _current_prompt_context.get() if prompt_ctx: import logging + logger = logging.getLogger(__name__) - logger.debug(f"✓ Injecting prompt context from ContextVar for span '{scope.name}': slug='{prompt_ctx['slug']}'") - # Inject prompt attributes from ContextVar - span.set_attribute("basalt.prompt.slug", prompt_ctx["slug"]) - if prompt_ctx.get("version"): - span.set_attribute("basalt.prompt.version", prompt_ctx["version"]) - if prompt_ctx.get("tag"): - span.set_attribute("basalt.prompt.tag", prompt_ctx["tag"]) - span.set_attribute("basalt.prompt.model.provider", prompt_ctx["provider"]) - span.set_attribute("basalt.prompt.model.model", prompt_ctx["model"]) - if prompt_ctx.get("variables"): - span.set_attribute("basalt.prompt.variables", json.dumps(prompt_ctx["variables"])) - span.set_attribute("basalt.prompt.from_cache", prompt_ctx["from_cache"]) + logger.debug( + f"✓ Injecting prompt context from ContextVar for span '{scope.name}': " + f"slug='{prompt_ctx['slug']}'" + ) + from .utils import apply_prompt_context_attributes + + apply_prompt_context_attributes(span, prompt_ctx) else: import logging + logger = logging.getLogger(__name__) - logger.debug(f"✗ No prompt context found in ContextVar for auto-instrumented span '{scope.name}'") + logger.debug( + f"✗ No prompt context found in ContextVar for auto-instrumented span '{scope.name}'" + ) except (ImportError, LookupError) as e: import logging + logger = logging.getLogger(__name__) logger.debug(f"✗ Failed to read prompt context for span '{scope.name}': {e}") # Prompts module not available or no context set - skip injection @@ -404,11 +418,11 @@ def on_start(self, span: Span, parent_context: Any | None = None) -> None: # ty new_ctx = otel_context.set_value(PENDING_INJECT_PROMPT_KEY, None, new_ctx) attach(new_ctx) - def on_end(self, span: ReadableSpan) -> None: # type: ignore[override] + def on_end(self, span: ReadableSpan) -> None: return - def shutdown(self) -> None: # type: ignore[override] + def shutdown(self) -> None: return - def force_flush(self, timeout_millis: int = 30000) -> bool: # type: ignore[override] + def force_flush(self, timeout_millis: int = 30000) -> bool: return True diff --git a/basalt/observability/request_tracing.py b/basalt/observability/request_tracing.py index 93410d3..1a3f9b2 100644 --- a/basalt/observability/request_tracing.py +++ b/basalt/observability/request_tracing.py @@ -4,7 +4,9 @@ import time from collections.abc import Awaitable, Callable -from typing import TypeVar +from typing import Any, TypeVar, cast + +from basalt.types.common import JSONValue from .api import observe from .spans import BasaltRequestSpan @@ -12,6 +14,17 @@ T = TypeVar("T") +def _build_request_output(span_data: BasaltRequestSpan, result: object) -> dict[str, JSONValue]: + # Type-safe output formatting for PromptRequestSpan + from basalt.prompts.client import PromptRequestSpan + + if isinstance(span_data, PromptRequestSpan) and isinstance(result, dict): + response_data = cast(dict[str, Any], result) + return span_data.format_output(response_data) + status_code: Any = getattr(result, "status_code", None) + return {"status_code": status_code} + + async def trace_async_request( span_data: BasaltRequestSpan, request_callable: Callable[[], Awaitable[T]], @@ -27,7 +40,7 @@ async def trace_async_request( Result of ``request_callable``. """ start = time.perf_counter() - input_payload = { + input_payload: dict[str, JSONValue] = { "method": span_data.method, "url": span_data.url, } @@ -41,8 +54,9 @@ async def trace_async_request( result = await request_callable() except Exception as exc: # pragma: no cover - passthrough # If the exception carries an HTTP status_code (BasaltAPIError), include it - status_code = getattr(exc, "status_code", None) - observe.set_output({"error": str(exc), "status_code": status_code}) + status_code: Any = getattr(exc, "status_code", None) + error_output: dict[str, JSONValue] = {"error": str(exc), "status_code": status_code} + observe.set_output(error_output) span_data.finalize( span, duration_s=time.perf_counter() - start, @@ -51,13 +65,7 @@ async def trace_async_request( ) raise - # Type-safe output formatting for PromptRequestSpan - from basalt.prompts.client import PromptRequestSpan - if isinstance(span_data, PromptRequestSpan): - output = span_data.format_output(result) - else: - status_code = getattr(result, "status_code", None) - output = {"status_code": status_code} + output = _build_request_output(span_data, result) observe.set_output(output) status_code = getattr(result, "status_code", None) @@ -85,7 +93,7 @@ def trace_sync_request( Result of ``request_callable``. """ start = time.perf_counter() - input_payload = { + input_payload: dict[str, JSONValue] = { "method": span_data.method, "url": span_data.url, } @@ -99,8 +107,9 @@ def trace_sync_request( result = request_callable() except Exception as exc: # pragma: no cover - passthrough # If the exception carries an HTTP status_code (BasaltAPIError), include it - status_code = getattr(exc, "status_code", None) - observe.set_output({"error": str(exc), "status_code": status_code}) + status_code: Any = getattr(exc, "status_code", None) + error_output: dict[str, JSONValue] = {"error": str(exc), "status_code": status_code} + observe.set_output(error_output) span_data.finalize( span, duration_s=time.perf_counter() - start, @@ -109,13 +118,7 @@ def trace_sync_request( ) raise - # Type-safe output formatting for PromptRequestSpan - from basalt.prompts.client import PromptRequestSpan - if isinstance(span_data, PromptRequestSpan): - output = span_data.format_output(result) - else: - status_code = getattr(result, "status_code", None) - output = {"status_code": status_code} + output = _build_request_output(span_data, result) observe.set_output(output) status_code = getattr(result, "status_code", None) diff --git a/basalt/observability/trace.py b/basalt/observability/trace.py index edc8f7d..6c03069 100644 --- a/basalt/observability/trace.py +++ b/basalt/observability/trace.py @@ -5,6 +5,7 @@ from opentelemetry.trace import Span, Tracer +from ..types.common import SpanAttributeValue from . import semconv from .context_managers import ( get_current_otel_span, @@ -36,10 +37,10 @@ def add_event(name: str, attributes: Mapping[str, Any] | None = None) -> None: span.add_event(name, attributes=attributes) @staticmethod - def set_attribute(key: str, value: Any) -> None: + def set_attribute(key: str, value: SpanAttributeValue) -> None: """Set a raw attribute on the current span.""" span = get_current_otel_span() - if span: + if span and value is not None: span.set_attribute(key, value) @staticmethod @@ -81,20 +82,22 @@ def identify( current_user, current_org = _get_current_identity_from_span(span) if user is not None: - new_user = _parse_identity_input(user) - merged_user = _merge_identity(current_user, new_user) - if 'id' in merged_user: - span.set_attribute(semconv.BasaltUser.ID, merged_user['id']) - if 'name' in merged_user: - span.set_attribute(semconv.BasaltUser.NAME, merged_user['name']) + _apply_identity_attributes( + span=span, + current=current_user, + incoming=user, + id_attr=semconv.BasaltUser.ID, + name_attr=semconv.BasaltUser.NAME, + ) if organization is not None: - new_org = _parse_identity_input(organization) - merged_org = _merge_identity(current_org, new_org) - if 'id' in merged_org: - span.set_attribute(semconv.BasaltOrganization.ID, merged_org['id']) - if 'name' in merged_org: - span.set_attribute(semconv.BasaltOrganization.NAME, merged_org['name']) + _apply_identity_attributes( + span=span, + current=current_org, + incoming=organization, + id_attr=semconv.BasaltOrganization.ID, + name_attr=semconv.BasaltOrganization.NAME, + ) def _get_current_identity_from_span(span: Span) -> tuple[dict[str, str], dict[str, str]]: @@ -102,15 +105,16 @@ def _get_current_identity_from_span(span: Span) -> tuple[dict[str, str], dict[st user_dict = {} org_dict = {} - if hasattr(span, 'attributes') and span.attributes: - if semconv.BasaltUser.ID in span.attributes: - user_dict['id'] = str(span.attributes[semconv.BasaltUser.ID]) - if semconv.BasaltUser.NAME in span.attributes: - user_dict['name'] = str(span.attributes[semconv.BasaltUser.NAME]) - if semconv.BasaltOrganization.ID in span.attributes: - org_dict['id'] = str(span.attributes[semconv.BasaltOrganization.ID]) - if semconv.BasaltOrganization.NAME in span.attributes: - org_dict['name'] = str(span.attributes[semconv.BasaltOrganization.NAME]) + attributes = getattr(span, "attributes", None) + if isinstance(attributes, Mapping): + if semconv.BasaltUser.ID in attributes: + user_dict["id"] = str(attributes[semconv.BasaltUser.ID]) + if semconv.BasaltUser.NAME in attributes: + user_dict["name"] = str(attributes[semconv.BasaltUser.NAME]) + if semconv.BasaltOrganization.ID in attributes: + org_dict["id"] = str(attributes[semconv.BasaltOrganization.ID]) + if semconv.BasaltOrganization.NAME in attributes: + org_dict["name"] = str(attributes[semconv.BasaltOrganization.NAME]) return user_dict, org_dict @@ -125,13 +129,13 @@ def _parse_identity_input(value: str | dict[str, Any] | None) -> dict[str, str]: if value is None: return {} if isinstance(value, str): - return {'id': value} + return {"id": value} if isinstance(value, dict): result = {} - if 'id' in value: - result['id'] = str(value['id']) if value['id'] is not None else '' - if 'name' in value: - result['name'] = str(value['name']) if value['name'] is not None else '' + if "id" in value: + result["id"] = str(value["id"]) if value["id"] is not None else "" + if "name" in value: + result["name"] = str(value["name"]) if value["name"] is not None else "" return result return {} @@ -141,5 +145,21 @@ def _merge_identity(existing: dict[str, str], new: dict[str, str]) -> dict[str, return {**existing, **new} +def _apply_identity_attributes( + *, + span: Span, + current: dict[str, str], + incoming: str | dict[str, Any] | None, + id_attr: str, + name_attr: str, +) -> None: + new_identity = _parse_identity_input(incoming) + merged_identity = _merge_identity(current, new_identity) + if "id" in merged_identity: + span.set_attribute(id_attr, merged_identity["id"]) + if "name" in merged_identity: + span.set_attribute(name_attr, merged_identity["name"]) + + # Singleton instance trace_api = Trace diff --git a/basalt/observability/trace_context.py b/basalt/observability/trace_context.py index 4050d24..ff9da62 100644 --- a/basalt/observability/trace_context.py +++ b/basalt/observability/trace_context.py @@ -5,7 +5,7 @@ from collections.abc import Mapping from dataclasses import dataclass from threading import RLock -from typing import Any, Final +from typing import Any, Final, cast from opentelemetry import context as otel_context from opentelemetry.trace import Span @@ -57,7 +57,9 @@ def clone(self) -> _TraceContextConfig: """Return a defensive copy of the configuration.""" return _TraceContextConfig( experiment=self.experiment, - observe_metadata=dict(self.observe_metadata) if self.observe_metadata is not None else {}, + observe_metadata=dict(self.observe_metadata) + if self.observe_metadata is not None + else {}, sample_rate=self.sample_rate, ) @@ -133,6 +135,8 @@ def set_global_sample_rate(sample_rate: float) -> None: sample_rate=sample_rate, ) _set_trace_defaults(new_config) + + def configure_global_metadata(metadata: dict[str, Any] | None) -> None: """ Configure global observability metadata applied to all traces. @@ -157,7 +161,9 @@ def apply_trace_defaults(span: Span, defaults: _TraceContextConfig | None = None if context.experiment.name: span.set_attribute(semconv.BasaltExperiment.NAME, context.experiment.name) if context.experiment.feature_slug: - span.set_attribute(semconv.BasaltExperiment.FEATURE_SLUG, context.experiment.feature_slug) + span.set_attribute( + semconv.BasaltExperiment.FEATURE_SLUG, context.experiment.feature_slug + ) if context.observe_metadata: for key, value in context.observe_metadata.items(): @@ -166,15 +172,17 @@ def apply_trace_defaults(span: Span, defaults: _TraceContextConfig | None = None def get_context_user() -> TraceIdentity | None: """Retrieve user identity from the current OpenTelemetry context.""" - return otel_context.get_value(USER_CONTEXT_KEY) + return cast(TraceIdentity | None, otel_context.get_value(USER_CONTEXT_KEY)) def get_context_organization() -> TraceIdentity | None: """Retrieve organization identity from the current OpenTelemetry context.""" - return otel_context.get_value(ORGANIZATION_CONTEXT_KEY) + return cast(TraceIdentity | None, otel_context.get_value(ORGANIZATION_CONTEXT_KEY)) -def apply_user_from_context(span: Span, user: TraceIdentity | Mapping[str, Any] | None = None) -> None: +def apply_user_from_context( + span: Span, user: TraceIdentity | Mapping[str, Any] | None = None +) -> None: """ Apply user identity to a span from the provided value or OpenTelemetry context. diff --git a/basalt/observability/utils.py b/basalt/observability/utils.py index ecd74af..242a79f 100644 --- a/basalt/observability/utils.py +++ b/basalt/observability/utils.py @@ -7,6 +7,8 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any +from opentelemetry.trace import Span + if TYPE_CHECKING: from .context_managers import LLMSpanHandle @@ -15,7 +17,7 @@ from .trace_context import TraceIdentity -def apply_span_metadata(span: Any, metadata: Mapping[str, Any] | None) -> None: +def apply_span_metadata(span: Span, metadata: Mapping[str, Any] | None) -> None: """Apply metadata to a span as an aggregated JSON object at basalt.metadata. Behavior: @@ -48,18 +50,33 @@ def apply_span_metadata(span: Any, metadata: Mapping[str, Any] | None) -> None: except Exception: # Fallback: try to serialize with str() for non-serializable values try: - safe_merged = {k: v if isinstance(v, (str, bool, int, float, type(None))) else str(v) - for k, v in merged.items()} + safe_merged = { + k: v if isinstance(v, (str, bool, int, float, type(None))) else str(v) + for k, v in merged.items() + } span.set_attribute(semconv.BasaltSpan.METADATA, json.dumps(safe_merged)) except Exception: pass +def apply_prompt_context_attributes(span: Span, prompt_ctx: Mapping[str, Any]) -> None: + span.set_attribute("basalt.prompt.slug", prompt_ctx["slug"]) + if prompt_ctx.get("version"): + span.set_attribute("basalt.prompt.version", prompt_ctx["version"]) + if prompt_ctx.get("tag"): + span.set_attribute("basalt.prompt.tag", prompt_ctx["tag"]) + span.set_attribute("basalt.prompt.model.provider", prompt_ctx["provider"]) + span.set_attribute("basalt.prompt.model.model", prompt_ctx["model"]) + if prompt_ctx.get("variables"): + span.set_attribute("basalt.prompt.variables", json.dumps(prompt_ctx["variables"])) + span.set_attribute("basalt.prompt.from_cache", prompt_ctx["from_cache"]) + + def resolve_attributes( - attributes: Any, + attributes: object, args: tuple[Any, ...], kwargs: dict[str, Any], -) -> Callable | None | Any: +) -> object: """Resolve attributes into a dictionary.""" if attributes is None: return None @@ -85,9 +102,9 @@ def resolve_bound_arguments( def resolve_payload_from_bound( - resolver: Any, + resolver: object, bound: inspect.BoundArguments | None, -) -> Any: +) -> object: """Resolve input payload from bound arguments.""" if resolver is None: if not bound: @@ -111,7 +128,11 @@ def resolve_payload_from_bound( def resolve_variables_payload( - resolver: dict[str, Any] | Callable[[inspect.BoundArguments | None], Mapping[str, Any]] | Sequence[str] | Mapping[str, Any] | None, + resolver: dict[str, Any] + | Callable[[inspect.BoundArguments | None], Mapping[str, Any] | None] + | Sequence[str] + | Mapping[str, Any] + | None, bound: inspect.BoundArguments | None, ) -> Mapping[str, Any] | None: """Resolve variables payload.""" @@ -125,7 +146,7 @@ def resolve_variables_payload( return None return payload if isinstance(resolver, Mapping): - return resolver + return {str(key): value for key, value in resolver.items()} if isinstance(resolver, Sequence) and not isinstance(resolver, (str, bytes)): if not bound: return None @@ -138,9 +159,9 @@ def resolve_variables_payload( def resolve_evaluators_payload( - resolver: Any, + resolver: object, bound: inspect.BoundArguments | None, - result: Any | None = None, + result: object | None = None, ) -> list[Any] | None: """Resolve evaluator specifications.""" if resolver is None: @@ -162,7 +183,7 @@ def resolve_evaluators_payload( def _normalize_identity_value( - value: Any, + value: object, ) -> TraceIdentity | dict[str, Any] | None: """Normalize a user/org identity specification.""" if value is None: @@ -184,7 +205,7 @@ def _normalize_identity_value( def resolve_identity_payload( - resolver: Any, + resolver: object, bound: inspect.BoundArguments | None, ) -> tuple[TraceIdentity | dict[str, Any] | None, TraceIdentity | dict[str, Any] | None]: """ @@ -205,13 +226,15 @@ def resolve_identity_payload( if payload is None: return None, None - def _from_mapping(mapping: Mapping[str, Any]) -> tuple[ - TraceIdentity | dict[str, Any] | None, TraceIdentity | dict[str, Any] | None - ]: - lowered = {str(key).lower(): value for key, value in mapping.items() if isinstance(key, str)} + def _from_mapping( + mapping: Mapping[str, Any], + ) -> tuple[TraceIdentity | dict[str, Any] | None, TraceIdentity | dict[str, Any] | None]: + lowered = { + str(key).lower(): value for key, value in mapping.items() if isinstance(key, str) + } - user_spec: Any | None = None - org_spec: Any | None = None + user_spec: object | None = None + org_spec: object | None = None if "user" in lowered: user_spec = lowered["user"] @@ -249,7 +272,7 @@ def _from_mapping(mapping: Mapping[str, Any]) -> tuple[ return _normalize_identity_value(payload), None -def _extract_first(bound, keys: tuple[str, ...]) -> Any | None: +def _extract_first(bound, keys: tuple[str, ...]) -> object: if not bound: return None for key in keys: @@ -258,7 +281,7 @@ def _extract_first(bound, keys: tuple[str, ...]) -> Any | None: return None -def default_generation_input(bound: inspect.BoundArguments | None) -> Any: +def default_generation_input(bound: inspect.BoundArguments | None) -> object: value = _extract_first(bound, ("prompt", "input", "inputs", "messages", "question")) if value is not None: return value @@ -270,7 +293,7 @@ def default_generation_variables(bound: inspect.BoundArguments | None) -> Mappin return value if isinstance(value, Mapping) else None -def default_retrieval_input(bound: inspect.BoundArguments | None) -> Any: +def default_retrieval_input(bound: inspect.BoundArguments | None) -> object: value = _extract_first(bound, ("query", "question", "text", "search")) if value is not None: return value @@ -282,7 +305,7 @@ def default_retrieval_variables(bound: inspect.BoundArguments | None) -> Mapping return value if isinstance(value, Mapping) else None -def serialize_prompt(value: Any) -> str | None: +def serialize_prompt(value: object) -> str | None: if value is None: return None if isinstance(value, str): @@ -293,7 +316,7 @@ def serialize_prompt(value: Any) -> str | None: return str(value) -def extract_completion(result: Any) -> str | None: +def extract_completion(result: object) -> str | None: if result is None: return None if isinstance(result, str): @@ -304,12 +327,16 @@ def extract_completion(result: Any) -> str | None: data = result elif hasattr(result, "model_dump"): try: - data = result.model_dump() + # Using getattr for type-safe dynamic attribute access + model_dump = result.model_dump # type: ignore[attr-defined] + data = model_dump() except Exception: data = None elif hasattr(result, "dict"): try: - data = result.dict() + # Using getattr for type-safe dynamic attribute access + dict_method = result.dict # type: ignore[attr-defined] + data = dict_method() except Exception: data = None elif hasattr(result, "__dict__"): @@ -337,18 +364,21 @@ def extract_completion(result: Any) -> str | None: return None -def extract_usage(result: Any) -> tuple[int | None, int | None]: - usage_section: Any | None = None +def extract_usage(result: object) -> tuple[int | None, int | None]: + usage_section: object | None = None if isinstance(result, dict): usage_section = result.get("usage") elif hasattr(result, "usage"): - usage_section = result.usage - elif hasattr(result, "model_dump"): - try: - dumped = result.model_dump() - usage_section = dumped.get("usage") - except Exception: - usage_section = None + # Using getattr for type-safe dynamic attribute access + usage_section = getattr(result, "usage", None) + else: + model_dump = getattr(result, "model_dump", None) + if callable(model_dump): + try: + dumped = model_dump() + usage_section = dumped.get("usage") if isinstance(dumped, dict) else None + except Exception: + usage_section = None if not isinstance(usage_section, dict): return None, None input_tokens = usage_section.get("prompt_tokens") or usage_section.get("input_tokens") @@ -358,7 +388,7 @@ def extract_usage(result: Any) -> tuple[int | None, int | None]: return input_tokens, output_tokens -def apply_llm_request_metadata(span: LLMSpanHandle, bound) -> None: +def apply_llm_request_metadata(span: LLMSpanHandle, bound: inspect.BoundArguments | None) -> None: if not bound: return model = _extract_first(bound, ("model", "model_name")) @@ -370,7 +400,7 @@ def apply_llm_request_metadata(span: LLMSpanHandle, bound) -> None: span.set_prompt(serialized) -def apply_llm_response_metadata(span: LLMSpanHandle, result: Any) -> None: +def apply_llm_response_metadata(span: LLMSpanHandle, result: object) -> None: completion = extract_completion(result) if completion and trace_content_enabled(): span.set_completion(completion) diff --git a/basalt/prompts/__init__.py b/basalt/prompts/__init__.py index a119056..1b5d257 100644 --- a/basalt/prompts/__init__.py +++ b/basalt/prompts/__init__.py @@ -5,6 +5,7 @@ with the Basalt Prompts API. The client is lazily imported to avoid circular imports during initialization. """ + from typing import TYPE_CHECKING, Any from .models import ( @@ -38,7 +39,7 @@ ] -def __getattr__(name: str) -> Any: +def __getattr__(name: str) -> object: if name == "PromptsClient": from .client import PromptsClient diff --git a/basalt/prompts/client.py b/basalt/prompts/client.py index b8c8610..5938206 100644 --- a/basalt/prompts/client.py +++ b/basalt/prompts/client.py @@ -3,12 +3,20 @@ This module provides the PromptsClient for interacting with the Basalt Prompts API. """ + from __future__ import annotations +import builtins +from collections.abc import Mapping from typing import Any, cast -from .._internal.base_client import BaseServiceClient -from .._internal.http import HTTPClient +try: + from typing import Unpack +except ImportError: # pragma: no cover - fallback for Python < 3.11 + from typing_extensions import Unpack + +from .._internal.base_client import BaseServiceClient, HTTPRequestKwargs +from .._internal.http import HTTPClient, HTTPResponse from ..config import config from ..observability.spans import BasaltRequestSpan from ..types.cache import CacheProtocol @@ -36,10 +44,7 @@ def format_output(self, response_data: dict[str, Any]) -> dict[str, Any]: return {"status_code": status_code} # Return the full prompt object as JSON - return { - "prompt": prompt_data, - "from_cache": response_data.get("from_cache", False) - } + return {"prompt": prompt_data, "from_cache": response_data.get("from_cache", False)} class PromptsClient(BaseServiceClient): @@ -58,7 +63,7 @@ def __init__( base_url: str | None = None, http_client: HTTPClient | None = None, log_level: str | None = None, - ): + ) -> None: """ Initialize the PromptsClient. @@ -78,17 +83,37 @@ def __init__( # Cache responses for 5 minutes self._cache_duration = 5 * 60 + @staticmethod + def _prompt_response_from_api(response: HTTPResponse) -> PromptResponse: + if response is None or response.body is None: + raise BasaltAPIError("Empty response from get prompt API") + prompt_data = response.get("prompt", {}) + if not isinstance(prompt_data, dict): + raise BasaltAPIError("Invalid prompt data in get prompt response") + return PromptResponse.from_dict(prompt_data) + + @staticmethod + def _publish_response_from_api(response: HTTPResponse | None) -> PublishPromptResponse: + if response is None: + raise BasaltAPIError("Empty response from publish prompt API") + payload = response.json() or {} + if not payload: + raise BasaltAPIError("Empty response from publish prompt API") + if not isinstance(payload, Mapping): + raise BasaltAPIError("Invalid publish prompt response") + return PublishPromptResponse.from_dict(payload) + async def _request_async( self, operation: str, *, method: str, url: str, - span_attributes: dict[str, Any] | None = None, - span_variables: dict[str, Any] | None = None, + span_attributes: Mapping[str, Any] | None = None, + span_variables: Mapping[str, Any] | None = None, cache_hit: bool | None = None, - **request_kwargs: Any, - ): + **request_kwargs: Unpack[HTTPRequestKwargs], + ) -> HTTPResponse: """Override to use PromptRequestSpan for custom output formatting.""" # Lazy import to avoid circular dependency import functools @@ -110,7 +135,10 @@ async def _request_async( method=method, **request_kwargs, ) - return await trace_async_request(span, call) + result = await trace_async_request(span, call) + if result is None: + raise BasaltAPIError("Empty response from async prompt API") + return result def _request_sync( self, @@ -118,11 +146,11 @@ def _request_sync( *, method: str, url: str, - span_attributes: dict[str, Any] | None = None, - span_variables: dict[str, Any] | None = None, + span_attributes: Mapping[str, Any] | None = None, + span_variables: Mapping[str, Any] | None = None, cache_hit: bool | None = None, - **request_kwargs: Any, - ): + **request_kwargs: Unpack[HTTPRequestKwargs], + ) -> HTTPResponse: """Override to use PromptRequestSpan for custom output formatting.""" # Lazy import to avoid circular dependency import functools @@ -144,7 +172,10 @@ def _request_sync( method=method, **request_kwargs, ) - return trace_sync_request(span, call) + result = trace_sync_request(span, call) + if result is None: + raise BasaltAPIError("Empty response from sync prompt API") + return result async def get( self, @@ -211,10 +242,7 @@ async def get( span_variables=variables, ) - if response is None or response.body is None: - raise BasaltAPIError("Empty response from get prompt API") - prompt_data = response.get("prompt", {}) - prompt_response = PromptResponse.from_dict(prompt_data) + prompt_response = self._prompt_response_from_api(response) # Store in both caches if cache_enabled: @@ -315,10 +343,7 @@ def get_sync( span_variables=variables, ) - if response is None or response.body is None: - raise BasaltAPIError("Empty response from get prompt API") - prompt_data = response.get("prompt", {}) - prompt_response = PromptResponse.from_dict(prompt_data) + prompt_response = self._prompt_response_from_api(response) # Store in both caches if cache_enabled: @@ -398,6 +423,8 @@ async def describe( if response is None or response.body is None: raise BasaltAPIError("Empty response from describe prompt API") prompt_data = response.get("prompt", {}) + if not isinstance(prompt_data, dict): + raise BasaltAPIError("Invalid prompt data in describe response") return DescribePromptResponse.from_dict(prompt_data) def describe_sync( @@ -444,9 +471,11 @@ def describe_sync( if response is None or response.body is None: raise BasaltAPIError("Empty response from describe prompt API") prompt_data = response.get("prompt", {}) + if not isinstance(prompt_data, dict): + raise BasaltAPIError("Invalid prompt data in describe response") return DescribePromptResponse.from_dict(prompt_data) - async def list(self, feature_slug: str | None = None) -> list[PromptListResponse]: + async def list(self, feature_slug: str | None = None) -> builtins.list[PromptListResponse]: """ List prompts, optionally filtering by feature_slug. @@ -480,13 +509,11 @@ async def list(self, feature_slug: str | None = None) -> list[PromptListResponse return [] prompts_data = response.get("prompts", []) - return [ - PromptListResponse.from_dict(p) - for p in prompts_data - if isinstance(p, dict) - ] + if not isinstance(prompts_data, list): + return [] + return [PromptListResponse.from_dict(p) for p in prompts_data if isinstance(p, dict)] - def list_sync(self, feature_slug: str | None = None) -> list[PromptListResponse]: + def list_sync(self, feature_slug: str | None = None) -> builtins.list[PromptListResponse]: """ Synchronously list prompts, optionally filtering by feature_slug. @@ -520,11 +547,9 @@ def list_sync(self, feature_slug: str | None = None) -> list[PromptListResponse] return [] prompts_data = response.get("prompts", []) - return [ - PromptListResponse.from_dict(p) - for p in prompts_data - if isinstance(p, dict) - ] + if not isinstance(prompts_data, list): + return [] + return [PromptListResponse.from_dict(p) for p in prompts_data if isinstance(p, dict)] async def publish( self, @@ -571,13 +596,7 @@ async def publish( }, ) - if response is None: - raise BasaltAPIError("Empty response from publish prompt API") - payload = response.json() or {} - if not payload: - raise BasaltAPIError("Empty response from publish prompt API") - - return PublishPromptResponse.from_dict(payload) + return self._publish_response_from_api(response) def publish_sync( self, @@ -624,13 +643,7 @@ def publish_sync( }, ) - if response is None: - raise BasaltAPIError("Empty response from publish prompt API") - payload = response.json() or {} - if not payload: - raise BasaltAPIError("Empty response from publish prompt API") - - return PublishPromptResponse.from_dict(payload) + return self._publish_response_from_api(response) @staticmethod def _create_prompt_instance( diff --git a/basalt/prompts/models.py b/basalt/prompts/models.py index 7306d20..87f5d47 100644 --- a/basalt/prompts/models.py +++ b/basalt/prompts/models.py @@ -4,18 +4,19 @@ This module contains all data models and data transfer objects used by the PromptsClient. """ + from __future__ import annotations import json from collections.abc import Mapping -from contextvars import ContextVar +from contextvars import ContextVar, Token from dataclasses import dataclass +from types import TracebackType from typing import Any # Context variable for prompt data injection into child spans _current_prompt_context: ContextVar[dict[str, Any] | None] = ContextVar( - '_current_prompt_context', - default=None + "_current_prompt_context", default=None ) @@ -25,6 +26,7 @@ class PromptModelParameters: Immutable and uses slots to reduce per-instance memory overhead. """ + temperature: float max_length: int response_format: str @@ -60,10 +62,14 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModelParameters: top_p = float(top_p) if isinstance(top_p, (int, float)) else None frequency_penalty = data.get("frequencyPenalty") - frequency_penalty = float(frequency_penalty) if isinstance(frequency_penalty, (int, float)) else None + frequency_penalty = ( + float(frequency_penalty) if isinstance(frequency_penalty, (int, float)) else None + ) presence_penalty = data.get("presencePenalty") - presence_penalty = float(presence_penalty) if isinstance(presence_penalty, (int, float)) else None + presence_penalty = ( + float(presence_penalty) if isinstance(presence_penalty, (int, float)) else None + ) json_object = data.get("jsonObject") json_object = dict(json_object) if isinstance(json_object, Mapping) else None @@ -86,6 +92,7 @@ class PromptModel: Immutable and uses slots to reduce per-instance memory overhead. """ + provider: str model: str version: str @@ -106,7 +113,9 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModel: version = data.get("version") if isinstance(data.get("version"), str) else "" parameters_data = data.get("parameters") - parameters = PromptModelParameters.from_dict(parameters_data if isinstance(parameters_data, Mapping) else None) + parameters = PromptModelParameters.from_dict( + parameters_data if isinstance(parameters_data, Mapping) else None + ) return cls( provider=str(provider), @@ -119,6 +128,7 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptModel: @dataclass class PromptParams: """Parameters for creating a new prompt instance.""" + slug: str text: str model: PromptModel @@ -141,7 +151,7 @@ class Prompt: prompt = basalt.prompts.get( slug="qa-prompt", version="2.1.0", - variables={"context": "Paris is the capital of France"} + variables={"context": "Paris is the capital of France"}, ) # Access prompt properties @@ -149,6 +159,7 @@ class Prompt: print(prompt.model.provider) ``` """ + slug: str text: str raw_text: str @@ -190,6 +201,7 @@ class _PromptContextMixin: _tag: str | None _variables: dict[str, Any] | None _from_cache: bool + _context_token: Token[dict[str, Any] | None] | None def _set_span_attributes(self) -> None: from basalt.observability import semconv @@ -209,6 +221,23 @@ def _set_span_attributes(self) -> None: span.set_attribute("basalt.prompt.variables", json.dumps(self._variables)) span.set_attribute("basalt.prompt.from_cache", self._from_cache) + def _enter_prompt_context(self) -> None: + # Store prompt metadata in context for child spans. + prompt_ctx = { + "slug": self._slug, + "version": self._version, + "tag": self._tag, + "provider": self._prompt.model.provider, + "model": self._prompt.model.model, + "variables": self._variables, + "from_cache": self._from_cache, + } + self._context_token = _current_prompt_context.set(prompt_ctx) + + def _exit_prompt_context(self) -> None: + if self._context_token is not None: + _current_prompt_context.reset(self._context_token) + class PromptContextManager(_PromptContextMixin): """ @@ -265,9 +294,9 @@ def __init__( self._tag: str | None = tag self._variables: dict[str, Any] | None = variables self._from_cache: bool = from_cache - self._context_token: Any = None + self._context_token: Token[dict[str, Any] | None] | None = None - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> object: """Forward all attribute access to the wrapped Prompt.""" return getattr(self._prompt, name) @@ -277,27 +306,21 @@ def __enter__(self) -> PromptContextManager: This allows child spans to automatically receive prompt attributes. """ - # Store prompt metadata in context - prompt_ctx = { - "slug": self._slug, - "version": self._version, - "tag": self._tag, - "provider": self._prompt.model.provider, - "model": self._prompt.model.model, - "variables": self._variables, - "from_cache": self._from_cache, - } - self._context_token = _current_prompt_context.set(prompt_ctx) + self._enter_prompt_context() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: """Exit context manager mode - clear prompt context.""" try: # Any cleanup logic pass finally: - if self._context_token is not None: - _current_prompt_context.reset(self._context_token) + self._exit_prompt_context() # Don't suppress exceptions return False @@ -359,9 +382,9 @@ def __init__( self._tag: str | None = tag self._variables: dict[str, Any] | None = variables self._from_cache: bool = from_cache - self._context_token: Any = None + self._context_token: Token[dict[str, Any] | None] | None = None - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> object: """Forward all attribute access to the wrapped Prompt.""" return getattr(self._prompt, name) @@ -371,27 +394,21 @@ async def __aenter__(self) -> AsyncPromptContextManager: This allows child spans to automatically receive prompt attributes. """ - # Store prompt metadata in context - prompt_ctx = { - "slug": self._slug, - "version": self._version, - "tag": self._tag, - "provider": self._prompt.model.provider, - "model": self._prompt.model.model, - "variables": self._variables, - "from_cache": self._from_cache, - } - self._context_token = _current_prompt_context.set(prompt_ctx) + self._enter_prompt_context() return self - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: """Exit async context manager mode - clear prompt context.""" try: # Any cleanup logic pass finally: - if self._context_token is not None: - _current_prompt_context.reset(self._context_token) + self._exit_prompt_context() # Don't suppress exceptions return False @@ -415,6 +432,7 @@ class PromptResponse: Immutable and uses slots to reduce per-instance memory overhead. """ + text: str slug: str version: str @@ -450,12 +468,42 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptResponse: ) +def _parse_prompt_descriptor_fields( + data: Mapping[str, Any] | None, +) -> tuple[str, str, str, str, list[str], list[str]]: + if data is None: + data = {} + + slug = data.get("slug") if isinstance(data.get("slug"), str) else "" + status = data.get("status") if isinstance(data.get("status"), str) else "" + name = data.get("name") if isinstance(data.get("name"), str) else "" + description = data.get("description") if isinstance(data.get("description"), str) else "" + + available_versions_raw = data.get("availableVersions") + available_versions = ( + list(available_versions_raw) if isinstance(available_versions_raw, list) else [] + ) + + available_tags_raw = data.get("availableTags") + available_tags = list(available_tags_raw) if isinstance(available_tags_raw, list) else [] + + return ( + str(slug), + str(status), + str(name), + str(description), + available_versions, + available_tags, + ) + + @dataclass(slots=True, frozen=True) class DescribePromptResponse: """Response from the Describe Prompt API endpoint. Immutable and uses slots to reduce per-instance memory overhead. """ + slug: str status: str name: str @@ -470,28 +518,23 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> DescribePromptResponse: Robust against missing keys or wrong types. Copies mutable inputs. """ - if data is None: - data = {} - - slug = data.get("slug") if isinstance(data.get("slug"), str) else "" - status = data.get("status") if isinstance(data.get("status"), str) else "" - name = data.get("name") if isinstance(data.get("name"), str) else "" - description = data.get("description") if isinstance(data.get("description"), str) else "" - - available_versions_raw = data.get("availableVersions") - available_versions = list(available_versions_raw) if isinstance(available_versions_raw, list) else [] - - available_tags_raw = data.get("availableTags") - available_tags = list(available_tags_raw) if isinstance(available_tags_raw, list) else [] - - variables_raw = data.get("variables") + ( + slug, + status, + name, + description, + available_versions, + available_tags, + ) = _parse_prompt_descriptor_fields(data) + + variables_raw = data.get("variables") if data else None variables = list(variables_raw) if isinstance(variables_raw, list) else [] return cls( - slug=str(slug), - status=str(status), - name=str(name), - description=str(description), + slug=slug, + status=status, + name=name, + description=description, available_versions=available_versions, available_tags=available_tags, variables=variables, @@ -504,6 +547,7 @@ class PromptListResponse: Immutable and uses slots to reduce per-instance memory overhead. """ + slug: str status: str name: str @@ -517,25 +561,20 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PromptListResponse: Robust against missing keys or wrong types. Copies mutable inputs. """ - if data is None: - data = {} - - slug = data.get("slug") if isinstance(data.get("slug"), str) else "" - status = data.get("status") if isinstance(data.get("status"), str) else "" - name = data.get("name") if isinstance(data.get("name"), str) else "" - description = data.get("description") if isinstance(data.get("description"), str) else "" - - available_versions_raw = data.get("availableVersions") - available_versions = list(available_versions_raw) if isinstance(available_versions_raw, list) else [] - - available_tags_raw = data.get("availableTags") - available_tags = list(available_tags_raw) if isinstance(available_tags_raw, list) else [] + ( + slug, + status, + name, + description, + available_versions, + available_tags, + ) = _parse_prompt_descriptor_fields(data) return cls( - slug=str(slug), - status=str(status), - name=str(name), - description=str(description), + slug=slug, + status=status, + name=name, + description=description, available_versions=available_versions, available_tags=available_tags, ) @@ -547,6 +586,7 @@ class PublishPromptResponse: Immutable and uses slots to reduce per-instance memory overhead. """ + id: str label: str @@ -563,7 +603,9 @@ def from_dict(cls, data: Mapping[str, Any] | None) -> PublishPromptResponse: deployment_tag = data.get("deploymentTag") if isinstance(deployment_tag, Mapping): id_val = deployment_tag.get("id") if isinstance(deployment_tag.get("id"), str) else "" - label_val = deployment_tag.get("label") if isinstance(deployment_tag.get("label"), str) else "" + label_val = ( + deployment_tag.get("label") if isinstance(deployment_tag.get("label"), str) else "" + ) else: id_val = data.get("id") if isinstance(data.get("id"), str) else "" label_val = data.get("label") if isinstance(data.get("label"), str) else "" diff --git a/basalt/types/__init__.py b/basalt/types/__init__.py index ac8a045..841a239 100644 --- a/basalt/types/__init__.py +++ b/basalt/types/__init__.py @@ -9,6 +9,7 @@ from basalt.types import Prompt, Dataset, PromptModel ``` """ + from ..datasets.models import Dataset, DatasetRow from ..prompts.models import ( DescribePromptResponse, @@ -19,6 +20,7 @@ PromptParams, PromptResponse, ) +from .common import JSONDict, JSONList, JSONPrimitive, JSONValue, SpanAttributeValue __all__ = [ # Prompt types @@ -32,4 +34,10 @@ # Dataset types "Dataset", "DatasetRow", + # Common types + "JSONValue", + "JSONDict", + "JSONList", + "JSONPrimitive", + "SpanAttributeValue", ] diff --git a/basalt/types/cache.py b/basalt/types/cache.py index fe746ad..8861562 100644 --- a/basalt/types/cache.py +++ b/basalt/types/cache.py @@ -1,15 +1,16 @@ """Cache protocol used throughout the Basalt SDK.""" + from __future__ import annotations from collections.abc import Hashable -from typing import Any, Protocol +from typing import Protocol class CacheProtocol(Protocol): """Minimal protocol implemented by cache backends.""" - def get(self, key: Hashable) -> Any | None: + def get(self, key: Hashable) -> object | None: """Return a cached value for *key* or ``None`` when missing.""" - def put(self, key: Hashable, value: Any, ttl: float = float("inf")) -> None: + def put(self, key: Hashable, value: object, ttl: float = float("inf")) -> None: """Store *value* for *key* with a time-to-live in seconds.""" diff --git a/basalt/types/common.py b/basalt/types/common.py new file mode 100644 index 0000000..55a99ef --- /dev/null +++ b/basalt/types/common.py @@ -0,0 +1,13 @@ +"""Common type aliases used throughout the Basalt SDK.""" + +from typing import TypeAlias + +# JSON-serializable types +# These represent data that can be safely serialized to/from JSON +JSONPrimitive: TypeAlias = str | int | float | bool | None +JSONValue: TypeAlias = JSONPrimitive | dict[str, "JSONValue"] | list["JSONValue"] +JSONDict: TypeAlias = dict[str, JSONValue] +JSONList: TypeAlias = list[JSONValue] + +# OpenTelemetry span attributes can only be primitive types +SpanAttributeValue: TypeAlias = str | int | float | bool | None diff --git a/basalt/types/exceptions.py b/basalt/types/exceptions.py index 4457e9f..a50041d 100644 --- a/basalt/types/exceptions.py +++ b/basalt/types/exceptions.py @@ -4,7 +4,7 @@ class BasaltError(Exception): """Base exception for all Basalt SDK errors.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message super().__init__(message) @@ -12,7 +12,7 @@ def __init__(self, message: str): class BasaltAPIError(BasaltError): """Base exception for API-related errors.""" - def __init__(self, message: str, status_code: int | None = None): + def __init__(self, message: str, status_code: int | None = None) -> None: self.status_code = status_code super().__init__(message) @@ -20,35 +20,35 @@ def __init__(self, message: str, status_code: int | None = None): class BadRequestError(BasaltAPIError): """Raised when the API returns a 400 Bad Request error.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: super().__init__(message, status_code=400) class UnauthorizedError(BasaltAPIError): """Raised when the API returns a 401 Unauthorized error.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: super().__init__(message, status_code=401) class ForbiddenError(BasaltAPIError): """Raised when the API returns a 403 Forbidden error.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: super().__init__(message, status_code=403) class NotFoundError(BasaltAPIError): """Raised when the API returns a 404 Not Found error.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: super().__init__(message, status_code=404) class NetworkError(BasaltError): """Raised when a network error occurs.""" - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message super().__init__(message) @@ -56,7 +56,7 @@ def __init__(self, message: str): class FileUploadError(BasaltError): """Raised when file upload to S3 fails.""" - def __init__(self, message: str, file_key: str | None = None): + def __init__(self, message: str, file_key: str | None = None) -> None: self.file_key = file_key super().__init__(message) diff --git a/basalt/utils/memcache.py b/basalt/utils/memcache.py index 85e5079..74f1bb7 100644 --- a/basalt/utils/memcache.py +++ b/basalt/utils/memcache.py @@ -1,9 +1,11 @@ import time from collections.abc import Hashable -from typing import Any +from typing import TypeVar from ..types.cache import CacheProtocol +T = TypeVar("T") + class MemoryCache(CacheProtocol): """ @@ -11,11 +13,11 @@ class MemoryCache(CacheProtocol): It implements the ICache protocol. """ - def __init__(self): - self._mem: dict[Hashable, Any] = {} + def __init__(self) -> None: + self._mem: dict[Hashable, object] = {} self._timeouts: dict[Hashable, float] = {} - def get(self, key: Hashable): + def get(self, key: Hashable) -> object | None: """ Retrieves the value associated with the given key if it has not expired. @@ -33,7 +35,7 @@ def get(self, key: Hashable): return None - def put(self, key: Hashable, value: Any, ttl: float = float('inf')) -> None: + def put(self, key: Hashable, value: object, ttl: float = float("inf")) -> None: """ Stores a value in the cache with an associated time-to-live (TTL). diff --git a/examples/async_observe_example.py b/examples/async_observe_example.py index c4213e9..a3587b9 100644 --- a/examples/async_observe_example.py +++ b/examples/async_observe_example.py @@ -24,11 +24,7 @@ async def process_item(item: dict) -> dict: # Simulate processing await asyncio.sleep(0.05) - result = { - "id": item["id"], - "processed": True, - "result": f"Processed: {item['name']}" - } + result = {"id": item["id"], "processed": True, "result": f"Processed: {item['name']}"} span.set_output(result) return result @@ -56,7 +52,9 @@ async def main(): """Main async workflow demonstrating async_start_observe.""" logging.basicConfig(level=logging.INFO) - async with async_start_observe(name="async_workflow_example", feature_slug="async_demo") as root_span: + async with async_start_observe( + name="async_workflow_example", feature_slug="async_demo" + ) as root_span: logging.info("Starting async workflow...") # Set some metadata on the root span @@ -69,10 +67,7 @@ async def main(): results = await asyncio.gather(*tasks) # Set the final output - root_span.set_output({ - "processed_count": len(results), - "items": results - }) + root_span.set_output({"processed_count": len(results), "items": results}) logging.info(f"Completed processing {len(results)} items") diff --git a/examples/dataset_api_example.py b/examples/dataset_api_example.py index 2881fb5..a0d4f5e 100644 --- a/examples/dataset_api_example.py +++ b/examples/dataset_api_example.py @@ -147,20 +147,16 @@ def example_4_add_dataset_row(client: DatasetsClient) -> None: values[column] = f"test_value_{column}" # Add row with optional metadata - row, warning = client.add_row_sync( + row = client.add_row_sync( slug=dataset_slug, values=values, name="Example Row", ideal_output="expected_output", - metadata={"source": "example_script"} + metadata={"source": "example_script"}, ) logging.info("Row added successfully") - logging.info(f"Row values: {row.values}") - if warning: - logging.warning(f"Warning: {warning}\n") - else: - logging.info("") + logging.info(f"Row values: {row.values}\n") except NotFoundError: logging.error("Dataset not found", exc_info=True) except UnauthorizedError: @@ -183,7 +179,7 @@ def example_5_get_dataset_metadata(client: DatasetsClient) -> None: logging.info(f"Dataset Name: {dataset.name}") logging.info(f"Dataset Slug: {dataset.slug}") - logging.info(f"Columns: {', '.join(dataset.columns)}") + logging.info(f"Columns: {', '.join(col.name for col in dataset.columns)}") logging.info(f"Total Rows: {len(dataset.rows)}\n") except NotFoundError: logging.error("Dataset not found", exc_info=True) @@ -249,18 +245,14 @@ async def example_8_async_add_row(client: DatasetsClient) -> None: values[column] = f"async_test_{column}" # Add row asynchronously - row, warning = await client.add_row( + row = await client.add_row( slug=dataset_slug, values=values, name="Async Example Row", - metadata={"source": "async_example"} + metadata={"source": "async_example"}, ) - logging.info("Row added asynchronously") - if warning: - logging.warning(f"Warning: {warning}\n") - else: - logging.info("") + logging.info("Row added asynchronously\n") except NotFoundError: logging.error("Dataset not found", exc_info=True) except BasaltAPIError: @@ -312,7 +304,9 @@ async def example_9_concurrent_operations(client: DatasetsClient) -> None: except BasaltAPIError as e: logging.error(f"Basalt API error during concurrent operations: {e}", exc_info=True) except Exception as e: - logging.error(f"Unexpected error during concurrent operations: {type(e).__name__}: {e}", exc_info=True) + logging.error( + f"Unexpected error during concurrent operations: {type(e).__name__}: {e}", exc_info=True + ) async def run_async_examples(client: DatasetsClient) -> None: diff --git a/examples/dataset_sdk_demo.ipynb b/examples/dataset_sdk_demo.ipynb index 92cc0dc..e71f62b 100644 --- a/examples/dataset_sdk_demo.ipynb +++ b/examples/dataset_sdk_demo.ipynb @@ -14,7 +14,9 @@ "import os\n", "import sys\n", "\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "sys.path.append(\n", + " os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + ") # Needed to make notebook work in VSCode\n", "\n", "os.environ[\"BASALT_BUILD\"] = \"development\"\n", "\n", @@ -63,6 +65,7 @@ " except Exception:\n", " return []\n", "\n", + "\n", "# Run the async function\n", "datasets = await list_datasets()" ] @@ -100,6 +103,7 @@ " else:\n", " return None, None\n", "\n", + "\n", "# Run the async function\n", "sample_dataset, dataset = await get_dataset(datasets)" ] @@ -139,7 +143,11 @@ " values=values,\n", " name=\"Async Sample Row\",\n", " ideal_output=\"Expected output for this row\",\n", - " metadata={\"source\": \"async_example\", \"type\": \"demo\", \"notebook\": \"dataset_sdk_demo\"}\n", + " metadata={\n", + " \"source\": \"async_example\",\n", + " \"type\": \"demo\",\n", + " \"notebook\": \"dataset_sdk_demo\",\n", + " },\n", " )\n", "\n", " if row.name:\n", @@ -153,6 +161,7 @@ " else:\n", " return None, None\n", "\n", + "\n", "# Run the async function\n", "row_result, warning = await add_row(sample_dataset)\n", "\n", diff --git a/examples/gemini_random_data_example.py b/examples/gemini_random_data_example.py index 8ed6eb9..5de2a82 100644 --- a/examples/gemini_random_data_example.py +++ b/examples/gemini_random_data_example.py @@ -6,9 +6,11 @@ Demonstrates both decorator-based and manual context manager instrumentation. """ + import asyncio import logging import os +from typing import TypedDict import httpx from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter @@ -19,7 +21,14 @@ genai = None from basalt import Basalt, TelemetryConfig -from basalt.observability import AsyncObserve, EvaluationConfig, ObserveKind, evaluate, observe, start_observe +from basalt.observability import ( + AsyncObserve, + EvaluationConfig, + ObserveKind, + evaluate, + observe, + start_observe, +) # --- Constants --- # specific model version to ensure consistency across execution and telemetry @@ -27,6 +36,14 @@ GEMINI_EMBEDDING_MODEL = "gemini-embedding-001" +class EmbeddingResult(TypedDict): + """Result of embedding generation.""" + + dimension: int + sample_values: list[float] + status: str + + # --- 1. Build Basalt client with custom OTLP exporter --- def build_custom_exporter_client() -> Basalt: """ @@ -49,7 +66,10 @@ def build_custom_exporter_client() -> Basalt: # Note: insecure=True is used here for local/demo purposes. # Use secure credentials for production. exporter = OTLPSpanExporter( - endpoint=otlp_endpoint, headers={"authorization": f"Bearer {api_key}"}, insecure=True, timeout=10 + endpoint=otlp_endpoint, + headers={"authorization": f"Bearer {api_key}"}, + insecure=True, + timeout=10, ) telemetry = TelemetryConfig( @@ -100,7 +120,7 @@ async def summarize_joke_with_gemini(joke: str) -> str | None: return response.text -async def embed_joke_summary(summary: str) -> dict | None: +async def embed_joke_summary(summary: str) -> EmbeddingResult: """ Generate embeddings for the joke summary using Gemini Embedding model. @@ -131,7 +151,11 @@ async def embed_joke_summary(summary: str) -> dict | None: ) # Extract embedding data + if result.embeddings is None or len(result.embeddings) == 0: + raise ValueError("No embeddings returned from API response") embedding_vector = result.embeddings[0].values + if embedding_vector is None: + raise ValueError("Embedding vector is None from API response") dimension_count = len(embedding_vector) # Token usage: prefer API metadata, fall back to a simple estimate @@ -177,19 +201,23 @@ async def embed_joke_summary(summary: str) -> dict | None: span.set_attribute("gen_ai.embeddings.dimension.count", dimension_count) # Set metadata for Basalt tracking - span.set_metadata({ - "embedding.dimension": dimension_count, - "embedding.status": "success", - "embedding.model": GEMINI_EMBEDDING_MODEL, - }) + span.set_metadata( + { + "embedding.dimension": dimension_count, + "embedding.status": "success", + "embedding.model": GEMINI_EMBEDDING_MODEL, + } + ) # Set output (partial vector only - don't log full 768-dim vector) - output_data = { + output_data: EmbeddingResult = { "dimension": dimension_count, - "sample_values": embedding_vector[:5], # First 5 values only + "sample_values": embedding_vector[:5] + if embedding_vector + else [], # First 5 values only "status": "success", } - span.set_output(output_data) + span.set_output(str(output_data)) logging.info( "Generated %s-dimensional embedding vector", @@ -199,14 +227,18 @@ async def embed_joke_summary(summary: str) -> dict | None: except Exception as exc: logging.error(f"Embedding error: {exc}") - span.set_metadata({ - "embedding.status": "error", - "embedding.error": str(exc), - }) - span.set_output({ - "status": "error", - "error": str(exc), - }) + span.set_metadata( + { + "embedding.status": "error", + "embedding.error": str(exc), + } + ) + span.set_output( + { + "status": "error", + "error": str(exc), + } + ) raise @@ -243,7 +275,10 @@ async def start_workflow() -> None: logging.info(f"Gemini summary: {gemini_result}") observe.set_metadata( - {"gemini.status": "success", "gemini.response_length": len(gemini_result) if gemini_result else 0} + { + "gemini.status": "success", + "gemini.response_length": len(gemini_result) if gemini_result else 0, + } ) # Use the constant to ensure attributes match the actual model used @@ -265,16 +300,20 @@ async def start_workflow() -> None: f"sample: {embedding_data['sample_values'][:3]}" ) - observe.set_metadata({ - "embedding.dimension": embedding_data["dimension"], - "embedding.generated": True, - }) + observe.set_metadata( + { + "embedding.dimension": embedding_data["dimension"], + "embedding.generated": True, + } + ) except Exception as exc: logging.error(f"Failed to generate embeddings: {exc}") - observe.set_metadata({ - "embedding.generated": False, - "embedding.error": str(exc), - }) + observe.set_metadata( + { + "embedding.generated": False, + "embedding.error": str(exc), + } + ) # Don't re-raise - allow workflow to continue even if embeddings fail except Exception as exc: diff --git a/examples/multi_exporter_example.py b/examples/multi_exporter_example.py index b9a0310..4e6c09a 100644 --- a/examples/multi_exporter_example.py +++ b/examples/multi_exporter_example.py @@ -1,6 +1,8 @@ import os -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter as GRPCSpanExporter, +) from opentelemetry.sdk.trace.export import ConsoleSpanExporter from basalt import Basalt, TelemetryConfig @@ -33,8 +35,8 @@ environment="production", exporter=[ basalt_exporter, # Export to Basalt for advanced features - local_exporter, # Export to local collector - console_exporter, # Export to console for debugging + local_exporter, # Export to local collector + console_exporter, # Export to console for debugging ], ) diff --git a/examples/openai_example.py b/examples/openai_example.py index 4c00519..68c24b8 100644 --- a/examples/openai_example.py +++ b/examples/openai_example.py @@ -45,7 +45,11 @@ def build_basalt_client(): # 4. Return Client return Basalt( api_key=basalt_key, - observability_metadata={"env": "development", "provider": "openai", "example": "auto-instrumentation"}, + observability_metadata={ + "env": "development", + "provider": "openai", + "example": "auto-instrumentation", + }, telemetry_config=telemetry, ) @@ -77,7 +81,9 @@ def run_weather_assistant(user_query: str): with start_observe( name="weather_assistant", feature_slug="weather-assistant", - identity=Identity(organization={"id": "123", "name": "Demo Corp"}, user={"id": "456", "name": "Alice"}), + identity=Identity( + organization={"id": "123", "name": "Demo Corp"}, user={"id": "456", "name": "Alice"} + ), ) as span: span.set_input({"query": user_query}) @@ -99,7 +105,10 @@ def run_weather_assistant(user_query: str): model=OPENAI_MODEL_NAME, messages=[ {"role": "system", "content": "You are a helpful weather assistant."}, - {"role": "user", "content": f"Context: {weather_data}\n\nQuery: {user_query}"}, + { + "role": "user", + "content": f"Context: {weather_data}\n\nQuery: {user_query}", + }, ], ) content = response.choices[0].message.content diff --git a/examples/prompt_sdk_demo.ipynb b/examples/prompt_sdk_demo.ipynb index 045860b..c7a3655 100644 --- a/examples/prompt_sdk_demo.ipynb +++ b/examples/prompt_sdk_demo.ipynb @@ -14,7 +14,9 @@ "import os\n", "import sys\n", "\n", - "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Needed to make notebook work in VSCode\n", + "sys.path.append(\n", + " os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + ") # Needed to make notebook work in VSCode\n", "\n", "os.environ[\"BASALT_BUILD\"] = \"development\"\n", "\n", @@ -53,7 +55,7 @@ "# Get a prompt by slug (default is production version)\n", "# For real usage, replace 'prompt-slug' with your actual prompt slug\n", "try:\n", - " result = basalt.prompts.get_sync('prompt-slug')\n", + " result = basalt.prompts.get_sync(\"prompt-slug\")\n", "except Exception:\n", " pass\n", " # In production, handle the error appropriately" @@ -76,13 +78,13 @@ "source": [ "# Get a prompt with a specific tag\n", "try:\n", - " result_tag = basalt.prompts.get_sync(slug='prompt-slug', tag='latest')\n", + " result_tag = basalt.prompts.get_sync(slug=\"prompt-slug\", tag=\"latest\")\n", "except Exception:\n", " pass\n", "\n", "# Get a prompt with a specific version\n", "try:\n", - " result_version = basalt.prompts.get_sync(slug='prompt-slug', version='1.0.0')\n", + " result_version = basalt.prompts.get_sync(slug=\"prompt-slug\", version=\"1.0.0\")\n", "except Exception:\n", " pass" ] @@ -106,12 +108,8 @@ "# Variables allow you to customize prompts with dynamic values\n", "try:\n", " result_vars = basalt.prompts.get_sync(\n", - " slug='prompt-slug-with-vars',\n", - " variables={\n", - " 'name': 'John Doe',\n", - " 'role': 'Developer',\n", - " 'company': 'Acme Inc'\n", - " }\n", + " slug=\"prompt-slug-with-vars\",\n", + " variables={\"name\": \"John Doe\", \"role\": \"Developer\", \"company\": \"Acme Inc\"},\n", " )\n", "except Exception:\n", " pass" @@ -142,7 +140,7 @@ "\n", " # Get a prompt from Basalt\n", " try:\n", - " result = basalt.prompts.get_sync('prompt-slug')\n", + " result = basalt.prompts.get_sync(\"prompt-slug\")\n", "\n", " # Use the prompt with OpenAI\n", " # Note: This will fail with demo keys, but shows the pattern\n", @@ -150,8 +148,8 @@ " model=\"gpt-4\",\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", - " {\"role\": \"user\", \"content\": result.text}\n", - " ]\n", + " {\"role\": \"user\", \"content\": result.text},\n", + " ],\n", " )\n", " except Exception:\n", " pass\n", @@ -183,16 +181,12 @@ " Returns the prompt text on success, or None on failure.\n", " \"\"\"\n", " try:\n", - " result = basalt.prompts.get_sync(\n", - " slug=slug,\n", - " tag=tag,\n", - " version=version,\n", - " variables=variables\n", - " )\n", + " result = basalt.prompts.get_sync(slug=slug, tag=tag, version=version, variables=variables)\n", " return result.text\n", " except Exception:\n", " return None\n", "\n", + "\n", "# Test with a non-existent prompt (will fail)\n", "prompt_text = get_prompt_safely(\"non-existent-prompt\")\n", "if prompt_text:\n", @@ -227,10 +221,7 @@ "@evaluator(\n", " slugs=[\"prompt-quality\", \"response-accuracy\"],\n", " sample_rate=1.0,\n", - " metadata=lambda prompt_slug, **kwargs: {\n", - " \"prompt_slug\": prompt_slug,\n", - " \"source\": \"basalt_prompts\"\n", - " }\n", + " metadata=lambda prompt_slug, **kwargs: {\"prompt_slug\": prompt_slug, \"source\": \"basalt_prompts\"},\n", ")\n", "def generate_response_with_prompt(prompt_slug: str, variables: dict = None) -> str:\n", " \"\"\"\n", @@ -267,11 +258,11 @@ " span.set_output({\"status\": \"error\", \"error\": str(e)})\n", " return f\"Error: {e}\"\n", "\n", + "\n", "# Execute the workflow\n", "try:\n", " response = generate_response_with_prompt(\n", - " prompt_slug=\"example-prompt\",\n", - " variables={\"context\": \"demo\"}\n", + " prompt_slug=\"example-prompt\", variables={\"context\": \"demo\"}\n", " )\n", "except Exception:\n", " pass\n", diff --git a/project.json b/project.json index f80075a..269368d 100644 --- a/project.json +++ b/project.json @@ -9,7 +9,7 @@ "cwd": "packages/py-sdk", "commands": [ "rm -rf ../../dist/packages/py-sdk", - "BASALT_BUILD=production python3 -m build --outdir ../../dist/packages/py-sdk" + "BASALT_BUILD=production hatch build --outdir ../../dist/packages/py-sdk" ] } }, @@ -18,8 +18,8 @@ "executor": "nx:run-commands", "dependsOn": ["build"], "options": { - "cwd": "dist/packages/py-sdk", - "command": "python3 -m twine upload ./*" + "cwd": "packages/py-sdk", + "command": "hatch publish ../../dist/packages/py-sdk" } } } diff --git a/pyproject.toml b/pyproject.toml index eba2acb..9d7998f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ -# This section defines the build system. It's standard for setuptools. +# This section defines the build system using Hatchling [build-system] -requires = ["setuptools>=80.9.0", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "basalt_sdk" -version = "1.1.0" +dynamic = ["version"] description = "Basalt SDK for python" readme = "README.md" # Replaces long_description license = "MIT" @@ -29,6 +29,8 @@ dependencies = [ "opentelemetry-exporter-otlp~=1.39.1", "wrapt~=1.17.3", "httpx>=0.28.1", + "black>=26.1.0", + "typing_extensions>=4.10.0", ] # This section lists URLs, replacing the `url` parameter @@ -48,9 +50,9 @@ google-generativeai = [ "opentelemetry-instrumentation-google-generativeai~=0.51.0", ] # Note: google-genai instrumentation not yet available in openllmetry - google-genai = [ - "opentelemetry-instrumentation-google-genai~=0.5b0", - ] +# google-genai = [ +# "opentelemetry-instrumentation-google-genai~=0.5b0", +# ] bedrock = [ "opentelemetry-instrumentation-bedrock~=0.51.0", @@ -95,7 +97,7 @@ vector-all = [ "basalt_sdk[chromadb,pinecone,qdrant]", ] framework-all = [ - "basalt_sdk[langchain,llamaindex,haystack]", + "basalt_sdk[langchain,llamaindex]", ] all = [ "basalt_sdk[llm-all,vector-all,framework-all]", @@ -131,11 +133,83 @@ dev = [ # A recommended configuration for Ruff, the modern linter/formatter [tool.ruff] -line-length = 120 -target-version = "py310" +line-length = 160 +target-version = "py312" +# Enable lint + format in one tool +lint.select = ["E", "F", "I", "B", "UP", "ANN"] +lint.ignore = ["ANN101", "ANN102"] # example: ignore self/cls annotations +lint.fixable = ["ALL"] +lint.unfixable = [] -[tool.ruff.lint] -select = ["E", "F", "W", "I", "N", "UP", "B", "C4", "T20"] +[tool.ruff.lint.isort] +known-first-party = ["basalt"] [tool.ruff.format] quote-style = "double" +indent-style = "space" +line-ending = "lf" +docstring-code-format = true + +# Hatch configuration +[tool.hatch.version] +path = "basalt/_version.py" + +[tool.hatch.build.targets.wheel] +packages = ["basalt"] + +[tool.hatch.build.targets.sdist] +include = [ + "/basalt", + "/tests", + "/docs", + "/README.md", + "/pyproject.toml", +] + +# Hatch environments for development and testing +[tool.hatch.envs.default] +dependencies = [ + "pytest>=9.0.2", + "pytest-cov>=4.1.0", + "pytest-asyncio>=0.23.0", + "anyio>=4.12.1", + "ruff>=0.14.13", + "mypy>=1.8.0", +] + +[tool.hatch.envs.default.scripts] +test = "pytest tests/ --cov=basalt --cov-report=term --cov-report=xml" +test-verbose = "pytest tests/ -v --cov=basalt --cov-report=term --cov-report=xml" +lint = "ruff check ." +lint-fix = "ruff check --fix ." +fmt = "ruff format ." +fmt-check = "ruff format --check ." +typecheck = "mypy basalt tests" +all = [ + "fmt", + "lint-fix", + "typecheck", + "test", +] + +# Matrix environments for testing across Python versions +[tool.hatch.envs.test] +dependencies = [ + "pytest>=8.0.0", + "pytest-cov>=4.1.0", + "pytest-asyncio>=0.23.0", + "anyio>=4.0.0", +] + +[[tool.hatch.envs.test.matrix]] +python = ["3.10", "3.11", "3.12", "3.13", "3.14"] + +[tool.hatch.envs.test.scripts] +run = "pytest tests/ --cov=basalt --cov-report=term" + +# Environment with all optional dependencies for comprehensive testing +[tool.hatch.envs.full] +features = ["all", "dev"] + +[tool.hatch.envs.full.scripts] +test = "pytest tests/ --cov=basalt --cov-report=term --cov-report=xml" diff --git a/tests/observability/conftest.py b/tests/conftest.py similarity index 100% rename from tests/observability/conftest.py rename to tests/conftest.py diff --git a/tests/datasets/test_client.py b/tests/datasets/test_client.py index eefe299..a005fbe 100644 --- a/tests/datasets/test_client.py +++ b/tests/datasets/test_client.py @@ -3,6 +3,7 @@ These tests were converted from unittest to pytest. They keep the same behaviour but use pytest fixtures, parametrization and asyncio support. """ + from unittest.mock import patch import pytest @@ -33,23 +34,25 @@ def test_list_sync_success(common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"datasets": [ - { - "slug": "dataset-1", - "name": "Dataset 1", - "columns": [ - {"name": "col1", "type": "text"}, - {"name": "col2", "type": "text"} - ], - }, + mock_fetch.return_value = make_response( { - "slug": "dataset-2", - "name": "Dataset 2", - "columns": [ - {"name": "col3", "type": "number"} - ], - }, - ]}) + "datasets": [ + { + "slug": "dataset-1", + "name": "Dataset 1", + "columns": [ + {"name": "col1", "type": "text"}, + {"name": "col2", "type": "text"}, + ], + }, + { + "slug": "dataset-2", + "name": "Dataset 2", + "columns": [{"name": "col3", "type": "number"}], + }, + ] + } + ) datasets = client.list_sync() @@ -83,24 +86,26 @@ def test_get_sync_success(common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "warning": "Some rows contained columns that do not exist in the dataset and were omitted.", - "dataset": { - "slug": "test-dataset", - "name": "Test Dataset", - "columns": [ - {"name": "input", "type": "text"}, - {"name": "output", "type": "text"} - ], - "rows": [ - { - "values": {"input": "hello", "output": "world"}, - "idealOutput": "This is the expected output", - "metadata": {"source": "user"}, - } - ], - }, - }) + mock_fetch.return_value = make_response( + { + "warning": "Some rows contained columns that do not exist in the dataset and were omitted.", + "dataset": { + "slug": "test-dataset", + "name": "Test Dataset", + "columns": [ + {"name": "input", "type": "text"}, + {"name": "output", "type": "text"}, + ], + "rows": [ + { + "values": {"input": "hello", "output": "world"}, + "idealOutput": "This is the expected output", + "metadata": {"source": "user"}, + } + ], + }, + } + ) dataset = client.get_sync("test-dataset") @@ -146,15 +151,17 @@ def test_add_row_sync_success(common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "datasetRow": { - "values": {"col1": "value1", "col2": "value2"}, - "name": "test-row", - "idealOutput": "expected", - "metadata": {"key": "value"}, - }, - "warning": None, - }) + mock_fetch.return_value = make_response( + { + "datasetRow": { + "values": {"col1": "value1", "col2": "value2"}, + "name": "test-row", + "idealOutput": "expected", + "metadata": {"key": "value"}, + }, + "warning": None, + } + ) row = client.add_row_sync( slug="test-dataset", @@ -186,12 +193,14 @@ def test_add_row_sync_with_warning(common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "datasetRow": {"values": {"col1": "value1"}}, - "warning": "Some warning message", - }) + mock_fetch.return_value = make_response( + { + "datasetRow": {"values": {"col1": "value1"}}, + "warning": "Some warning message", + } + ) - with patch.object(client._logger, 'warning') as mock_logger: + with patch.object(client._logger, "warning") as mock_logger: client.add_row_sync( slug="test-dataset", values={"col1": "value1"}, @@ -259,16 +268,20 @@ async def test_list_async_success(self, common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({"datasets": [ + mock_fetch.return_value = make_response( { - "slug": "ds1", - "name": "Dataset 1", - "columns": [ - {"name": "a", "type": "text"}, - {"name": "b", "type": "text"} - ], - }, - ]}) + "datasets": [ + { + "slug": "ds1", + "name": "Dataset 1", + "columns": [ + {"name": "a", "type": "text"}, + {"name": "b", "type": "text"}, + ], + }, + ] + } + ) datasets = await client.list() @@ -285,17 +298,17 @@ async def test_get_async_success(self, common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({ - "warning": None, - "dataset": { - "slug": "test", - "name": "Test", - "columns": [ - {"name": "col1", "type": "text"} - ], - "rows": [], - }, - }) + mock_fetch.return_value = make_response( + { + "warning": None, + "dataset": { + "slug": "test", + "name": "Test", + "columns": [{"name": "col1", "type": "text"}], + "rows": [], + }, + } + ) dataset = await client.get("test") @@ -312,10 +325,9 @@ async def test_add_row_async_success(self, common_client): client = common_client["client"] with patch("basalt.datasets.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({ - "datasetRow": {"values": {"col1": "val1"}}, - "warning": None - }) + mock_fetch.return_value = make_response( + {"datasetRow": {"values": {"col1": "val1"}}, "warning": None} + ) row = await client.add_row("test", {"col1": "val1"}) diff --git a/tests/datasets/test_file_upload.py b/tests/datasets/test_file_upload.py index 7b95c8c..7fe4131 100644 --- a/tests/datasets/test_file_upload.py +++ b/tests/datasets/test_file_upload.py @@ -90,9 +90,7 @@ def test_from_file_like_object(self, tmp_path): def test_from_bytesio_without_name(self): """Test creating FileAttachment from BytesIO without name attribute.""" bio = io.BytesIO(b"content") - with pytest.raises( - FileValidationError, match="filename is required for file-like objects" - ): + with pytest.raises(FileValidationError, match="filename is required for file-like objects"): FileAttachment(source=bio) def test_explicit_content_type(self, temp_file): @@ -150,17 +148,13 @@ def test_detect_markdown(self): def test_explicit_content_type_valid(self): """Test explicit valid content type.""" - attachment = FileAttachment( - source=b"content", filename="test", content_type="image/png" - ) + attachment = FileAttachment(source=b"content", filename="test", content_type="image/png") content_type = _detect_content_type(attachment) assert content_type == "image/png" def test_explicit_content_type_invalid(self): """Test explicit invalid content type raises error.""" - attachment = FileAttachment( - source=b"content", filename="test", content_type="text/plain" - ) + attachment = FileAttachment(source=b"content", filename="test", content_type="text/plain") with pytest.raises(FileValidationError, match="not allowed"): _detect_content_type(attachment) @@ -292,9 +286,7 @@ def test_validate_file_success(self, file_upload_handler, temp_file): """Test successful file validation.""" attachment = FileAttachment(source=temp_file, content_type="image/png") - file_bytes, content_type, filename = file_upload_handler.validate_file( - attachment - ) + file_bytes, content_type, filename = file_upload_handler.validate_file(attachment) assert file_bytes == b"fake png content" assert content_type == "image/png" @@ -334,9 +326,7 @@ async def test_request_presigned_url_async(self, file_upload_handler, http_clien ) http_client.fetch = AsyncMock(return_value=mock_response) - result = await file_upload_handler.request_presigned_url( - "test.jpg", "image/jpeg" - ) + result = await file_upload_handler.request_presigned_url("test.jpg", "image/jpeg") assert isinstance(result, PresignedUploadResponse) assert result.file_key == "datasets/ws/uuid.jpg" @@ -356,9 +346,7 @@ def test_request_presigned_url_sync(self, file_upload_handler, http_client): ) http_client.fetch_sync = Mock(return_value=mock_response) - result = file_upload_handler.request_presigned_url_sync( - "test.jpg", "image/jpeg" - ) + result = file_upload_handler.request_presigned_url_sync("test.jpg", "image/jpeg") assert isinstance(result, PresignedUploadResponse) assert result.file_key == "datasets/ws/uuid.jpg" @@ -425,9 +413,7 @@ async def test_upload_to_s3_timeout(self, file_upload_handler, http_client): ) @pytest.mark.asyncio - async def test_upload_file_complete_workflow( - self, file_upload_handler, http_client, temp_file - ): + async def test_upload_file_complete_workflow(self, file_upload_handler, http_client, temp_file): """Test complete file upload workflow.""" # Mock presigned URL request presigned_response = HTTPResponse( @@ -477,9 +463,7 @@ async def test_add_row_with_file_async(self, client, temp_file): ) as mock_upload: mock_upload.return_value = "datasets/ws/uuid.png" - with patch.object( - client._http_client, "fetch", new_callable=AsyncMock - ) as mock_fetch: + with patch.object(client._http_client, "fetch", new_callable=AsyncMock) as mock_fetch: mock_fetch.return_value = HTTPResponse( status_code=200, data={ @@ -507,9 +491,7 @@ async def test_add_row_with_file_async(self, client, temp_file): def test_add_row_with_file_sync(self, client, temp_file): """Test add_row_sync with FileAttachment.""" - with patch.object( - client._file_upload_handler, "upload_file_sync" - ) as mock_upload: + with patch.object(client._file_upload_handler, "upload_file_sync") as mock_upload: mock_upload.return_value = "datasets/ws/uuid.png" with patch.object(client._http_client, "fetch_sync") as mock_fetch: @@ -536,9 +518,7 @@ def test_add_row_with_file_sync(self, client, temp_file): def test_add_row_with_mixed_values(self, client, temp_file): """Test add_row with both string and file values.""" - with patch.object( - client._file_upload_handler, "upload_file_sync" - ) as mock_upload: + with patch.object(client._file_upload_handler, "upload_file_sync") as mock_upload: mock_upload.return_value = "datasets/ws/uuid.png" with patch.object(client._http_client, "fetch_sync") as mock_fetch: @@ -577,9 +557,7 @@ def test_add_row_with_multiple_files(self, client, tmp_path): file2 = tmp_path / "file2.jpg" file2.write_bytes(b"content2") - with patch.object( - client._file_upload_handler, "upload_file_sync" - ) as mock_upload: + with patch.object(client._file_upload_handler, "upload_file_sync") as mock_upload: mock_upload.side_effect = ["datasets/ws/uuid1.png", "datasets/ws/uuid2.jpg"] with patch.object(client._http_client, "fetch_sync") as mock_fetch: @@ -611,9 +589,7 @@ def test_add_row_with_multiple_files(self, client, tmp_path): def test_add_row_file_upload_error_propagates(self, client, temp_file): """Test that file upload errors are propagated.""" - with patch.object( - client._file_upload_handler, "upload_file_sync" - ) as mock_upload: + with patch.object(client._file_upload_handler, "upload_file_sync") as mock_upload: mock_upload.side_effect = FileUploadError("S3 upload failed") with pytest.raises(FileUploadError, match="S3 upload failed"): diff --git a/tests/experiments/test_client.py b/tests/experiments/test_client.py index 917dcf7..b3b5e62 100644 --- a/tests/experiments/test_client.py +++ b/tests/experiments/test_client.py @@ -2,6 +2,7 @@ These tests follow the same pattern as the prompts and datasets tests. """ + from unittest.mock import patch import pytest @@ -40,12 +41,14 @@ def test_create_sync_success(common_client): client: ExperimentsClient = common_client["client"] with patch("basalt.experiments.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "id": "123", - "name": "My Experiment", - "featureSlug": "my-feature", - "createdAt": "2024-03-20T12:00:00Z", - }) + mock_fetch.return_value = make_response( + { + "id": "123", + "name": "My Experiment", + "featureSlug": "my-feature", + "createdAt": "2024-03-20T12:00:00Z", + } + ) experiment = client.create_sync( feature_slug="my-feature", @@ -120,12 +123,14 @@ async def test_create_async_success(common_client): client: ExperimentsClient = common_client["client"] with patch("basalt.experiments.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({ - "id": "456", - "name": "Async Experiment", - "featureSlug": "async-feature", - "createdAt": "2024-03-21T10:30:00Z", - }) + mock_fetch.return_value = make_response( + { + "id": "456", + "name": "Async Experiment", + "featureSlug": "async-feature", + "createdAt": "2024-03-21T10:30:00Z", + } + ) experiment = await client.create( feature_slug="async-feature", @@ -222,12 +227,14 @@ def test_create_sync_parameter_combinations(common_client, feature_slug, name): client: ExperimentsClient = common_client["client"] with patch("basalt.experiments.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "id": "123", - "name": name, - "featureSlug": feature_slug, - "createdAt": "2024-03-20T12:00:00Z", - }) + mock_fetch.return_value = make_response( + { + "id": "123", + "name": name, + "featureSlug": feature_slug, + "createdAt": "2024-03-20T12:00:00Z", + } + ) experiment = client.create_sync(feature_slug=feature_slug, name=name) @@ -263,10 +270,12 @@ def test_experiment_model_from_dict_with_empty_dict(): def test_experiment_model_from_dict_with_partial_data(): """Test Experiment.from_dict with partial data.""" - experiment = Experiment.from_dict({ - "id": "123", - "name": "Test", - }) + experiment = Experiment.from_dict( + { + "id": "123", + "name": "Test", + } + ) assert experiment.id == "123" assert experiment.name == "Test" @@ -276,12 +285,14 @@ def test_experiment_model_from_dict_with_partial_data(): def test_experiment_model_from_dict_with_wrong_types(): """Test Experiment.from_dict handles wrong types gracefully.""" - experiment = Experiment.from_dict({ - "id": 123, # Should be string - "name": None, # Should be string - "featureSlug": ["not", "a", "string"], # Should be string - "createdAt": True, # Should be string - }) + experiment = Experiment.from_dict( + { + "id": 123, # Should be string + "name": None, # Should be string + "featureSlug": ["not", "a", "string"], # Should be string + "createdAt": True, # Should be string + } + ) # All values should be converted to empty strings due to type checking assert experiment.id == "" diff --git a/tests/internal/test_http.py b/tests/internal/test_http.py index 2d51d1a..a286eac 100644 --- a/tests/internal/test_http.py +++ b/tests/internal/test_http.py @@ -1,4 +1,5 @@ """Tests for the HTTP client.""" + from unittest.mock import Mock, patch import httpx @@ -25,220 +26,223 @@ def create_mock_session(): class TestHTTPClient: """Test cases for the HTTPClient class.""" - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_uses_httpx_to_make_http_calls(self, request_mock): """Test that the client uses httpx library for HTTP calls.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json.return_value = {} - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") # Verify the request was called with expected parameters assert request_mock.call_count == 1 call_args = request_mock.call_args - assert call_args[0][0] == 'GET' - assert call_args[0][1] == 'http://test/abc' + assert call_args[0][0] == "GET" + assert call_args[0][1] == "http://test/abc" - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_captures_httpx_exceptions(self, request_mock): """Test that the client captures and wraps httpx exceptions.""" client = HTTPClient() - request_mock.side_effect = httpx.HTTPError('Some unknown error') + request_mock.side_effect = httpx.HTTPError("Some unknown error") with pytest.raises(NetworkError) as exc_info: - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") - assert 'Some unknown error' in str(exc_info.value.message) + assert "Some unknown error" in str(exc_info.value.message) - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_rejects_non_json_responses(self, request_mock): """Test that the client handles non-JSON responses properly.""" client = HTTPClient() request_mock.return_value = Mock() - request_mock.return_value.json.side_effect = Exception('No JSON object could be decoded') + request_mock.return_value.json.side_effect = Exception("No JSON object could be decoded") request_mock.return_value.headers = {} request_mock.return_value.status_code = 200 - request_mock.return_value.content = b'plain text' - request_mock.return_value.text = 'plain text' + request_mock.return_value.content = b"plain text" + request_mock.return_value.text = "plain text" with pytest.raises(NetworkError): - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_returns_valid_json_as_dict(self, request_mock): """Test that the client returns valid JSON responses.""" client = HTTPClient() mock_response = Mock() mock_response.json.return_value = {"some": "data"} - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.status_code = 200 mock_response.content = b'{"some":"data"}' request_mock.return_value = mock_response - result = client.fetch_sync('http://test/abc', 'GET') + result = client.fetch_sync("http://test/abc", "GET") assert isinstance(result, HTTPResponse) assert result.json() == {"some": "data"} - @pytest.mark.parametrize("response_code,error_class", [ - (400, BadRequestError), - (401, UnauthorizedError), - (403, ForbiddenError), - (404, NotFoundError), - ]) - @patch('basalt._internal.http.httpx.Client.request') + @pytest.mark.parametrize( + "response_code,error_class", + [ + (400, BadRequestError), + (401, UnauthorizedError), + (403, ForbiddenError), + (404, NotFoundError), + ], + ) + @patch("basalt._internal.http.httpx.Client.request") def test_raises_custom_errors(self, request_mock, response_code, error_class): """Test that the client raises appropriate custom errors for HTTP error codes.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = response_code - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json.return_value = {} mock_response.text = "" - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response with pytest.raises(error_class): - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_includes_error_message_from_api(self, request_mock): """Test that the client includes error messages from the API response.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 400 - mock_response.headers = {'Content-Type': 'application/json'} - mock_response.json.return_value = {'error': 'Invalid request format'} + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = {"error": "Invalid request format"} mock_response.text = "" - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response with pytest.raises(BadRequestError) as exc_info: - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") - assert exc_info.value.args[0] == 'Invalid request format' + assert exc_info.value.args[0] == "Invalid request format" assert exc_info.value.status_code == 400 - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_handles_errors_field_for_bad_request(self, request_mock): """Test that the client handles 'errors' field in bad request responses.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 400 - mock_response.headers = {'Content-Type': 'application/json'} - mock_response.json.return_value = {'errors': 'Validation failed'} + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = {"errors": "Validation failed"} mock_response.text = "" - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response with pytest.raises(BadRequestError) as exc_info: - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") - assert exc_info.value.args[0] == 'Validation failed' + assert exc_info.value.args[0] == "Validation failed" - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_handles_202_no_content(self, request_mock): """Test that the client handles 202 Accepted responses with no content.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 202 mock_response.headers = {} - mock_response.json.side_effect = Exception('No content') + mock_response.json.side_effect = Exception("No content") mock_response.content = b"" request_mock.return_value = mock_response - result = client.fetch_sync('http://test/abc', 'POST') + result = client.fetch_sync("http://test/abc", "POST") assert isinstance(result, HTTPResponse) assert result.status_code == 202 assert result.body is None - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_handles_204_no_content(self, request_mock): """Test that the client handles 204 No Content responses.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 204 mock_response.headers = {} - mock_response.json.side_effect = Exception('No content') + mock_response.json.side_effect = Exception("No content") mock_response.content = b"" request_mock.return_value = mock_response - result = client.fetch_sync('http://test/abc', 'DELETE') + result = client.fetch_sync("http://test/abc", "DELETE") assert isinstance(result, HTTPResponse) assert result.status_code == 204 assert result.body is None @pytest.mark.parametrize("method", ["GET", "POST", "PUT", "DELETE"]) - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_supports_http_methods(self, request_mock, method): """Test that the client supports various HTTP methods.""" client = HTTPClient() mock_response = Mock() mock_response.json.return_value = {} - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.status_code = 200 - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', method) + client.fetch_sync("http://test/abc", method) call_args = request_mock.call_args[0] assert call_args[0] == method - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_passes_body_to_request(self, request_mock): """Test that the client passes request body correctly.""" client = HTTPClient() mock_response = Mock() mock_response.json.return_value = {} - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.status_code = 200 - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', 'POST', body={"test": "data"}) + client.fetch_sync("http://test/abc", "POST", body={"test": "data"}) call_kwargs = request_mock.call_args.kwargs - assert call_kwargs['json'] == {"test": "data"} + assert call_kwargs["json"] == {"test": "data"} - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_passes_params_to_request(self, request_mock): """Test that the client passes query parameters correctly.""" client = HTTPClient() mock_response = Mock() mock_response.json.return_value = {} - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.status_code = 200 - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', 'GET', params={"tag": "test"}) + client.fetch_sync("http://test/abc", "GET", params={"tag": "test"}) call_kwargs = request_mock.call_args.kwargs - assert call_kwargs['params'] == {"tag": "test"} + assert call_kwargs["params"] == {"tag": "test"} - @patch('basalt._internal.http.httpx.Client.request') + @patch("basalt._internal.http.httpx.Client.request") def test_passes_headers_to_request(self, request_mock): """Test that the client passes headers correctly.""" client = HTTPClient() mock_response = Mock() mock_response.json.return_value = {} - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.status_code = 200 - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', 'GET', headers={"Authorization": "Bearer token"}) + client.fetch_sync("http://test/abc", "GET", headers={"Authorization": "Bearer token"}) call_kwargs = request_mock.call_args.kwargs - assert call_kwargs['headers'] == {"Authorization": "Bearer token"} + assert call_kwargs["headers"] == {"Authorization": "Bearer token"} def test_context_manager_sync(self): """Test that the client supports sync context manager protocol.""" @@ -268,101 +272,103 @@ def test_custom_retries(self): client = HTTPClient(max_retries=5) assert client.max_retries == 5 - @pytest.mark.parametrize("response_code,error_class", [ - (401, UnauthorizedError), - (403, ForbiddenError), - (404, NotFoundError), - (500, NetworkError), - ]) - @patch('basalt._internal.http.httpx.Client.request') - def test_extracts_error_field_from_non_200_responses(self, request_mock, response_code, error_class): + @pytest.mark.parametrize( + "response_code,error_class", + [ + (401, UnauthorizedError), + (403, ForbiddenError), + (404, NotFoundError), + (500, NetworkError), + ], + ) + @patch("basalt._internal.http.httpx.Client.request") + def test_extracts_error_field_from_non_200_responses( + self, request_mock, response_code, error_class + ): """Test that error field is extracted from all non-2xx JSON responses.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = response_code - mock_response.headers = {'Content-Type': 'application/json'} - mock_response.json.return_value = {'error': 'Custom error message from API'} + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = {"error": "Custom error message from API"} mock_response.text = "" - mock_response.content = b'{}' + mock_response.content = b"{}" request_mock.return_value = mock_response with pytest.raises(error_class) as exc_info: - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") # The error message should be extracted from the "error" field - assert exc_info.value.args[0] == 'Custom error message from API' + assert exc_info.value.args[0] == "Custom error message from API" - @patch('basalt._internal.http.httpx.Client.request') - @patch('basalt._internal.http.logger') + @patch("basalt._internal.http.httpx.Client.request") + @patch("basalt._internal.http.logger") def test_logs_warning_field_in_200_response(self, mock_logger, request_mock): """Test that warning field in 2xx responses is returned but not logged by HTTPClient.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 200 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json.return_value = { - 'data': 'some data', - 'warning': 'Deprecated API endpoint, please migrate to v2' + "data": "some data", + "warning": "Deprecated API endpoint, please migrate to v2", } mock_response.content = b'{"data":"some data"}' request_mock.return_value = mock_response - result = client.fetch_sync('http://test/abc', 'GET') + result = client.fetch_sync("http://test/abc", "GET") # Verify the warning was NOT logged by HTTPClient (it's logged by the API clients) mock_logger.warning.assert_not_called() # Result should still contain the warning in the data - assert result.json()['data'] == 'some data' - assert result.json()['warning'] == 'Deprecated API endpoint, please migrate to v2' + assert result.json()["data"] == "some data" + assert result.json()["warning"] == "Deprecated API endpoint, please migrate to v2" - @patch('basalt._internal.http.httpx.Client.request') - @patch('basalt._internal.http.logger') + @patch("basalt._internal.http.httpx.Client.request") + @patch("basalt._internal.http.logger") def test_logs_warning_field_in_201_response(self, mock_logger, request_mock): """Test that warning field in 201 Created responses is returned but not logged by HTTPClient.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 201 - mock_response.headers = {'Content-Type': 'application/json'} + mock_response.headers = {"Content-Type": "application/json"} mock_response.json.return_value = { - 'id': '123', - 'warning': 'Resource created with default values' + "id": "123", + "warning": "Resource created with default values", } mock_response.content = b'{"id":"123"}' request_mock.return_value = mock_response - result = client.fetch_sync('http://test/abc', 'POST') + result = client.fetch_sync("http://test/abc", "POST") # Verify the warning was NOT logged by HTTPClient (it's logged by the API clients) mock_logger.warning.assert_not_called() # Result should still contain the warning in the data - assert result.json()['id'] == '123' - assert result.json()['warning'] == 'Resource created with default values' + assert result.json()["id"] == "123" + assert result.json()["warning"] == "Resource created with default values" - @patch('basalt._internal.http.httpx.Client.request') - @patch('basalt._internal.http.logger') + @patch("basalt._internal.http.httpx.Client.request") + @patch("basalt._internal.http.logger") def test_no_warning_logged_when_field_absent(self, mock_logger, request_mock): """Test that no warning is logged when warning field is absent.""" client = HTTPClient() mock_response = Mock() mock_response.status_code = 200 - mock_response.headers = {'Content-Type': 'application/json'} - mock_response.json.return_value = {'data': 'some data'} + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = {"data": "some data"} mock_response.content = b'{"data":"some data"}' request_mock.return_value = mock_response - client.fetch_sync('http://test/abc', 'GET') + client.fetch_sync("http://test/abc", "GET") # Verify no warning was logged mock_logger.warning.assert_not_called() - def test_handle_response_200_with_json(self): """Test handling of a successful 200 response with valid JSON.""" response_data = {"key": "value"} mock_response = Response( - status_code=200, - json=response_data, - headers={"Content-Type": "application/json"} + status_code=200, json=response_data, headers={"Content-Type": "application/json"} ) result = HTTPClient._handle_response(mock_response) assert isinstance(result, HTTPResponse) @@ -372,9 +378,7 @@ def test_handle_response_200_with_json(self): def test_handle_response_204_no_content(self): """Test handling of a 204 No Content response.""" mock_response = Response( - status_code=204, - content=b"", - headers={"Content-Type": "application/json"} + status_code=204, content=b"", headers={"Content-Type": "application/json"} ) result = HTTPClient._handle_response(mock_response) assert result is not None @@ -385,9 +389,7 @@ def test_handle_response_400_with_json(self): """Test handling of a 400 Bad Request response with JSON error.""" response_data = {"error": "Invalid data"} mock_response = Response( - status_code=400, - json=response_data, - headers={"Content-Type": "application/json"} + status_code=400, json=response_data, headers={"Content-Type": "application/json"} ) with pytest.raises(BadRequestError) as exc_info: HTTPClient._handle_response(mock_response) @@ -396,9 +398,7 @@ def test_handle_response_400_with_json(self): def test_handle_response_500_with_plain_text(self): """Test handling of a 500 Internal Server Error response with plain text.""" mock_response = Response( - status_code=500, - content=b"Internal Error", - headers={"Content-Type": "text/plain"} + status_code=500, content=b"Internal Error", headers={"Content-Type": "text/plain"} ) with pytest.raises(NetworkError) as exc_info: HTTPClient._handle_response(mock_response) @@ -407,9 +407,7 @@ def test_handle_response_500_with_plain_text(self): def test_handle_response_invalid_json(self): """Test handling of a response with invalid JSON when JSON is expected.""" mock_response = Response( - status_code=200, - content=b"invalid-json", - headers={"Content-Type": "application/json"} + status_code=200, content=b"invalid-json", headers={"Content-Type": "application/json"} ) with pytest.raises(NetworkError) as exc_info: HTTPClient._handle_response(mock_response) diff --git a/tests/observability/test_api.py b/tests/observability/test_api.py index 1c307fa..2233e6c 100644 --- a/tests/observability/test_api.py +++ b/tests/observability/test_api.py @@ -76,8 +76,10 @@ def test_get_config_for_kind_default(): assert result[2] is None assert result[3] is None + def test_call_decorator_sync_function(): """Test using __call__ as a decorator on a synchronous function.""" + @Observe(name="test_function", kind=ObserveKind.FUNCTION, metadata={"key": "value"}) def sample_function(x, y): return x + y @@ -89,6 +91,7 @@ def sample_function(x, y): @pytest.mark.asyncio async def test_call_decorator_async_function(): """Test using __call__ as a decorator on an asynchronous function.""" + @Observe(name="async_test_function", kind=ObserveKind.FUNCTION, metadata={"key": "value"}) async def async_sample_function(x, y): return x * y @@ -99,6 +102,7 @@ async def async_sample_function(x, y): def test_call_decorator_handling_exceptions(): """Test that __call__ as a decorator appropriately handles exceptions.""" + @Observe(name="exception_test_function", kind=ObserveKind.FUNCTION) def error_function(x): raise ValueError("Intentional error") @@ -106,6 +110,7 @@ def error_function(x): with pytest.raises(ValueError, match="Intentional error"): error_function(10) + def test_observe_as_decorator(): """Test Observe when used as a decorator for a synchronous function.""" @@ -172,7 +177,6 @@ def test_observe_static_metadata(): assert parsed_metadata["key2"] == "value2" - def test_observe_decorator_sync_function(): """Test Observe decorator on a synchronous function.""" observed_data = {} @@ -186,6 +190,7 @@ def test_function(x, y): assert result == 7 assert observed_data["input"] == (3, 4) + def test_observe_context_manager(): """Test Observe used as a context manager.""" with Observe(name="Test Context Manager", metadata={"key": "value"}) as span: @@ -193,6 +198,7 @@ def test_observe_context_manager(): span.set_output("test_output") assert span is not None + def test_observe_with_metadata(): """Test Observe with metadata added.""" observed_metadata = {} @@ -207,9 +213,11 @@ def test_function(): assert result is True assert observed_metadata["user"] == "test_user" + @pytest.mark.asyncio async def test_observe_decorator_async_function(): """Test Observe decorator on an async function.""" + @Observe(name="Test Async Function", kind=ObserveKind.FUNCTION) async def test_async_function(x, y): return x * y @@ -221,6 +229,7 @@ async def test_async_function(x, y): @pytest.mark.asyncio async def test_start_observe_decorator_async_function(setup_tracing): """Test StartObserve decorator on an async function.""" + @StartObserve(name="test_async_decorated_root", feature_slug="test_decorator") async def async_root_function(x: int, y: int) -> int: """An async function that adds two numbers.""" @@ -290,7 +299,7 @@ class MockPrompt: frequency_penalty=None, presence_penalty=None, json_object=None, - ) + ), ) mock_prompt = MockPrompt( @@ -299,7 +308,7 @@ class MockPrompt: raw_text="Hello, {{name}}!", version="1.0.0", model=mock_model, - variables={"name": "world"} + variables={"name": "world"}, ) @Observe(kind=ObserveKind.GENERATION, name="test_with_prompt", prompt=mock_prompt) @@ -320,6 +329,7 @@ def generate_text(): assert span.attributes.get("basalt.prompt.model.model") == "gpt-4" # Variables are stored as JSON string for OpenTelemetry compatibility import json + assert json.loads(span.attributes.get("basalt.prompt.variables")) == {"name": "world"} @@ -360,7 +370,7 @@ class MockPrompt: frequency_penalty=None, presence_penalty=None, json_object=None, - ) + ), ) mock_prompt = MockPrompt( @@ -369,10 +379,12 @@ class MockPrompt: raw_text="Context manager {{test}}", version="2.0.0", model=mock_model, - variables={"test": "value"} + variables={"test": "value"}, ) - with Observe(kind=ObserveKind.GENERATION, name="test_context_with_prompt", prompt=mock_prompt) as span: + with Observe( + kind=ObserveKind.GENERATION, name="test_context_with_prompt", prompt=mock_prompt + ) as span: pass # Verify span attributes contain prompt metadata @@ -386,6 +398,7 @@ class MockPrompt: assert span.attributes.get("basalt.prompt.model.model") == "claude-3-opus" # Variables are stored as JSON string for OpenTelemetry compatibility import json + assert json.loads(span.attributes.get("basalt.prompt.variables")) == {"test": "value"} @@ -426,7 +439,7 @@ class MockPrompt: frequency_penalty=None, presence_penalty=None, json_object=None, - ) + ), ) mock_prompt = MockPrompt( @@ -435,7 +448,7 @@ class MockPrompt: raw_text="Simple prompt", version="1.0.0", model=mock_model, - variables=None + variables=None, ) @Observe(kind=ObserveKind.GENERATION, name="test_no_vars", prompt=mock_prompt) diff --git a/tests/observability/test_config.py b/tests/observability/test_config.py index 386e5dd..5a997ee 100644 --- a/tests/observability/test_config.py +++ b/tests/observability/test_config.py @@ -2,43 +2,34 @@ from __future__ import annotations -import os -import unittest -from unittest import mock - from basalt.observability.config import TelemetryConfig -class TestTelemetryConfig(unittest.TestCase): - def test_env_overrides_respect_environment_variables(self): - with mock.patch.dict( - os.environ, - { - "BASALT_TELEMETRY_ENABLED": "0", - "BASALT_SERVICE_NAME": "env-service", - "BASALT_ENVIRONMENT": "staging", - }, - clear=False, - ): - config = TelemetryConfig(service_name="sdk").with_env_overrides() - - self.assertFalse(config.enabled) - self.assertEqual(config.service_name, "env-service") - self.assertEqual(config.environment, "staging") - - def test_clone_returns_independent_copy_of_provider_lists(self): - """Test that cloning creates independent copies of provider lists.""" - original = TelemetryConfig( - enabled_providers=["openai", "anthropic"], - disabled_providers=["langchain"], - ) - - clone = original.clone() - if clone.enabled_providers: - clone.enabled_providers.append("cohere") - if clone.disabled_providers: - clone.disabled_providers.append("llamaindex") - - # Original should be unchanged - self.assertEqual(original.enabled_providers, ["openai", "anthropic"]) - self.assertEqual(original.disabled_providers, ["langchain"]) +def test_env_overrides_respect_environment_variables(monkeypatch): + monkeypatch.setenv("BASALT_TELEMETRY_ENABLED", "0") + monkeypatch.setenv("BASALT_SERVICE_NAME", "env-service") + monkeypatch.setenv("BASALT_ENVIRONMENT", "staging") + + config = TelemetryConfig(service_name="sdk").with_env_overrides() + + assert not config.enabled + assert config.service_name == "env-service" + assert config.environment == "staging" + + +def test_clone_returns_independent_copy_of_provider_lists(): + """Test that cloning creates independent copies of provider lists.""" + original = TelemetryConfig( + enabled_providers=["openai", "anthropic"], + disabled_providers=["langchain"], + ) + + clone = original.clone() + if clone.enabled_providers: + clone.enabled_providers.append("cohere") + if clone.disabled_providers: + clone.disabled_providers.append("llamaindex") + + # Original should be unchanged + assert original.enabled_providers == ["openai", "anthropic"] + assert original.disabled_providers == ["langchain"] diff --git a/tests/observability/test_context_managers.py b/tests/observability/test_context_managers.py index 9d35b38..a0b88fd 100644 --- a/tests/observability/test_context_managers.py +++ b/tests/observability/test_context_managers.py @@ -56,7 +56,6 @@ def test_normalize_evaluator_entry_with_existing_evaluator_attachment(): assert result == entry - def test_with_evaluators_no_values(): """Test with_evaluators does not set any contexts when given no values.""" token = otel_context.attach(otel_context.set_value("test_key", "initial_value")) @@ -156,7 +155,8 @@ def test_set_io_with_no_arguments(): class SimpleSpanMock: """Minimal mock for testing span attribute setting.""" - def __init__(self): + + def __init__(self) -> None: self.attributes = {} def set_attribute(self, key, value): @@ -203,6 +203,7 @@ def mock_span(): span = MagicMock(spec=Span) return span + def test_set_attribute(mock_span): """Test SpanHandle.set_attribute sets attributes on the span.""" span_handle = SpanHandle(span=mock_span) @@ -214,8 +215,7 @@ def test_set_input(mock_span): """Test SpanHandle.set_input sets input payload and serializes it if tracing is enabled.""" with pytest.MonkeyPatch().context() as monkeypatch: monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", - lambda: True + "basalt.observability.context_managers.trace_content_enabled", lambda: True ) span_handle = SpanHandle(span=mock_span) payload = {"key": "value"} @@ -223,12 +223,12 @@ def test_set_input(mock_span): assert span_handle._io_payload["input"] == payload mock_span.set_attribute.assert_called_once() + def test_set_output(mock_span): """Test SpanHandle.set_output sets output payload and serializes it if tracing is enabled.""" with pytest.MonkeyPatch().context() as monkeypatch: monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", - lambda: True + "basalt.observability.context_managers.trace_content_enabled", lambda: True ) span_handle = SpanHandle(span=mock_span) payload = {"result": "success"} @@ -236,26 +236,25 @@ def test_set_output(mock_span): assert span_handle._io_payload["output"] == payload mock_span.set_attribute.assert_called_once() + def test_set_io(mock_span): """Test SpanHandle.set_io sets all I/O payloads correctly.""" with pytest.MonkeyPatch().context() as monkeypatch: monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", - lambda: False + "basalt.observability.context_managers.trace_content_enabled", lambda: False ) span_handle = SpanHandle(span=mock_span) input_payload = {"input": "data"} output_payload = {"output": "data"} variables = {"key": "value"} span_handle.set_io( - input_payload=input_payload, - output_payload=output_payload, - variables=variables + input_payload=input_payload, output_payload=output_payload, variables=variables ) assert span_handle._io_payload["input"] == input_payload assert span_handle._io_payload["output"] == output_payload assert span_handle._io_payload["variables"] == variables + def test_io_snapshot(mock_span): """Test SpanHandle.io_snapshot returns a copy of the I/O payload.""" span_handle = SpanHandle(span=mock_span) @@ -307,10 +306,12 @@ def test_identify_with_both_user_and_organization(): mock_span = SimpleSpanMock() span_handle = SpanHandle(span=mock_span) - span_handle.set_identity({ - "user": {"id": "user-123", "name": "John Doe"}, - "organization": {"id": "org-456", "name": "Acme Corp"} - }) + span_handle.set_identity( + { + "user": {"id": "user-123", "name": "John Doe"}, + "organization": {"id": "org-456", "name": "Acme Corp"}, + } + ) assert mock_span.attributes[semconv.BasaltUser.ID] == "user-123" assert mock_span.attributes[semconv.BasaltUser.NAME] == "John Doe" @@ -324,10 +325,7 @@ def test_identify_with_ids_only(): mock_span = SimpleSpanMock() span_handle = SpanHandle(span=mock_span) - span_handle.set_identity({ - "user": {"id": "user-789"}, - "organization": {"id": "org-101"} - }) + span_handle.set_identity({"user": {"id": "user-789"}, "organization": {"id": "org-101"}}) assert mock_span.attributes[semconv.BasaltUser.ID] == "user-789" assert mock_span.attributes[semconv.BasaltOrganization.ID] == "org-101" @@ -424,7 +422,9 @@ async def test_async_start_observe_with_metadata(): exporter.clear() metadata = {"custom_key": "custom_value", "test_id": 42} - async with AsyncStartObserve(name="test_with_metadata", feature_slug="test_metadata", metadata=metadata) as span: + async with AsyncStartObserve( + name="test_with_metadata", feature_slug="test_metadata", metadata=metadata + ) as span: # Verify span was created successfully assert span is not None assert isinstance(span._span, object) # Span exists @@ -449,9 +449,7 @@ async def test_async_observe_with_evaluators(): async with AsyncStartObserve(name="test_root", feature_slug="test_evaluators"): evaluator_slugs = ["test_evaluator_1", "test_evaluator_2"] async with AsyncObserve( - name="test_with_evaluators", - kind=ObserveKind.GENERATION, - evaluators=evaluator_slugs + name="test_with_evaluators", kind=ObserveKind.GENERATION, evaluators=evaluator_slugs ) as span: # The evaluators should be attached to the span # We can verify by checking the context or span attributes @@ -512,10 +510,12 @@ async def test_async_start_observe_with_identity(): identity = { "user": {"id": "user-123", "name": "Test User"}, - "organization": {"id": "org-456", "name": "Test Org"} + "organization": {"id": "org-456", "name": "Test Org"}, } - async with AsyncStartObserve(name="test_identity", feature_slug="test_identity", identity=identity) as span: + async with AsyncStartObserve( + name="test_identity", feature_slug="test_identity", identity=identity + ) as span: # Verify span and identity setup completed successfully assert span is not None @@ -572,4 +572,3 @@ async def test_async_observe_has_in_trace_attribute(setup_tracing): assert root_span._span.attributes.get("basalt.in_trace") is True async with AsyncObserve(kind=ObserveKind.GENERATION, name="async_child") as child_span: assert child_span._span.attributes.get("basalt.in_trace") is True - diff --git a/tests/observability/test_decorators.py b/tests/observability/test_decorators.py index 0023d5c..b945ee9 100644 --- a/tests/observability/test_decorators.py +++ b/tests/observability/test_decorators.py @@ -9,6 +9,7 @@ def test_evaluate_decorator_single_slug(): """Test the evaluate decorator with a single evaluator slug.""" + @evaluate("test-slug") def test_function(): return "executed" @@ -19,6 +20,7 @@ def test_function(): def test_evaluate_decorator_multiple_slugs(): """Test the evaluate decorator with multiple evaluator slugs.""" + @evaluate(["slug1", "slug2"]) def another_function(): return "success" @@ -38,6 +40,7 @@ def empty_slugs_function(): def test_evaluate_with_metadata_callable(): """Test the evaluate decorator with callable metadata.""" + def metadata_resolver(param): return {"key": param} @@ -45,15 +48,20 @@ def metadata_resolver(param): def function_with_metadata(param): return f"Metadata resolved for {param}" - with patch("basalt.observability.decorators.with_evaluators", wraps=with_evaluators) as mock_with_evaluators: + with patch( + "basalt.observability.decorators.with_evaluators", wraps=with_evaluators + ) as mock_with_evaluators: result = function_with_metadata("test-param") - assert result == "Metadata resolved for test-param", "Function should return a correctly formatted string." + assert result == "Metadata resolved for test-param", ( + "Function should return a correctly formatted string." + ) mock_with_evaluators.assert_called_once() def test_evaluate_asynchronous_function(): """Test that the evaluate decorator works for async functions.""" + @evaluate("async-slug") async def async_function(): return "async executed" diff --git a/tests/observability/test_evaluators.py b/tests/observability/test_evaluators.py index bdedf7d..3d8cba3 100644 --- a/tests/observability/test_evaluators.py +++ b/tests/observability/test_evaluators.py @@ -4,7 +4,7 @@ class DummyAttachment: - def __init__(self, slug): + def __init__(self, slug) -> None: self.slug = slug diff --git a/tests/observability/test_instrumentation.py b/tests/observability/test_instrumentation.py index 554725e..f6eab4b 100644 --- a/tests/observability/test_instrumentation.py +++ b/tests/observability/test_instrumentation.py @@ -3,7 +3,6 @@ from __future__ import annotations import os -import unittest from unittest import mock from basalt.observability.config import TelemetryConfig @@ -11,245 +10,257 @@ from basalt.observability.resilient_exporters import ResilientSpanExporter -class TestInstrumentationManager(unittest.TestCase): - @mock.patch("basalt.observability.instrumentation.setup_tracing") - def test_initialize_disabled_skips_tracing(self, mock_setup): - manager = InstrumentationManager() +@mock.patch("basalt.observability.instrumentation.setup_tracing") +def test_initialize_disabled_skips_tracing(mock_setup): + manager = InstrumentationManager() - manager.initialize(TelemetryConfig(enabled=False)) + manager.initialize(TelemetryConfig(enabled=False)) - mock_setup.assert_not_called() + mock_setup.assert_not_called() - @mock.patch("basalt.observability.instrumentation.setup_tracing") - def test_initialize_enables_tracing(self, mock_setup): - mock_setup.return_value = mock.Mock() - manager = InstrumentationManager() - manager.initialize(TelemetryConfig(service_name="svc")) +@mock.patch("basalt.observability.instrumentation.setup_tracing") +def test_initialize_enables_tracing(mock_setup): + mock_setup.return_value = mock.Mock() + manager = InstrumentationManager() - mock_setup.assert_called_once() + manager.initialize(TelemetryConfig(service_name="svc")) - @mock.patch.object(InstrumentationManager, "_uninstrument_providers") - def test_shutdown_flushes_provider(self, mock_uninstrument): - manager = InstrumentationManager() - manager._initialized = True - provider = mock.Mock() - manager._tracer_provider = provider + mock_setup.assert_called_once() - manager.shutdown() - mock_uninstrument.assert_called_once() - provider.force_flush.assert_called_once() - provider.shutdown.assert_called_once() - self.assertFalse(manager._initialized) +@mock.patch.object(InstrumentationManager, "_uninstrument_providers") +def test_shutdown_flushes_provider(mock_uninstrument): + manager = InstrumentationManager() + manager._initialized = True + provider = mock.Mock() + manager._tracer_provider = provider - @mock.patch.object(InstrumentationManager, "_instrument_providers") - def test_initialize_instrumentation_sets_env_and_instruments_providers(self, mock_providers): - config = TelemetryConfig( - trace_content=False, - enabled_providers=["openai", "anthropic"], - ) - manager = InstrumentationManager() + manager.shutdown() - with mock.patch.dict(os.environ, {}, clear=False): - manager._initialize_instrumentation(config) - self.assertEqual(os.environ["TRACELOOP_TRACE_CONTENT"], "false") + mock_uninstrument.assert_called_once() + provider.force_flush.assert_called_once() + provider.shutdown.assert_called_once() + assert not manager._initialized - mock_providers.assert_called_once_with(config) - @mock.patch.dict( - os.environ, - {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4317"}, - clear=False, +@mock.patch.object(InstrumentationManager, "_instrument_providers") +def test_initialize_instrumentation_sets_env_and_instruments_providers(mock_providers): + config = TelemetryConfig( + trace_content=False, + enabled_providers=["openai", "anthropic"], ) - @mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") - def test_build_exporter_from_env_adds_bearer_for_grpc(self, mock_grpc_exporter): - mock_grpc_exporter.return_value = mock.Mock() - manager = InstrumentationManager() - manager._resolve_api_key("test-key") - - exporter = manager._build_exporter_from_env() - - self.assertIs(exporter, mock_grpc_exporter.return_value) - mock_grpc_exporter.assert_called_once() - headers = mock_grpc_exporter.call_args.kwargs["headers"] - self.assertEqual(headers["authorization"], "Bearer test-key") - - @mock.patch.dict( - os.environ, - { - "BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "https://collector/v1/traces", - "BASALT_API_KEY": "env-key", - }, - clear=False, + manager = InstrumentationManager() + + with mock.patch.dict(os.environ, {}, clear=False): + manager._initialize_instrumentation(config) + assert os.environ["TRACELOOP_TRACE_CONTENT"] == "false" + + mock_providers.assert_called_once_with(config) + + +@mock.patch.dict( + os.environ, + {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4317"}, + clear=False, +) +@mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") +def test_build_exporter_from_env_adds_bearer_for_grpc(mock_grpc_exporter): + mock_grpc_exporter.return_value = mock.Mock() + manager = InstrumentationManager() + manager._resolve_api_key("test-key") + + exporter = manager._build_exporter_from_env() + + assert exporter is mock_grpc_exporter.return_value + mock_grpc_exporter.assert_called_once() + headers = mock_grpc_exporter.call_args.kwargs["headers"] + assert headers["authorization"] == "Bearer test-key" + + +@mock.patch.dict( + os.environ, + { + "BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "https://collector/v1/traces", + "BASALT_API_KEY": "env-key", + }, + clear=False, +) +@mock.patch("basalt.observability.instrumentation.OTLPHTTPSpanExporter") +@mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") +def test_build_exporter_from_env_adds_bearer_for_http( + mock_grpc_exporter, + mock_http_exporter, +): + mock_http_exporter.return_value = mock.Mock() + manager = InstrumentationManager() + + exporter = manager._build_exporter_from_env() + + # HTTP exporter is now wrapped in ResilientSpanExporter + assert isinstance(exporter, ResilientSpanExporter) + assert exporter._exporter is mock_http_exporter.return_value + mock_http_exporter.assert_called_once() + mock_grpc_exporter.assert_not_called() + headers = mock_http_exporter.call_args.kwargs["headers"] + assert headers["authorization"] == "Bearer env-key" + + +def test_should_instrument_provider_default(): + """Test that by default all providers are instrumented.""" + config = TelemetryConfig() + assert config.should_instrument_provider("openai") + assert config.should_instrument_provider("anthropic") + assert config.should_instrument_provider("langchain") + + +def test_should_instrument_provider_with_enabled_list(): + """Test that only enabled providers are instrumented when specified.""" + config = TelemetryConfig(enabled_providers=["openai", "anthropic"]) + assert config.should_instrument_provider("openai") + assert config.should_instrument_provider("anthropic") + assert not config.should_instrument_provider("langchain") + assert not config.should_instrument_provider("llamaindex") + + +def test_should_instrument_provider_with_disabled_list(): + """Test that disabled providers are not instrumented.""" + config = TelemetryConfig(disabled_providers=["langchain", "llamaindex"]) + assert config.should_instrument_provider("openai") + assert config.should_instrument_provider("anthropic") + assert not config.should_instrument_provider("langchain") + assert not config.should_instrument_provider("llamaindex") + + +def test_should_instrument_provider_disabled_takes_precedence(): + """Test that disabled list takes precedence over enabled list.""" + config = TelemetryConfig( + enabled_providers=["openai", "anthropic"], + disabled_providers=["anthropic"], ) - @mock.patch("basalt.observability.instrumentation.OTLPHTTPSpanExporter") - @mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") - def test_build_exporter_from_env_adds_bearer_for_http( - self, - mock_grpc_exporter, - mock_http_exporter, - ): - mock_http_exporter.return_value = mock.Mock() - manager = InstrumentationManager() - - exporter = manager._build_exporter_from_env() - - # HTTP exporter is now wrapped in ResilientSpanExporter - self.assertIsInstance(exporter, ResilientSpanExporter) - self.assertIs(exporter._exporter, mock_http_exporter.return_value) - mock_http_exporter.assert_called_once() - mock_grpc_exporter.assert_not_called() - headers = mock_http_exporter.call_args.kwargs["headers"] - self.assertEqual(headers["authorization"], "Bearer env-key") - - def test_should_instrument_provider_default(self): - """Test that by default all providers are instrumented.""" - config = TelemetryConfig() - self.assertTrue(config.should_instrument_provider("openai")) - self.assertTrue(config.should_instrument_provider("anthropic")) - self.assertTrue(config.should_instrument_provider("langchain")) - - def test_should_instrument_provider_with_enabled_list(self): - """Test that only enabled providers are instrumented when specified.""" - config = TelemetryConfig(enabled_providers=["openai", "anthropic"]) - self.assertTrue(config.should_instrument_provider("openai")) - self.assertTrue(config.should_instrument_provider("anthropic")) - self.assertFalse(config.should_instrument_provider("langchain")) - self.assertFalse(config.should_instrument_provider("llamaindex")) - - def test_should_instrument_provider_with_disabled_list(self): - """Test that disabled providers are not instrumented.""" - config = TelemetryConfig(disabled_providers=["langchain", "llamaindex"]) - self.assertTrue(config.should_instrument_provider("openai")) - self.assertTrue(config.should_instrument_provider("anthropic")) - self.assertFalse(config.should_instrument_provider("langchain")) - self.assertFalse(config.should_instrument_provider("llamaindex")) - - def test_should_instrument_provider_disabled_takes_precedence(self): - """Test that disabled list takes precedence over enabled list.""" - config = TelemetryConfig( - enabled_providers=["openai", "anthropic"], - disabled_providers=["anthropic"], - ) - self.assertTrue(config.should_instrument_provider("openai")) - self.assertFalse(config.should_instrument_provider("anthropic")) - - @mock.patch("basalt.observability.instrumentation._safe_import") - def test_instrument_providers_respects_config(self, mock_import): - """Test that _instrument_providers respects the configuration.""" - # Mock instrumentor with is_instrumented_by_opentelemetry property - mock_instrumentor = mock.Mock() - mock_instrumentor.is_instrumented_by_opentelemetry = False # Not yet instrumented - mock_instrumentor_cls = mock.Mock(return_value=mock_instrumentor) - - # Only return instrumentor class for openai and anthropic - def safe_import_side_effect(module, name): - if "openai" in module: - return mock_instrumentor_cls - elif "anthropic" in module: - return mock_instrumentor_cls - return None - - mock_import.side_effect = safe_import_side_effect - - config = TelemetryConfig(enabled_providers=["openai", "anthropic"]) - manager = InstrumentationManager() - manager._instrument_providers(config) - - # Should have called instrument() for both providers - self.assertEqual(mock_instrumentor.instrument.call_count, 2) - self.assertIn("openai", manager._provider_instrumentors) - self.assertIn("anthropic", manager._provider_instrumentors) - - @mock.patch.dict( - os.environ, - {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "https://bad-endpoint.invalid/v1/traces"}, - clear=False, - ) - @mock.patch("basalt.observability.instrumentation.OTLPHTTPSpanExporter") - def test_http_exporter_wrapped_in_resilient_wrapper(self, mock_http_exporter): - """Verify HTTP exporters are wrapped for error resilience.""" - mock_http_exporter_instance = mock.Mock() - mock_http_exporter.return_value = mock_http_exporter_instance - manager = InstrumentationManager() - - exporter = manager._build_exporter_from_env() - - # Should be wrapped in ResilientSpanExporter - self.assertIsInstance(exporter, ResilientSpanExporter) - # Underlying exporter should be the HTTP exporter instance - self.assertIs(exporter._exporter, mock_http_exporter_instance) - - @mock.patch.dict( - os.environ, - {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4317"}, - clear=False, - ) - @mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") - def test_grpc_exporter_not_wrapped(self, mock_grpc_exporter): - """Verify gRPC exporters are NOT wrapped (they handle errors internally).""" - mock_grpc_exporter_instance = mock.Mock() - mock_grpc_exporter.return_value = mock_grpc_exporter_instance - manager = InstrumentationManager() + assert config.should_instrument_provider("openai") + assert not config.should_instrument_provider("anthropic") + + +@mock.patch("basalt.observability.instrumentation._safe_import") +def test_instrument_providers_respects_config(mock_import): + """Test that _instrument_providers respects the configuration.""" + # Mock instrumentor with is_instrumented_by_opentelemetry property + mock_instrumentor = mock.Mock() + mock_instrumentor.is_instrumented_by_opentelemetry = False # Not yet instrumented + mock_instrumentor_cls = mock.Mock(return_value=mock_instrumentor) + + # Only return instrumentor class for openai and anthropic + def safe_import_side_effect(module, name): + if "openai" in module: + return mock_instrumentor_cls + if "anthropic" in module: + return mock_instrumentor_cls + return None + + mock_import.side_effect = safe_import_side_effect + + config = TelemetryConfig(enabled_providers=["openai", "anthropic"]) + manager = InstrumentationManager() + manager._instrument_providers(config) + + # Should have called instrument() for both providers + assert mock_instrumentor.instrument.call_count == 2 + assert "openai" in manager._provider_instrumentors + assert "anthropic" in manager._provider_instrumentors + + +@mock.patch.dict( + os.environ, + {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "https://bad-endpoint.invalid/v1/traces"}, + clear=False, +) +@mock.patch("basalt.observability.instrumentation.OTLPHTTPSpanExporter") +def test_http_exporter_wrapped_in_resilient_wrapper(mock_http_exporter): + """Verify HTTP exporters are wrapped for error resilience.""" + mock_http_exporter_instance = mock.Mock() + mock_http_exporter.return_value = mock_http_exporter_instance + manager = InstrumentationManager() + + exporter = manager._build_exporter_from_env() + + # Should be wrapped in ResilientSpanExporter + assert isinstance(exporter, ResilientSpanExporter) + # Underlying exporter should be the HTTP exporter instance + assert exporter._exporter is mock_http_exporter_instance + + +@mock.patch.dict( + os.environ, + {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4317"}, + clear=False, +) +@mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") +def test_grpc_exporter_not_wrapped(mock_grpc_exporter): + """Verify gRPC exporters are NOT wrapped (they handle errors internally).""" + mock_grpc_exporter_instance = mock.Mock() + mock_grpc_exporter.return_value = mock_grpc_exporter_instance + manager = InstrumentationManager() + + exporter = manager._build_exporter_from_env() - exporter = manager._build_exporter_from_env() + # Should NOT be wrapped, should be the gRPC exporter directly + assert not isinstance(exporter, ResilientSpanExporter) + assert exporter is mock_grpc_exporter_instance - # Should NOT be wrapped, should be the gRPC exporter directly - self.assertNotIsInstance(exporter, ResilientSpanExporter) - self.assertIs(exporter, mock_grpc_exporter_instance) - @mock.patch("basalt.observability.instrumentation.trace") - def test_install_processors_on_existing_provider(self, mock_trace): - """Test that Basalt processors are installed on an existing TracerProvider (e.g., Datadog).""" - from opentelemetry.sdk.trace import TracerProvider +@mock.patch("basalt.observability.instrumentation.trace") +def test_install_processors_on_existing_provider(mock_trace): + """Test that Basalt processors are installed on an existing TracerProvider (e.g., Datadog).""" + from opentelemetry.sdk.trace import TracerProvider - # Simulate an external tool (like Datadog) creating a provider first - external_provider = TracerProvider() - mock_trace.get_tracer_provider.return_value = external_provider - mock_trace.set_tracer_provider = mock.Mock() # Should not be called + # Simulate an external tool (like Datadog) creating a provider first + external_provider = TracerProvider() + mock_trace.get_tracer_provider.return_value = external_provider + mock_trace.set_tracer_provider = mock.Mock() # Should not be called - manager = InstrumentationManager() - config = TelemetryConfig(service_name="test", enabled=True) + manager = InstrumentationManager() + config = TelemetryConfig(service_name="test", enabled=True) - manager.initialize(config) + manager.initialize(config) - # Verify that setup_tracing reused the existing provider - mock_trace.set_tracer_provider.assert_not_called() + # Verify that setup_tracing reused the existing provider + mock_trace.set_tracer_provider.assert_not_called() - # Verify that Basalt processors were installed on the external provider - self.assertTrue(hasattr(external_provider, "_basalt_processors_installed")) - self.assertTrue(external_provider._basalt_processors_installed) + # Verify that Basalt processors were installed on the external provider + assert hasattr(external_provider, "_basalt_processors_installed") + assert external_provider._basalt_processors_installed - # Verify that the manager has references to the processors - # 4 processors: BasaltContextProcessor, BasaltCallEvaluatorProcessor, - # BasaltShouldEvaluateProcessor, BasaltAutoInstrumentationProcessor - self.assertEqual(len(manager._span_processors), 4) + # Verify that the manager has references to the processors + # 4 processors: BasaltContextProcessor, BasaltCallEvaluatorProcessor, + # BasaltShouldEvaluateProcessor, BasaltAutoInstrumentationProcessor + assert len(manager._span_processors) == 4 - # Verify that the manager stored the external provider - self.assertIs(manager._tracer_provider, external_provider) + # Verify that the manager stored the external provider + assert manager._tracer_provider is external_provider - @mock.patch("basalt.observability.instrumentation.trace") - def test_processors_not_installed_twice_on_same_provider(self, mock_trace): - """Test that Basalt processors are not installed twice on the same provider.""" - from opentelemetry.sdk.trace import TracerProvider - external_provider = TracerProvider() - mock_trace.get_tracer_provider.return_value = external_provider +@mock.patch("basalt.observability.instrumentation.trace") +def test_processors_not_installed_twice_on_same_provider(mock_trace): + """Test that Basalt processors are not installed twice on the same provider.""" + from opentelemetry.sdk.trace import TracerProvider + + external_provider = TracerProvider() + mock_trace.get_tracer_provider.return_value = external_provider - manager1 = InstrumentationManager() - manager2 = InstrumentationManager() + manager1 = InstrumentationManager() + manager2 = InstrumentationManager() - config = TelemetryConfig(service_name="test", enabled=True) + config = TelemetryConfig(service_name="test", enabled=True) - # First initialization should install processors - manager1.initialize(config) - processor_count_after_first = len(external_provider._active_span_processor._span_processors) + # First initialization should install processors + manager1.initialize(config) + processor_count_after_first = len(external_provider._active_span_processor._span_processors) - # Second initialization should NOT add processors again (idempotent) - manager2.initialize(config) - processor_count_after_second = len(external_provider._active_span_processor._span_processors) + # Second initialization should NOT add processors again (idempotent) + manager2.initialize(config) + processor_count_after_second = len(external_provider._active_span_processor._span_processors) - # Verify processors were only added once - self.assertEqual(processor_count_after_first, processor_count_after_second) - self.assertTrue(external_provider._basalt_processors_installed) + # Verify processors were only added once + assert processor_count_after_first == processor_count_after_second + assert external_provider._basalt_processors_installed diff --git a/tests/observability/test_multi_exporters.py b/tests/observability/test_multi_exporters.py index fb84cb8..c724987 100644 --- a/tests/observability/test_multi_exporters.py +++ b/tests/observability/test_multi_exporters.py @@ -1,8 +1,10 @@ """Tests for multiple span exporters functionality.""" -import unittest +from __future__ import annotations + from unittest import mock +import pytest from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ConsoleSpanExporter @@ -16,286 +18,281 @@ ) -class TestMultipleExporters(unittest.TestCase): - """Test cases for multiple span exporters support.""" - - @classmethod - def setUpClass(cls): - """Save original global state before any tests run.""" - cls._original_provider = trace.get_tracer_provider() - cls._original_once = trace._TRACER_PROVIDER_SET_ONCE - - @classmethod - def tearDownClass(cls): - """Restore original global state after all tests complete.""" - # Restore the Once flag so other test files can set providers - trace._TRACER_PROVIDER_SET_ONCE = cls._original_once - # If there was a real provider originally, try to restore it - if cls._original_provider and not isinstance(cls._original_provider, trace.ProxyTracerProvider): - trace._TRACER_PROVIDER = None - trace._TRACER_PROVIDER_SET_ONCE = trace.Once() - try: - trace.set_tracer_provider(cls._original_provider) - except Exception: - pass - - def setUp(self): - """Reset provider state before each test so we can set new ones.""" - # Allow each test to set its own provider by resetting the Once flag +@pytest.fixture(scope="module", autouse=True) +def restore_trace_provider(): + """Restore the original global tracer provider after this module.""" + original_provider = trace.get_tracer_provider() + original_once = trace._TRACER_PROVIDER_SET_ONCE + + yield + + trace._TRACER_PROVIDER_SET_ONCE = original_once + if original_provider and not isinstance(original_provider, trace.ProxyTracerProvider): trace._TRACER_PROVIDER = None trace._TRACER_PROVIDER_SET_ONCE = trace.Once() + try: + trace.set_tracer_provider(original_provider) + except Exception: + pass + + +@pytest.fixture(autouse=True) +def reset_trace_provider(): + """Reset provider state before each test so we can set new ones.""" + trace._TRACER_PROVIDER = None + trace._TRACER_PROVIDER_SET_ONCE = trace.Once() + yield + + +def test_single_exporter_backward_compatibility(): + """Test that single exporter still works (backward compatibility).""" + exporter = InMemorySpanExporter() + config = BasaltConfig(service_name="test-service") + + provider = create_tracer_provider(config, exporter=exporter) + + # Verify provider was created + assert isinstance(provider, TracerProvider) + # Verify exporter was added (check _active_span_processor has processors) + assert len(provider._active_span_processor._span_processors) > 0 + + +def test_multiple_exporters_list(): + """Test configuring with list of 2 exporters.""" + exporter1 = InMemorySpanExporter() + exporter2 = InMemorySpanExporter() + config = BasaltConfig(service_name="test-service") + + provider = create_tracer_provider(config, exporter=[exporter1, exporter2]) + + # Verify provider was created + assert isinstance(provider, TracerProvider) + # Verify both exporters were added (2 processors) + assert len(provider._active_span_processor._span_processors) == 2 - def tearDown(self): - """Don't clean up - let setUpClass/tearDownClass handle global state.""" + # Test that both exporters receive spans + trace.set_tracer_provider(provider) + tracer = trace.get_tracer("test") + + with tracer.start_as_current_span("test-span"): pass - def test_single_exporter_backward_compatibility(self): - """Test that single exporter still works (backward compatibility).""" - exporter = InMemorySpanExporter() - config = BasaltConfig(service_name="test-service") + # Force flush to ensure spans are exported + provider.force_flush() - provider = create_tracer_provider(config, exporter=exporter) + # Both exporters should have received the span + assert len(exporter1.get_finished_spans()) == 1 + assert len(exporter2.get_finished_spans()) == 1 - # Verify provider was created - self.assertIsInstance(provider, TracerProvider) - # Verify exporter was added (check _active_span_processor has processors) - self.assertGreater(len(provider._active_span_processor._span_processors), 0) + # Verify span content is identical + span1 = exporter1.get_finished_spans()[0] + span2 = exporter2.get_finished_spans()[0] + assert span1.name == span2.name + assert span1.context.trace_id == span2.context.trace_id + assert span1.context.span_id == span2.context.span_id - def test_multiple_exporters_list(self): - """Test configuring with list of 2 exporters.""" - exporter1 = InMemorySpanExporter() - exporter2 = InMemorySpanExporter() - config = BasaltConfig(service_name="test-service") - provider = create_tracer_provider(config, exporter=[exporter1, exporter2]) +def test_empty_list_uses_console_exporter(): + """Test that empty list falls back to ConsoleSpanExporter with warning.""" + config = BasaltConfig(service_name="test-service") - # Verify provider was created - self.assertIsInstance(provider, TracerProvider) - # Verify both exporters were added (2 processors) - self.assertEqual(len(provider._active_span_processor._span_processors), 2) + with pytest.warns(UserWarning) as warning: + provider = create_tracer_provider(config, exporter=[]) - # Test that both exporters receive spans - trace.set_tracer_provider(provider) - tracer = trace.get_tracer("test") + # Verify warning message + assert "Empty exporter list" in str(warning[0].message) - with tracer.start_as_current_span("test-span"): - pass + # Verify ConsoleSpanExporter was used + assert isinstance(provider, TracerProvider) + # Check that a processor was added + assert len(provider._active_span_processor._span_processors) > 0 - # Force flush to ensure spans are exported - provider.force_flush() - # Both exporters should have received the span - self.assertEqual(len(exporter1.get_finished_spans()), 1) - self.assertEqual(len(exporter2.get_finished_spans()), 1) - - # Verify span content is identical - span1 = exporter1.get_finished_spans()[0] - span2 = exporter2.get_finished_spans()[0] - self.assertEqual(span1.name, span2.name) - self.assertEqual(span1.context.trace_id, span2.context.trace_id) - self.assertEqual(span1.context.span_id, span2.context.span_id) - - def test_empty_list_uses_console_exporter(self): - """Test that empty list falls back to ConsoleSpanExporter with warning.""" - config = BasaltConfig(service_name="test-service") - - with self.assertWarns(UserWarning) as cm: - provider = create_tracer_provider(config, exporter=[]) - - # Verify warning message - self.assertIn("Empty exporter list", str(cm.warning)) - - # Verify ConsoleSpanExporter was used - self.assertIsInstance(provider, TracerProvider) - # Check that a processor was added - self.assertGreater(len(provider._active_span_processor._span_processors), 0) - - @mock.patch.dict("os.environ", {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"}, clear=False) - @mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") - def test_user_exporters_plus_env_exporter(self, mock_otlp_exporter): - """Test that user exporters are used instead of environment exporter.""" - # Mock the OTLP exporter creation - mock_env_exporter = mock.Mock() - mock_otlp_exporter.return_value = mock_env_exporter - - user_exporter = InMemorySpanExporter() - config = TelemetryConfig( - service_name="test-service", - exporter=user_exporter, - ) - - manager = InstrumentationManager() - manager.initialize(config) - - # Verify user exporter was used - provider = manager._tracer_provider - self.assertIsInstance(provider, TracerProvider) - # Should have 1 exporter + 4 Basalt processors = 5 total processors - # Basalt processors: Context, CallEvaluator, ShouldEvaluate, AutoInstrumentation - # Note: Environment exporter is only used if no user exporter is provided - self.assertEqual(len(provider._active_span_processor._span_processors), 5) - - def test_mixed_console_and_otlp_exporters(self): - """Test mix of ConsoleSpanExporter and regular exporters.""" - console_exporter = ConsoleSpanExporter() - memory_exporter = InMemorySpanExporter() - config = BasaltConfig(service_name="test-service") - - provider = create_tracer_provider( - config, exporter=[console_exporter, memory_exporter] - ) - - # Verify both exporters were added - self.assertIsInstance(provider, TracerProvider) - self.assertEqual(len(provider._active_span_processor._span_processors), 2) - - # Test span export - trace.set_tracer_provider(provider) - tracer = trace.get_tracer("test") - - with tracer.start_as_current_span("test-span"): - pass +@mock.patch.dict( + "os.environ", {"BASALT_OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4318"}, clear=False +) +@mock.patch("basalt.observability.instrumentation.OTLPSpanExporter") +def test_user_exporters_plus_env_exporter(mock_otlp_exporter): + """Test that user exporters are used instead of environment exporter.""" + # Mock the OTLP exporter creation + mock_env_exporter = mock.Mock() + mock_otlp_exporter.return_value = mock_env_exporter + + user_exporter = InMemorySpanExporter() + config = TelemetryConfig( + service_name="test-service", + exporter=user_exporter, + ) + + manager = InstrumentationManager() + manager.initialize(config) + + # Verify user exporter was used + provider = manager._tracer_provider + assert isinstance(provider, TracerProvider) + # Should have 1 exporter + 4 Basalt processors = 5 total processors + # Basalt processors: Context, CallEvaluator, ShouldEvaluate, AutoInstrumentation + # Note: Environment exporter is only used if no user exporter is provided + assert len(provider._active_span_processor._span_processors) == 5 + + +def test_mixed_console_and_otlp_exporters(): + """Test mix of ConsoleSpanExporter and regular exporters.""" + console_exporter = ConsoleSpanExporter() + memory_exporter = InMemorySpanExporter() + config = BasaltConfig(service_name="test-service") + + provider = create_tracer_provider(config, exporter=[console_exporter, memory_exporter]) + + # Verify both exporters were added + assert isinstance(provider, TracerProvider) + assert len(provider._active_span_processor._span_processors) == 2 + + # Test span export + trace.set_tracer_provider(provider) + tracer = trace.get_tracer("test") + + with tracer.start_as_current_span("test-span"): + pass - provider.force_flush() + provider.force_flush() - # Memory exporter should have received the span - self.assertEqual(len(memory_exporter.get_finished_spans()), 1) + # Memory exporter should have received the span + assert len(memory_exporter.get_finished_spans()) == 1 - def test_exporter_isolation_on_error(self): - """Test that one failing exporter doesn't affect others.""" - failing_exporter = mock.Mock() - failing_exporter.export.side_effect = Exception("Export failed") - working_exporter = InMemorySpanExporter() - config = BasaltConfig(service_name="test-service") +def test_exporter_isolation_on_error(): + """Test that one failing exporter doesn't affect others.""" + failing_exporter = mock.Mock() + failing_exporter.export.side_effect = Exception("Export failed") - provider = create_tracer_provider( - config, exporter=[failing_exporter, working_exporter] - ) + working_exporter = InMemorySpanExporter() + config = BasaltConfig(service_name="test-service") - trace.set_tracer_provider(provider) - tracer = trace.get_tracer("test") + provider = create_tracer_provider(config, exporter=[failing_exporter, working_exporter]) - with tracer.start_as_current_span("test-span"): - pass + trace.set_tracer_provider(provider) + tracer = trace.get_tracer("test") - # Force flush (failing exporter will raise but shouldn't stop working exporter) - try: - provider.force_flush() - except Exception: - pass # Expected from failing exporter + with tracer.start_as_current_span("test-span"): + pass + + # Force flush (failing exporter will raise but shouldn't stop working exporter) + try: + provider.force_flush() + except Exception: + pass # Expected from failing exporter + + # Working exporter should still have received the span + assert len(working_exporter.get_finished_spans()) == 1 + + +def test_duplicate_exporters_allowed(): + """Test that duplicate exporters in list are allowed (user responsibility).""" + exporter = InMemorySpanExporter() + config = BasaltConfig(service_name="test-service") - # Working exporter should still have received the span - self.assertEqual(len(working_exporter.get_finished_spans()), 1) + # Same exporter instance twice + provider = create_tracer_provider(config, exporter=[exporter, exporter]) - def test_duplicate_exporters_allowed(self): - """Test that duplicate exporters in list are allowed (user responsibility).""" - exporter = InMemorySpanExporter() - config = BasaltConfig(service_name="test-service") + # Should have 2 processors (both using same exporter) + assert len(provider._active_span_processor._span_processors) == 2 - # Same exporter instance twice - provider = create_tracer_provider(config, exporter=[exporter, exporter]) - # Should have 2 processors (both using same exporter) - self.assertEqual(len(provider._active_span_processor._span_processors), 2) +def test_none_exporter_uses_console_with_warning(): + """Test that None exporter defaults to ConsoleSpanExporter with warning.""" + config = BasaltConfig(service_name="test-service") - def test_none_exporter_uses_console_with_warning(self): - """Test that None exporter defaults to ConsoleSpanExporter with warning.""" - config = BasaltConfig(service_name="test-service") + with pytest.warns(UserWarning) as warning: + provider = create_tracer_provider(config, exporter=None) - with self.assertWarns(UserWarning) as cm: - provider = create_tracer_provider(config, exporter=None) + # Verify warning message + assert "No span exporter configured" in str(warning[0].message) - # Verify warning message - self.assertIn("No span exporter configured", str(cm.warning)) + # Verify provider was created + assert isinstance(provider, TracerProvider) - # Verify provider was created - self.assertIsInstance(provider, TracerProvider) +def test_config_accepts_exporter_list(): + """Test that TelemetryConfig accepts list of exporters.""" + exporter1 = InMemorySpanExporter() + exporter2 = InMemorySpanExporter() -class TestTelemetryConfigWithMultipleExporters(unittest.TestCase): - """Test TelemetryConfig with multiple exporters.""" + config = TelemetryConfig( + service_name="test-service", + exporter=[exporter1, exporter2], + ) - def test_config_accepts_exporter_list(self): - """Test that TelemetryConfig accepts list of exporters.""" - exporter1 = InMemorySpanExporter() - exporter2 = InMemorySpanExporter() + assert isinstance(config.exporter, list) + assert len(config.exporter) == 2 + assert config.exporter[0] is exporter1 + assert config.exporter[1] is exporter2 - config = TelemetryConfig( - service_name="test-service", - exporter=[exporter1, exporter2], - ) - self.assertIsInstance(config.exporter, list) - self.assertEqual(len(config.exporter), 2) - self.assertIs(config.exporter[0], exporter1) - self.assertIs(config.exporter[1], exporter2) +def test_config_accepts_single_exporter(): + """Test backward compatibility: single exporter still works.""" + exporter = InMemorySpanExporter() - def test_config_accepts_single_exporter(self): - """Test backward compatibility: single exporter still works.""" - exporter = InMemorySpanExporter() + config = TelemetryConfig( + service_name="test-service", + exporter=exporter, + ) - config = TelemetryConfig( - service_name="test-service", - exporter=exporter, - ) + # Should be the exporter itself, not wrapped in list + assert isinstance(config.exporter, InMemorySpanExporter) + assert config.exporter is exporter - # Should be the exporter itself, not wrapped in list - self.assertIsInstance(config.exporter, InMemorySpanExporter) - self.assertIs(config.exporter, exporter) - def test_clone_with_exporter_list(self): - """Test that clone() properly copies exporter lists.""" - exporter1 = InMemorySpanExporter() - exporter2 = InMemorySpanExporter() +def test_clone_with_exporter_list(): + """Test that clone() properly copies exporter lists.""" + exporter1 = InMemorySpanExporter() + exporter2 = InMemorySpanExporter() - original = TelemetryConfig( - service_name="test-service", - exporter=[exporter1, exporter2], - ) + original = TelemetryConfig( + service_name="test-service", + exporter=[exporter1, exporter2], + ) - cloned = original.clone() + cloned = original.clone() - # Verify it's a new list instance - self.assertIsNot(cloned.exporter, original.exporter) - # But contains same exporter objects - self.assertEqual(len(cloned.exporter), 2) - self.assertIs(cloned.exporter[0], exporter1) - self.assertIs(cloned.exporter[1], exporter2) + # Verify it's a new list instance + assert cloned.exporter is not original.exporter + # But contains same exporter objects + assert len(cloned.exporter) == 2 + assert cloned.exporter[0] is exporter1 + assert cloned.exporter[1] is exporter2 - def test_clone_list_independence(self): - """Test that modifying cloned exporter list doesn't affect original.""" - exporter1 = InMemorySpanExporter() - exporter2 = InMemorySpanExporter() - original = TelemetryConfig( - service_name="test-service", - exporter=[exporter1, exporter2], - ) +def test_clone_list_independence(): + """Test that modifying cloned exporter list doesn't affect original.""" + exporter1 = InMemorySpanExporter() + exporter2 = InMemorySpanExporter() - cloned = original.clone() + original = TelemetryConfig( + service_name="test-service", + exporter=[exporter1, exporter2], + ) - # Modify cloned list - if isinstance(cloned.exporter, list): - cloned.exporter.append(InMemorySpanExporter()) + cloned = original.clone() - # Original should be unchanged - self.assertEqual(len(original.exporter), 2) + # Modify cloned list + if isinstance(cloned.exporter, list): + cloned.exporter.append(InMemorySpanExporter()) - def test_clone_with_single_exporter(self): - """Test that clone() handles single exporter correctly.""" - exporter = InMemorySpanExporter() + # Original should be unchanged + assert len(original.exporter) == 2 - original = TelemetryConfig( - service_name="test-service", - exporter=exporter, - ) - cloned = original.clone() +def test_clone_with_single_exporter(): + """Test that clone() handles single exporter correctly.""" + exporter = InMemorySpanExporter() - # Should be same exporter object (not cloned) - self.assertIs(cloned.exporter, exporter) + original = TelemetryConfig( + service_name="test-service", + exporter=exporter, + ) + cloned = original.clone() -if __name__ == "__main__": - unittest.main() + # Should be same exporter object (not cloned) + assert cloned.exporter is exporter diff --git a/tests/observability/test_processors.py b/tests/observability/test_processors.py index 7c0303b..10589e8 100644 --- a/tests/observability/test_processors.py +++ b/tests/observability/test_processors.py @@ -7,7 +7,7 @@ class DummySpan: - def __init__(self, is_recording=True, attributes=None): + def __init__(self, is_recording=True, attributes=None) -> None: self._is_recording = is_recording self.attributes = attributes if attributes is not None else {} self.set_attributes = {} @@ -18,22 +18,26 @@ def is_recording(self): def set_attribute(self, key, value): self.set_attributes[key] = value + @pytest.fixture def mock_semconv(): with patch("basalt.observability.processors.semconv") as mock_semconv: mock_semconv.BasaltSpan.EVALUATORS = "basalt.evaluators" yield mock_semconv + def test_no_slugs(mock_semconv): span = DummySpan() processors._merge_evaluators(cast(processors.Span, span), []) assert span.set_attributes == {} + def test_span_not_recording(mock_semconv): span = DummySpan(is_recording=False) processors._merge_evaluators(cast(processors.Span, span), ["foo"]) assert span.set_attributes == {} + def test_merge_with_no_existing(mock_semconv): span = DummySpan(attributes={}) processors._merge_evaluators(cast(processors.Span, span), ["foo", "bar"]) @@ -41,6 +45,7 @@ def test_merge_with_no_existing(mock_semconv): assert key in span.set_attributes assert span.set_attributes[key] == ["foo", "bar"] + def test_merge_with_existing(mock_semconv): key = mock_semconv.BasaltSpan.EVALUATORS span = DummySpan(attributes={key: ["foo", "baz"]}) @@ -48,12 +53,14 @@ def test_merge_with_existing(mock_semconv): # Should merge and deduplicate: ["foo", "baz", "bar"] assert span.set_attributes[key] == ["foo", "baz", "bar"] + def test_merge_with_empty_and_whitespace_slugs(mock_semconv): key = mock_semconv.BasaltSpan.EVALUATORS span = DummySpan(attributes={key: ["", " ", "foo"]}) processors._merge_evaluators(cast(processors.Span, span), ["", "bar", " "]) assert span.set_attributes[key] == ["foo", "bar"] + def test_merge_with_non_dict_attributes(mock_semconv): key = mock_semconv.BasaltSpan.EVALUATORS span = DummySpan() @@ -70,7 +77,10 @@ def test_openai_v1_scope_recognized(): def test_openai_v1_scope_has_generation_kind(): """Test that opentelemetry.instrumentation.openai.v1 is mapped to GENERATION kind.""" assert "opentelemetry.instrumentation.openai.v1" in processors.INSTRUMENTATION_SCOPE_KINDS - assert processors.INSTRUMENTATION_SCOPE_KINDS["opentelemetry.instrumentation.openai.v1"] == "generation" + assert ( + processors.INSTRUMENTATION_SCOPE_KINDS["opentelemetry.instrumentation.openai.v1"] + == "generation" + ) def test_auto_instrumentation_processor_sets_in_trace_for_openai_v1(): @@ -132,7 +142,7 @@ def test_auto_instrumentation_processor_injects_prompt_from_contextvar(): "provider": "gemini", "model": "gemini-2.5-flash-lite", "variables": {"var1": "value1"}, - "from_cache": False + "from_cache": False, } cv_token = _current_prompt_context.set(prompt_ctx) @@ -146,11 +156,14 @@ def test_auto_instrumentation_processor_injects_prompt_from_contextvar(): mock_span.set_attribute.assert_any_call("basalt.prompt.version", "v1.0.0") mock_span.set_attribute.assert_any_call("basalt.prompt.tag", "latest") mock_span.set_attribute.assert_any_call("basalt.prompt.model.provider", "gemini") - mock_span.set_attribute.assert_any_call("basalt.prompt.model.model", "gemini-2.5-flash-lite") + mock_span.set_attribute.assert_any_call( + "basalt.prompt.model.model", "gemini-2.5-flash-lite" + ) mock_span.set_attribute.assert_any_call("basalt.prompt.from_cache", False) # Check that variables were serialized as JSON import json + calls = mock_span.set_attribute.call_args_list variables_call = [call for call in calls if call[0][0] == "basalt.prompt.variables"] assert len(variables_call) == 1 @@ -185,7 +198,7 @@ def test_auto_instrumentation_processor_explicit_injection_overrides_contextvar( "version": "v1.0.0", "provider": "openai", "model": "gpt-4", - "from_cache": False + "from_cache": False, } cv_token = _current_prompt_context.set(contextvar_prompt) @@ -277,7 +290,7 @@ def test_auto_instrumentation_processor_prompt_context_with_optional_fields(): "provider": "openai", "model": "gpt-4", "variables": None, - "from_cache": True + "from_cache": True, } cv_token = _current_prompt_context.set(prompt_ctx) diff --git a/tests/observability/test_request_tracing.py b/tests/observability/test_request_tracing.py index 1cff902..0e4fc7b 100644 --- a/tests/observability/test_request_tracing.py +++ b/tests/observability/test_request_tracing.py @@ -5,7 +5,7 @@ class DummySpan: - def __init__(self): + def __init__(self) -> None: self.variables = None self.attributes = {} self.exceptions = [] @@ -28,12 +28,12 @@ def set_status(self, status): class DummyObserve: - def __init__(self): + def __init__(self) -> None: self.inputs = [] self.outputs = [] self.entered = [] - def __call__(self, *, name, metadata): + def __call__(self, *, name, metadata) -> "_DummyContext": self.entered.append({"name": name, "metadata": metadata}) return _DummyContext(self) @@ -45,21 +45,21 @@ def output(self, payload): class _DummyContext: - def __init__(self, observe): + def __init__(self, observe) -> None: self.observe = observe self.span = DummySpan() - def __enter__(self): + def __enter__(self) -> DummySpan: return self.span - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: return False class RecordingSpan(BasaltRequestSpan): __slots__ = ("finalize_calls",) - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.finalize_calls: list[dict] = [] diff --git a/tests/observability/test_resilient_exporters.py b/tests/observability/test_resilient_exporters.py index 915bf55..631a76d 100644 --- a/tests/observability/test_resilient_exporters.py +++ b/tests/observability/test_resilient_exporters.py @@ -2,141 +2,148 @@ from __future__ import annotations -import unittest from unittest import mock +import pytest from opentelemetry.sdk.trace.export import SpanExportResult from basalt.observability.resilient_exporters import ResilientSpanExporter -class TestResilientSpanExporter(unittest.TestCase): - def test_successful_export_delegates_to_underlying_exporter(self): - """When export succeeds, return the result from underlying exporter.""" - mock_exporter = mock.Mock() - mock_exporter.export.return_value = SpanExportResult.SUCCESS +def test_successful_export_delegates_to_underlying_exporter(): + """When export succeeds, return the result from underlying exporter.""" + mock_exporter = mock.Mock() + mock_exporter.export.return_value = SpanExportResult.SUCCESS - wrapper = ResilientSpanExporter(mock_exporter) - spans = [mock.Mock()] + wrapper = ResilientSpanExporter(mock_exporter) + spans = [mock.Mock()] - result = wrapper.export(spans) + result = wrapper.export(spans) - self.assertEqual(result, SpanExportResult.SUCCESS) - mock_exporter.export.assert_called_once_with(spans) + assert result == SpanExportResult.SUCCESS + mock_exporter.export.assert_called_once_with(spans) - def test_connection_error_returns_failure_without_raising(self): - """When connection error occurs, log and return FAILURE.""" - mock_exporter = mock.Mock() - mock_exporter.export.side_effect = ConnectionError("DNS resolution failed") - wrapper = ResilientSpanExporter(mock_exporter) - spans = [mock.Mock()] +def test_connection_error_returns_failure_without_raising(): + """When connection error occurs, log and return FAILURE.""" + mock_exporter = mock.Mock() + mock_exporter.export.side_effect = ConnectionError("DNS resolution failed") - # Should not raise - result = wrapper.export(spans) + wrapper = ResilientSpanExporter(mock_exporter) + spans = [mock.Mock()] - self.assertEqual(result, SpanExportResult.FAILURE) + # Should not raise + result = wrapper.export(spans) - @mock.patch("basalt.observability.resilient_exporters.logger") - def test_exception_logged_at_warning_level(self, mock_logger): - """Verify exceptions are logged at warning level.""" - mock_exporter = mock.Mock() - error = ConnectionError("Connection refused") - mock_exporter.export.side_effect = error + assert result == SpanExportResult.FAILURE + + +@mock.patch("basalt.observability.resilient_exporters.logger") +def test_exception_logged_at_warning_level(mock_logger): + """Verify exceptions are logged at warning level.""" + mock_exporter = mock.Mock() + error = ConnectionError("Connection refused") + mock_exporter.export.side_effect = error + + wrapper = ResilientSpanExporter(mock_exporter) + wrapper.export([mock.Mock()]) + + # Should log at warning level so users see the error + mock_logger.warning.assert_called_once() + # Check that it was called with the format string and exception type/message + call_args = mock_logger.warning.call_args[0] + assert "Span export failed" in call_args[0] + assert call_args[1] == "ConnectionError" + assert call_args[2] == error - wrapper = ResilientSpanExporter(mock_exporter) - wrapper.export([mock.Mock()]) - # Should log at warning level so users see the error - mock_logger.warning.assert_called_once() - # Check that it was called with the format string and exception type/message - call_args = mock_logger.warning.call_args[0] - self.assertIn("Span export failed", call_args[0]) - self.assertEqual("ConnectionError", call_args[1]) - self.assertEqual(error, call_args[2]) +@pytest.mark.parametrize( + "exception", + [ + ConnectionError("Connection failed"), + OSError("Network unreachable"), + TimeoutError("Request timeout"), + RuntimeError("Unexpected error"), + ], +) +def test_various_exception_types_caught(exception): + """Verify different exception types are all caught.""" + mock_exporter = mock.Mock() + mock_exporter.export.side_effect = exception - def test_various_exception_types_caught(self): - """Verify different exception types are all caught.""" - exception_types = [ - ConnectionError("Connection failed"), - OSError("Network unreachable"), - TimeoutError("Request timeout"), - RuntimeError("Unexpected error"), - ] + wrapper = ResilientSpanExporter(mock_exporter) - for exception in exception_types: - with self.subTest(exception=exception): - mock_exporter = mock.Mock() - mock_exporter.export.side_effect = exception + # Should not raise + result = wrapper.export([mock.Mock()]) + assert result == SpanExportResult.FAILURE - wrapper = ResilientSpanExporter(mock_exporter) - # Should not raise - result = wrapper.export([mock.Mock()]) - self.assertEqual(result, SpanExportResult.FAILURE) +def test_shutdown_suppresses_exceptions(): + """Shutdown errors are suppressed and logged.""" + mock_exporter = mock.Mock() + mock_exporter.shutdown.side_effect = RuntimeError("Shutdown failed") - def test_shutdown_suppresses_exceptions(self): - """Shutdown errors are suppressed and logged.""" - mock_exporter = mock.Mock() - mock_exporter.shutdown.side_effect = RuntimeError("Shutdown failed") + wrapper = ResilientSpanExporter(mock_exporter) - wrapper = ResilientSpanExporter(mock_exporter) + # Should not raise + wrapper.shutdown() - # Should not raise - wrapper.shutdown() + mock_exporter.shutdown.assert_called_once() - mock_exporter.shutdown.assert_called_once() - def test_force_flush_suppresses_exceptions_returns_false(self): - """Force flush errors return False.""" - mock_exporter = mock.Mock() - mock_exporter.force_flush.side_effect = RuntimeError("Flush failed") +def test_force_flush_suppresses_exceptions_returns_false(): + """Force flush errors return False.""" + mock_exporter = mock.Mock() + mock_exporter.force_flush.side_effect = RuntimeError("Flush failed") - wrapper = ResilientSpanExporter(mock_exporter) + wrapper = ResilientSpanExporter(mock_exporter) - result = wrapper.force_flush(1000) + result = wrapper.force_flush(1000) - self.assertFalse(result) - mock_exporter.force_flush.assert_called_once_with(1000) + assert result is False + mock_exporter.force_flush.assert_called_once_with(1000) - def test_force_flush_success_returns_true(self): - """Force flush success returns True.""" - mock_exporter = mock.Mock() - mock_exporter.force_flush.return_value = True - wrapper = ResilientSpanExporter(mock_exporter) +def test_force_flush_success_returns_true(): + """Force flush success returns True.""" + mock_exporter = mock.Mock() + mock_exporter.force_flush.return_value = True - result = wrapper.force_flush(5000) + wrapper = ResilientSpanExporter(mock_exporter) - self.assertTrue(result) - mock_exporter.force_flush.assert_called_once_with(5000) + result = wrapper.force_flush(5000) - def test_custom_exception_types_can_be_specified(self): - """Can configure which exception types to suppress.""" - mock_exporter = mock.Mock() - mock_exporter.export.side_effect = ValueError("Not a connection error") + assert result is True + mock_exporter.force_flush.assert_called_once_with(5000) - # Only suppress ConnectionError - wrapper = ResilientSpanExporter( - mock_exporter, - suppress_exceptions=(ConnectionError,), - ) - # Should raise ValueError since it's not in suppress list - with self.assertRaises(ValueError): - wrapper.export([mock.Mock()]) +def test_custom_exception_types_can_be_specified(): + """Can configure which exception types to suppress.""" + mock_exporter = mock.Mock() + mock_exporter.export.side_effect = ValueError("Not a connection error") + + # Only suppress ConnectionError + wrapper = ResilientSpanExporter( + mock_exporter, + suppress_exceptions=(ConnectionError,), + ) + + # Should raise ValueError since it's not in suppress list + with pytest.raises(ValueError): + wrapper.export([mock.Mock()]) + - def test_custom_exception_types_suppress_configured_types(self): - """Custom exception types work correctly.""" - mock_exporter = mock.Mock() - mock_exporter.export.side_effect = ConnectionError("Connection failed") +def test_custom_exception_types_suppress_configured_types(): + """Custom exception types work correctly.""" + mock_exporter = mock.Mock() + mock_exporter.export.side_effect = ConnectionError("Connection failed") - # Only suppress ConnectionError - wrapper = ResilientSpanExporter( - mock_exporter, - suppress_exceptions=(ConnectionError,), - ) + # Only suppress ConnectionError + wrapper = ResilientSpanExporter( + mock_exporter, + suppress_exceptions=(ConnectionError,), + ) - # Should not raise ConnectionError - result = wrapper.export([mock.Mock()]) - self.assertEqual(result, SpanExportResult.FAILURE) + # Should not raise ConnectionError + result = wrapper.export([mock.Mock()]) + assert result == SpanExportResult.FAILURE diff --git a/tests/observability/test_should_evaluate_propagation.py b/tests/observability/test_should_evaluate_propagation.py index 2e3f6a8..7402fa6 100644 --- a/tests/observability/test_should_evaluate_propagation.py +++ b/tests/observability/test_should_evaluate_propagation.py @@ -17,7 +17,7 @@ class InMemorySpanExporter(SpanExporter): """Simple in-memory span exporter for testing.""" - def __init__(self): + def __init__(self) -> None: self.spans = [] def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: @@ -44,13 +44,13 @@ def setup_tracer(): provider = trace.get_tracer_provider() # If provider is a ProxyTracerProvider, create a real one - if type(provider).__name__ == 'ProxyTracerProvider': + if type(provider).__name__ == "ProxyTracerProvider": provider = TracerProvider() provider.add_span_processor(BasaltShouldEvaluateProcessor()) trace.set_tracer_provider(provider) # Ensure BasaltShouldEvaluateProcessor is installed - if not hasattr(provider, '_basalt_should_evaluate_installed'): + if not hasattr(provider, "_basalt_should_evaluate_installed"): processor = BasaltShouldEvaluateProcessor() provider.add_span_processor(processor) provider._basalt_should_evaluate_installed = True # type: ignore[attr-defined] @@ -72,9 +72,7 @@ def test_sample_rate_1_propagates_to_children(self, setup_tracer): exporter = setup_tracer with start_observe( - name="parent", - feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=1.0) + name="parent", feature_slug="test", evaluate_config=EvaluationConfig(sample_rate=1.0) ): with observe(name="child1", kind=ObserveKind.FUNCTION): with observe(name="grandchild", kind=ObserveKind.FUNCTION): @@ -88,19 +86,19 @@ def test_sample_rate_1_propagates_to_children(self, setup_tracer): # All spans should have should_evaluate=True for span in spans: - assert BasaltSpan.SHOULD_EVALUATE in span.attributes, \ + assert BasaltSpan.SHOULD_EVALUATE in span.attributes, ( f"Span {span.name} missing should_evaluate" - assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is True, \ + ) + assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is True, ( f"Span {span.name} has should_evaluate={span.attributes[BasaltSpan.SHOULD_EVALUATE]}, expected True" + ) def test_sample_rate_0_propagates_to_children(self, setup_tracer): """Test that should_evaluate=False propagates to all child spans.""" exporter = setup_tracer with start_observe( - name="parent", - feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=0.0) + name="parent", feature_slug="test", evaluate_config=EvaluationConfig(sample_rate=0.0) ): with observe(name="child1", kind=ObserveKind.FUNCTION): with observe(name="grandchild", kind=ObserveKind.FUNCTION): @@ -114,10 +112,12 @@ def test_sample_rate_0_propagates_to_children(self, setup_tracer): # All spans should have should_evaluate=False for span in spans: - assert BasaltSpan.SHOULD_EVALUATE in span.attributes, \ + assert BasaltSpan.SHOULD_EVALUATE in span.attributes, ( f"Span {span.name} missing should_evaluate" - assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is False, \ + ) + assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is False, ( f"Span {span.name} has should_evaluate={span.attributes[BasaltSpan.SHOULD_EVALUATE]}, expected False" + ) def test_experiment_forces_true_for_all_spans(self, setup_tracer): """Test that experiment forces should_evaluate=True for all spans.""" @@ -127,7 +127,7 @@ def test_experiment_forces_true_for_all_spans(self, setup_tracer): name="parent", feature_slug="test", experiment="exp_123", - evaluate_config=EvaluationConfig(sample_rate=0.0) # Would normally be False + evaluate_config=EvaluationConfig(sample_rate=0.0), # Would normally be False ): with observe(name="child1", kind=ObserveKind.FUNCTION): with observe(name="grandchild", kind=ObserveKind.FUNCTION): @@ -141,10 +141,12 @@ def test_experiment_forces_true_for_all_spans(self, setup_tracer): # All spans should have should_evaluate=True due to experiment for span in spans: - assert BasaltSpan.SHOULD_EVALUATE in span.attributes, \ + assert BasaltSpan.SHOULD_EVALUATE in span.attributes, ( f"Span {span.name} missing should_evaluate" - assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is True, \ + ) + assert span.attributes[BasaltSpan.SHOULD_EVALUATE] is True, ( f"Span {span.name} has should_evaluate={span.attributes[BasaltSpan.SHOULD_EVALUATE]}, expected True due to experiment" + ) def test_experiment_overrides_sample_rate_for_all_spans(self, setup_tracer): """Test that experiment overrides sample_rate=0.0 for entire trace.""" @@ -153,7 +155,7 @@ def test_experiment_overrides_sample_rate_for_all_spans(self, setup_tracer): with start_observe( name="experiment_trace", feature_slug="test", - experiment="exp_456" + experiment="exp_456", # No evaluate_config, global default is 0.0 ): with observe(name="processing", kind=ObserveKind.FUNCTION): @@ -168,17 +170,16 @@ def test_experiment_overrides_sample_rate_for_all_spans(self, setup_tracer): # Verify all have should_evaluate=True for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, ( f"Span {span.name} should have should_evaluate=True with experiment" + ) def test_deeply_nested_spans_propagate(self, setup_tracer): """Test propagation through deeply nested span hierarchy.""" exporter = setup_tracer with start_observe( - name="root", - feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=1.0) + name="root", feature_slug="test", evaluate_config=EvaluationConfig(sample_rate=1.0) ): with observe(name="level1", kind=ObserveKind.FUNCTION): with observe(name="level2", kind=ObserveKind.FUNCTION): @@ -191,21 +192,17 @@ def test_deeply_nested_spans_propagate(self, setup_tracer): assert len(spans) == 6, f"Expected 6 spans, got {len(spans)}" # All spans should have same should_evaluate value - should_evaluate_values = [ - span.attributes.get(BasaltSpan.SHOULD_EVALUATE) - for span in spans - ] - assert all(v is True for v in should_evaluate_values), \ + should_evaluate_values = [span.attributes.get(BasaltSpan.SHOULD_EVALUATE) for span in spans] + assert all(v is True for v in should_evaluate_values), ( f"All spans should have should_evaluate=True, got: {should_evaluate_values}" + ) def test_multiple_child_branches_propagate(self, setup_tracer): """Test propagation across multiple child branches.""" exporter = setup_tracer with start_observe( - name="root", - feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=0.0) + name="root", feature_slug="test", evaluate_config=EvaluationConfig(sample_rate=0.0) ): # Branch 1 with observe(name="branch1", kind=ObserveKind.FUNCTION): @@ -226,8 +223,9 @@ def test_multiple_child_branches_propagate(self, setup_tracer): # All should be False for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is False, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is False, ( f"Span {span.name} should have should_evaluate=False" + ) def test_decorator_style_propagation(self, setup_tracer): """Test propagation works with decorator-style usage.""" @@ -236,7 +234,7 @@ def test_decorator_style_propagation(self, setup_tracer): @start_observe( name="decorated_root", feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=1.0) + evaluate_config=EvaluationConfig(sample_rate=1.0), ) def root_function(): with observe(name="child_in_decorator", kind=ObserveKind.FUNCTION): @@ -253,8 +251,9 @@ def nested_function(): # All should have should_evaluate=True for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, ( f"Span {span.name} should have should_evaluate=True" + ) def test_experiment_with_decorator_propagation(self, setup_tracer): """Test experiment forces evaluation with decorator pattern.""" @@ -264,7 +263,7 @@ def test_experiment_with_decorator_propagation(self, setup_tracer): name="experiment_decorated", feature_slug="test", experiment="exp_789", - evaluate_config=EvaluationConfig(sample_rate=0.0) + evaluate_config=EvaluationConfig(sample_rate=0.0), ) def experiment_function(): with observe(name="child", kind=ObserveKind.FUNCTION): @@ -277,8 +276,9 @@ def experiment_function(): # Both should be True due to experiment for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, ( f"Span {span.name} should have should_evaluate=True with experiment" + ) def test_mixed_span_kinds_propagation(self, setup_tracer): """Test propagation across different span kinds.""" @@ -287,7 +287,7 @@ def test_mixed_span_kinds_propagation(self, setup_tracer): with start_observe( name="mixed_trace", feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=1.0) + evaluate_config=EvaluationConfig(sample_rate=1.0), ): with observe(name="generation", kind=ObserveKind.GENERATION): pass @@ -309,8 +309,9 @@ def test_mixed_span_kinds_propagation(self, setup_tracer): # All different kinds should have same should_evaluate for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is True, ( f"Span {span.name} (kind={span.attributes.get('basalt.span.kind')}) should have should_evaluate=True" + ) def test_no_evaluate_config_uses_global_default(self, setup_tracer): """Test that without evaluate_config, global default (0.0) is used for all spans.""" @@ -318,7 +319,7 @@ def test_no_evaluate_config_uses_global_default(self, setup_tracer): with start_observe( name="default_trace", - feature_slug="test" + feature_slug="test", # No evaluate_config, should use global default 0.0 ): with observe(name="child1", kind=ObserveKind.FUNCTION): @@ -332,8 +333,9 @@ def test_no_evaluate_config_uses_global_default(self, setup_tracer): # All should be False (global default is 0.0) for span in spans: - assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is False, \ + assert span.attributes.get(BasaltSpan.SHOULD_EVALUATE) is False, ( f"Span {span.name} should have should_evaluate=False with global default" + ) def test_trace_consistency(self, setup_tracer): """Test that all spans in a trace have the SAME should_evaluate value.""" @@ -343,21 +345,19 @@ def test_trace_consistency(self, setup_tracer): with start_observe( name="consistent_trace", feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=1.0) + evaluate_config=EvaluationConfig(sample_rate=1.0), ): with observe(name="child1", kind=ObserveKind.FUNCTION): with observe(name="grandchild", kind=ObserveKind.GENERATION): pass spans = exporter.get_finished_spans() - should_evaluate_values = [ - span.attributes.get(BasaltSpan.SHOULD_EVALUATE) - for span in spans - ] + should_evaluate_values = [span.attributes.get(BasaltSpan.SHOULD_EVALUATE) for span in spans] # All values should be identical - assert len(set(should_evaluate_values)) == 1, \ + assert len(set(should_evaluate_values)) == 1, ( f"All spans should have same should_evaluate value, got: {should_evaluate_values}" + ) assert should_evaluate_values[0] is True @@ -372,7 +372,7 @@ def test_experiment_string_forces_evaluation(self, setup_tracer): name="trace", feature_slug="test", experiment="exp_string", - evaluate_config=EvaluationConfig(sample_rate=0.0) + evaluate_config=EvaluationConfig(sample_rate=0.0), ): with observe(name="child", kind=ObserveKind.FUNCTION): pass @@ -391,7 +391,7 @@ def test_experiment_object_forces_evaluation(self, setup_tracer): exporter = setup_tracer class MockExperiment: - def __init__(self, id, name=None): + def __init__(self, id, name=None) -> None: self.id = id self.name = name @@ -401,7 +401,7 @@ def __init__(self, id, name=None): name="trace", feature_slug="test", experiment=exp, - evaluate_config=EvaluationConfig(sample_rate=0.0) + evaluate_config=EvaluationConfig(sample_rate=0.0), ): with observe(name="child", kind=ObserveKind.FUNCTION): pass @@ -418,7 +418,7 @@ def test_experiment_without_evaluate_config(self, setup_tracer): with start_observe( name="trace", feature_slug="test", - experiment="exp_no_config" + experiment="exp_no_config", # No evaluate_config at all ): with observe(name="child", kind=ObserveKind.FUNCTION): @@ -437,7 +437,7 @@ def test_no_experiment_respects_sample_rate_zero(self, setup_tracer): with start_observe( name="trace", feature_slug="test", - evaluate_config=EvaluationConfig(sample_rate=0.0) + evaluate_config=EvaluationConfig(sample_rate=0.0), # No experiment ): with observe(name="child", kind=ObserveKind.FUNCTION): diff --git a/tests/otel/otlp_utils.py b/tests/otel/otlp_utils.py index a3635c3..326d695 100644 --- a/tests/otel/otlp_utils.py +++ b/tests/otel/otlp_utils.py @@ -13,7 +13,7 @@ from typing import Any -def _to_otel_attribute_value(value: Any) -> dict[str, Any]: +def _to_otel_attribute_value(value: object) -> dict[str, object]: """ Convert a Python attribute value to the OTLP JSON AttributeValue shape. diff --git a/tests/otel/test_llm_instrumentation.py b/tests/otel/test_llm_instrumentation.py index edc8ae1..b2ebb5c 100644 --- a/tests/otel/test_llm_instrumentation.py +++ b/tests/otel/test_llm_instrumentation.py @@ -376,17 +376,21 @@ def test_instrumentation_initialization(provider_name: str, otel_exporter: InMem # Check if the instrumentor was registered if provider_name not in manager._provider_instrumentors: # Instrumentation package not installed - skip test - pkg_name = instrumentation_packages.get(provider_name, f"opentelemetry-instrumentation-{provider_name}") + pkg_name = instrumentation_packages.get( + provider_name, f"opentelemetry-instrumentation-{provider_name}" + ) pytest.skip(f"Instrumentation package not installed: {pkg_name}") instrumentor = manager._provider_instrumentors[provider_name] assert instrumentor is not None, f"Instrumentor for {provider_name} is None" # Verify the instrumentor has the expected interface - assert hasattr(instrumentor, 'instrument'), \ + assert hasattr(instrumentor, "instrument"), ( f"Instrumentor for {provider_name} missing instrument() method" - assert hasattr(instrumentor, 'uninstrument'), \ + ) + assert hasattr(instrumentor, "uninstrument"), ( f"Instrumentor for {provider_name} missing uninstrument() method" + ) finally: # Clean up instrumentation @@ -443,17 +447,21 @@ def test_instrumentation_span_attrs_real_provider( assert "gen_ai.request.model" in attrs, f"Missing gen_ai.request.model for {provider_name}" # With real calls, we should definitely have these - assert ( - "gen_ai.usage.input_tokens" in attrs or "gen_ai.usage.prompt_tokens" in attrs - ), f"Missing input token count for {provider_name} (real call)" + assert "gen_ai.usage.input_tokens" in attrs or "gen_ai.usage.prompt_tokens" in attrs, ( + f"Missing input token count for {provider_name} (real call)" + ) - assert ( - "gen_ai.usage.output_tokens" in attrs or "gen_ai.usage.completion_tokens" in attrs - ), f"Missing output token count for {provider_name} (real call)" + assert "gen_ai.usage.output_tokens" in attrs or "gen_ai.usage.completion_tokens" in attrs, ( + f"Missing output token count for {provider_name} (real call)" + ) # Verify token counts are positive integers - input_tokens = attrs.get("gen_ai.usage.input_tokens") or attrs.get("gen_ai.usage.prompt_tokens") - output_tokens = attrs.get("gen_ai.usage.output_tokens") or attrs.get("gen_ai.usage.completion_tokens") + input_tokens = attrs.get("gen_ai.usage.input_tokens") or attrs.get( + "gen_ai.usage.prompt_tokens" + ) + output_tokens = attrs.get("gen_ai.usage.output_tokens") or attrs.get( + "gen_ai.usage.completion_tokens" + ) assert input_tokens > 0, f"Input tokens should be > 0 for {provider_name}" assert output_tokens > 0, f"Output tokens should be > 0 for {provider_name}" @@ -524,7 +532,9 @@ def test_otlp_json_conversion_helpers(): # Validate resource assert "resource" in resource_span - resource_attrs = {item["key"]: item["value"] for item in resource_span["resource"]["attributes"]} + resource_attrs = { + item["key"]: item["value"] for item in resource_span["resource"]["attributes"] + } assert resource_attrs["service.name"]["stringValue"] == "test-service" # Validate scope @@ -742,21 +752,27 @@ def get_int(key: str) -> int | None: # Verify token usage attributes # Different instrumentors use different attribute names has_input_tokens = "gen_ai.usage.input_tokens" in attrs or "gen_ai.usage.prompt_tokens" in attrs - has_output_tokens = "gen_ai.usage.output_tokens" in attrs or "gen_ai.usage.completion_tokens" in attrs + has_output_tokens = ( + "gen_ai.usage.output_tokens" in attrs or "gen_ai.usage.completion_tokens" in attrs + ) assert has_input_tokens, f"Missing input token count for {provider_name}" assert has_output_tokens, f"Missing output token count for {provider_name}" # Verify token counts are positive integers (from mocked response) input_tokens = get_int("gen_ai.usage.input_tokens") or get_int("gen_ai.usage.prompt_tokens") - output_tokens = get_int("gen_ai.usage.output_tokens") or get_int("gen_ai.usage.completion_tokens") + output_tokens = get_int("gen_ai.usage.output_tokens") or get_int( + "gen_ai.usage.completion_tokens" + ) assert input_tokens and input_tokens > 0, f"Input tokens should be > 0 for {provider_name}" assert output_tokens and output_tokens > 0, f"Output tokens should be > 0 for {provider_name}" # ---- Validate Resource Attributes ---- - resource_attrs_dict = {item["key"]: item["value"] for item in resource_span["resource"]["attributes"]} + resource_attrs_dict = { + item["key"]: item["value"] for item in resource_span["resource"]["attributes"] + } # Verify key resource attributes exist assert "service.name" in resource_attrs_dict, "Missing service.name" @@ -765,9 +781,9 @@ def get_int(key: str) -> int | None: # Verify resource attributes have correct types assert resource_attrs_dict["service.name"].get("stringValue"), "service.name should be a string" - assert ( - resource_attrs_dict["telemetry.sdk.language"].get("stringValue") == "python" - ), "telemetry.sdk.language should be 'python'" + assert resource_attrs_dict["telemetry.sdk.language"].get("stringValue") == "python", ( + "telemetry.sdk.language should be 'python'" + ) # Optional: Print OTLP JSON for debugging or TS test fixture generation # Uncomment to see the full JSON output: diff --git a/tests/prompts/conftest.py b/tests/prompts/conftest.py deleted file mode 100644 index ff1cf55..0000000 --- a/tests/prompts/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Pytest fixtures for prompts tests.""" - -import pytest -from opentelemetry import trace -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - - -@pytest.fixture -def setup_tracing(): - """ - Setup OpenTelemetry tracing for tests. - - This fixture can be used by tests that need real span recording, - ensuring that spans are properly recorded and can be inspected. - - Usage: - def test_something(setup_tracing): - # test code here - """ - # Save the current tracer provider to restore later - original_provider = trace.get_tracer_provider() - - # Create an in-memory exporter - exporter = InMemorySpanExporter() - - # Create a tracer provider with the exporter - provider = TracerProvider() - processor = SimpleSpanProcessor(exporter) - provider.add_span_processor(processor) - - # Set as the global tracer provider - trace.set_tracer_provider(provider) - - yield exporter - - # Clean up after test - exporter.clear() - provider.shutdown() - - # Restore original tracer provider - if original_provider and not isinstance(original_provider, trace.ProxyTracerProvider): - trace.set_tracer_provider(original_provider) diff --git a/tests/prompts/test_client.py b/tests/prompts/test_client.py index 8eb3441..2730072 100644 --- a/tests/prompts/test_client.py +++ b/tests/prompts/test_client.py @@ -3,6 +3,7 @@ These tests were converted from unittest to pytest. They keep the same behaviour but use pytest fixtures, parametrization and asyncio support. """ + from unittest.mock import MagicMock, patch import pytest @@ -78,25 +79,30 @@ def test_get_sync_success(common_client): fallback_cache = common_client["fallback_cache"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Hello {{name}}", - "slug": "test-slug", - "version": "1.0.0", - "tag": "prod", - "systemText": "You are a helpful assistant", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", - "topP": 0.5, + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are a helpful assistant", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + "topP": 0.5, + }, + }, }, - }, - }}) + } + ) prompt = client.get_sync("test-slug", version="1.0.0", tag="prod") @@ -109,6 +115,7 @@ def test_get_sync_success(common_client): # Verify prompt object (wrapped in PromptContextManager) from basalt.prompts.models import PromptContextManager + assert isinstance(prompt, PromptContextManager) # Verify attributes are forwarded correctly assert prompt.slug == "test-slug" @@ -124,24 +131,29 @@ def test_get_sync_with_variables(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Hello {{name}}", - "slug": "test-slug", - "version": "1.0.0", - "tag": "prod", - "systemText": "You are {{role}}", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are {{role}}", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + }, + }, }, - }, - }}) + } + ) variables = {"name": "World", "role": "helpful"} prompt = client.get_sync("test-slug", variables=variables) @@ -175,24 +187,29 @@ def test_get_sync_cache_disabled(common_client): fallback_cache = common_client["fallback_cache"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Hello", - "slug": "test-slug", - "version": "1.0.0", - "tag": "prod", - "systemText": "", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Hello", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + }, + }, }, - }, - }}) + } + ) # Set cache to have a value cache.get.return_value = common_client["mock_prompt_response"] @@ -239,15 +256,26 @@ def test_describe_sync_success(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "slug": "test-slug", - "status": "live", - "name": "Test Prompt", - "description": "A test prompt", - "availableVersions": ["1.0.0", "1.1.0"], - "availableTags": ["latest", "production", "staging"], - "variables": [{"label": "Name", "type": "string", "description": "This is a description of the variable"}], - }}) + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "slug": "test-slug", + "status": "live", + "name": "Test Prompt", + "description": "A test prompt", + "availableVersions": ["1.0.0", "1.1.0"], + "availableTags": ["latest", "production", "staging"], + "variables": [ + { + "label": "Name", + "type": "string", + "description": "This is a description of the variable", + } + ], + }, + } + ) response = client.describe_sync("test-slug", version="1.0.0") @@ -273,24 +301,28 @@ def test_list_sync_success(common_client): client = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"prompts": [ + mock_fetch.return_value = make_response( { - "slug": "prompt-1", - "status": "live", - "name": "Prompt 1", - "description": "First prompt", - "availableVersions": ["1.0.0"], - "availableTags": ["latest"], - }, - { - "slug": "prompt-2", - "status": "live", - "name": "Prompt 2", - "description": "Second prompt", - "availableVersions": ["2.0.0"], - "availableTags": ["production"], - }, - ]}) + "prompts": [ + { + "slug": "prompt-1", + "status": "live", + "name": "Prompt 1", + "description": "First prompt", + "availableVersions": ["1.0.0"], + "availableTags": ["latest"], + }, + { + "slug": "prompt-2", + "status": "live", + "name": "Prompt 2", + "description": "Second prompt", + "availableVersions": ["2.0.0"], + "availableTags": ["production"], + }, + ] + } + ) prompts = client.list_sync() @@ -335,24 +367,29 @@ def test_get_sync_parameter_combinations(common_client, slug, version, tag): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Test", - "slug": slug, - "version": version or "latest", - "tag": tag or "default", - "systemText": "", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Test", + "slug": slug, + "version": version or "latest", + "tag": tag or "default", + "systemText": "", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + }, + }, }, - }, - }}) + } + ) client.get_sync(slug, version=version, tag=tag) @@ -370,31 +407,37 @@ async def test_get_async_success(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Hello {{name}}", - "slug": "test-slug", - "version": "1.0.0", - "tag": "prod", - "systemText": "You are a helpful assistant", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", - "topP": 0.5, + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are a helpful assistant", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + "topP": 0.5, + }, + }, }, - }, - }}) + } + ) prompt = await client.get("test-slug", version="1.0.0") mock_fetch.assert_called_once() # Verify prompt object (wrapped in AsyncPromptContextManager) from basalt.prompts.models import AsyncPromptContextManager + assert isinstance(prompt, AsyncPromptContextManager) # Verify attributes are forwarded correctly assert prompt.slug == "test-slug" @@ -434,15 +477,20 @@ async def test_describe_async_success(common_client): client = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "slug": "test-slug", - "status": "live", - "name": "Test Prompt", - "description": "A test prompt", - "availableVersions": ["1.0.0"], - "availableTags": ["latest", "production", "staging"], - "variables": [], - }}) + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "slug": "test-slug", + "status": "live", + "name": "Test Prompt", + "description": "A test prompt", + "availableVersions": ["1.0.0"], + "availableTags": ["latest", "production", "staging"], + "variables": [], + }, + } + ) response = await client.describe("test-slug") @@ -455,16 +503,20 @@ async def test_list_async_success(common_client): client = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({"prompts": [ + mock_fetch.return_value = make_response( { - "slug": "prompt-1", - "status": "live", - "name": "Prompt 1", - "description": "First prompt", - "availableVersions": ["1.0.0"], - "availableTags": ["latest"], - }, - ]}) + "prompts": [ + { + "slug": "prompt-1", + "status": "live", + "name": "Prompt 1", + "description": "First prompt", + "availableVersions": ["1.0.0"], + "availableTags": ["latest"], + }, + ] + } + ) prompts = await client.list() @@ -477,24 +529,29 @@ async def test_get_async_with_variables(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({"warning": "", "prompt": { - "text": "Hello {{name}}", - "slug": "test-slug", - "version": "1.0.0", - "tag": "prod", - "systemText": "You are {{role}}", - "model": { - "provider": "open-ai", - "client": "open-ai", - "model": "gpt-4o", - "version": "latest", - "parameters": { - "temperature": 0.5, - "maxLength": 100, - "responseFormat": "json_object", + mock_fetch.return_value = make_response( + { + "warning": "", + "prompt": { + "text": "Hello {{name}}", + "slug": "test-slug", + "version": "1.0.0", + "tag": "prod", + "systemText": "You are {{role}}", + "model": { + "provider": "open-ai", + "client": "open-ai", + "model": "gpt-4o", + "version": "latest", + "parameters": { + "temperature": 0.5, + "maxLength": 100, + "responseFormat": "json_object", + }, + }, }, - }, - }}) + } + ) variables = {"name": "Alice", "role": "assistant"} prompt = await client.get("test-slug", variables=variables) @@ -542,12 +599,14 @@ def test_publish_prompt_sync_success(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-123", - "label": "production", + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-123", + "label": "production", + } } - }) + ) response = client.publish_sync( slug="test-slug", @@ -572,12 +631,14 @@ def test_publish_prompt_sync_with_tag(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-456", - "label": "staging", + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-456", + "label": "staging", + } } - }) + ) response = client.publish_sync( slug="test-slug", @@ -600,12 +661,14 @@ def test_publish_prompt_sync_minimal(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-789", - "label": "latest", + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-789", + "label": "latest", + } } - }) + ) response = client.publish_sync( slug="test-slug", @@ -638,12 +701,14 @@ async def test_publish_prompt_async_success(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-async-123", - "label": "production", + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-async-123", + "label": "production", + } } - }) + ) response = await client.publish( slug="test-slug", @@ -669,12 +734,14 @@ async def test_publish_prompt_async_with_both_version_and_tag(common_client): client: PromptsClient = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-both", - "label": "release", + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-both", + "label": "release", + } } - }) + ) response = await client.publish( slug="test-slug", @@ -718,12 +785,14 @@ def test_publish_prompt_sync_parameter_combinations(common_client, slug, new_tag client = common_client["client"] with patch("basalt.prompts.client.HTTPClient.fetch_sync") as mock_fetch: - mock_fetch.return_value = make_response({ - "deploymentTag": { - "id": "tag-param-test", - "label": new_tag, + mock_fetch.return_value = make_response( + { + "deploymentTag": { + "id": "tag-param-test", + "label": new_tag, + } } - }) + ) client.publish_sync( slug=slug, diff --git a/tests/prompts/test_context_manager.py b/tests/prompts/test_context_manager.py index b5e53ed..ca6c514 100644 --- a/tests/prompts/test_context_manager.py +++ b/tests/prompts/test_context_manager.py @@ -4,6 +4,7 @@ This module tests the context manager wrappers for prompts that enable observability spans for prompt fetches and GenAI call scoping. """ + from unittest.mock import MagicMock import pytest @@ -204,6 +205,7 @@ def test_wrapper_repr_and_str(mock_prompt): assert "Prompt" in repr(wrapper) assert "Prompt" in str(wrapper) + @pytest.mark.asyncio async def test_async_prompt_context_manager_sets_in_trace_attribute(mock_prompt): """Test that async prompt context manager uses _set_span_attributes which includes basalt.in_trace. diff --git a/tests/test_client_telemetry.py b/tests/test_client_telemetry.py index 71a70af..05af7b4 100644 --- a/tests/test_client_telemetry.py +++ b/tests/test_client_telemetry.py @@ -2,7 +2,6 @@ from __future__ import annotations -import unittest from unittest import mock from basalt.client import Basalt @@ -10,33 +9,40 @@ from basalt.observability.instrumentation import InstrumentationManager -class TestBasaltClientTelemetry(unittest.TestCase): - @mock.patch.object(InstrumentationManager, "initialize") - def test_enable_telemetry_false_disables_config(self, mock_initialize): - client = Basalt(api_key="key", enable_telemetry=False) +def test_enable_telemetry_false_disables_config(monkeypatch): + mock_initialize = mock.Mock() + monkeypatch.setattr(InstrumentationManager, "initialize", mock_initialize) - self.assertTrue(mock_initialize.called) - config_arg = mock_initialize.call_args[0][0] - self.assertFalse(config_arg.enabled) - self.assertEqual(mock_initialize.call_args.kwargs["api_key"], "key") + client = Basalt(api_key="key", enable_telemetry=False) - client.shutdown() + assert mock_initialize.called + config_arg = mock_initialize.call_args[0][0] + assert not config_arg.enabled + assert mock_initialize.call_args.kwargs["api_key"] == "key" - @mock.patch.object(InstrumentationManager, "shutdown") - @mock.patch.object(InstrumentationManager, "initialize") - def test_shutdown_invokes_instrumentation(self, mock_initialize, mock_shutdown): - client = Basalt(api_key="key") + client.shutdown() - client.shutdown() - mock_initialize.assert_called_once() - mock_shutdown.assert_called_once() +def test_shutdown_invokes_instrumentation(monkeypatch): + mock_initialize = mock.Mock() + mock_shutdown = mock.Mock() + monkeypatch.setattr(InstrumentationManager, "initialize", mock_initialize) + monkeypatch.setattr(InstrumentationManager, "shutdown", mock_shutdown) - @mock.patch.object(InstrumentationManager, "initialize") - def test_custom_telemetry_config_passed_through(self, mock_initialize): - telemetry = TelemetryConfig(service_name="custom") + client = Basalt(api_key="key") - Basalt(api_key="key", telemetry_config=telemetry) + client.shutdown() - mock_initialize.assert_called_once_with(telemetry, api_key="key") + mock_initialize.assert_called_once() + mock_shutdown.assert_called_once() + +def test_custom_telemetry_config_passed_through(monkeypatch): + mock_initialize = mock.Mock() + monkeypatch.setattr(InstrumentationManager, "initialize", mock_initialize) + + telemetry = TelemetryConfig(service_name="custom") + + Basalt(api_key="key", telemetry_config=telemetry) + + mock_initialize.assert_called_once_with(telemetry, api_key="key") diff --git a/tests/test_observe_decorators.py b/tests/test_observe_decorators.py index b0a5352..3e6f244 100644 --- a/tests/test_observe_decorators.py +++ b/tests/test_observe_decorators.py @@ -32,6 +32,7 @@ def test_decorator_definitions(): # Test that observe accepts kind parameter import inspect + sig = inspect.signature(observe) assert "kind" in sig.parameters @@ -40,7 +41,6 @@ def test_decorator_definitions(): return False - def test_basic_usage(): """Test basic decorator usage without actual execution.""" try: @@ -79,7 +79,6 @@ def main(): results = [test() for test in tests] - return all(results) diff --git a/uv.lock b/uv.lock index 9d9ade3..dd0ff80 100644 --- a/uv.lock +++ b/uv.lock @@ -71,9 +71,9 @@ wheels = [ [[package]] name = "basalt-sdk" -version = "1.1.0" source = { editable = "." } dependencies = [ + { name = "black" }, { name = "httpx" }, { name = "jinja2" }, { name = "opentelemetry-api" }, @@ -133,9 +133,6 @@ framework-all = [ { name = "opentelemetry-instrumentation-langchain" }, { name = "opentelemetry-instrumentation-llamaindex" }, ] -google-genai = [ - { name = "opentelemetry-instrumentation-google-genai" }, -] google-generativeai = [ { name = "opentelemetry-instrumentation-google-generativeai" }, ] @@ -177,10 +174,7 @@ vertex-ai = [ [package.metadata] requires-dist = [ { name = "anthropic", marker = "extra == 'dev'" }, - { name = "basalt-sdk", extras = ["chromadb", "pinecone", "qdrant"], marker = "extra == 'vector-all'" }, - { name = "basalt-sdk", extras = ["langchain", "llamaindex", "haystack"], marker = "extra == 'framework-all'" }, - { name = "basalt-sdk", extras = ["llm-all", "vector-all", "framework-all"], marker = "extra == 'all'" }, - { name = "basalt-sdk", extras = ["openai", "anthropic", "google-generativeai", "bedrock", "vertex-ai", "mistralai"], marker = "extra == 'llm-all'" }, + { name = "black", specifier = ">=26.1.0" }, { name = "coverage", marker = "extra == 'dev'" }, { name = "google-genai", marker = "extra == 'dev'" }, { name = "httpx", specifier = ">=0.28.1" }, @@ -190,23 +184,44 @@ requires-dist = [ { name = "opentelemetry-api", specifier = "~=1.39.1" }, { name = "opentelemetry-exporter-otlp", specifier = "~=1.39.1" }, { name = "opentelemetry-instrumentation", specifier = "~=0.59b0" }, + { name = "opentelemetry-instrumentation-anthropic", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-anthropic", marker = "extra == 'anthropic'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-anthropic", marker = "extra == 'dev'" }, + { name = "opentelemetry-instrumentation-anthropic", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-bedrock", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-bedrock", marker = "extra == 'bedrock'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-bedrock", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-chromadb", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-chromadb", marker = "extra == 'chromadb'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-chromadb", marker = "extra == 'vector-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-google-genai", marker = "extra == 'dev'" }, - { name = "opentelemetry-instrumentation-google-genai", marker = "extra == 'google-genai'", specifier = "~=0.5b0" }, + { name = "opentelemetry-instrumentation-google-generativeai", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-google-generativeai", marker = "extra == 'dev'" }, { name = "opentelemetry-instrumentation-google-generativeai", marker = "extra == 'google-generativeai'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-google-generativeai", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-httpx", specifier = "~=0.59b0" }, + { name = "opentelemetry-instrumentation-langchain", marker = "extra == 'all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-langchain", marker = "extra == 'framework-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-langchain", marker = "extra == 'langchain'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-llamaindex", marker = "extra == 'all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-llamaindex", marker = "extra == 'framework-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-llamaindex", marker = "extra == 'llamaindex'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-mistralai", marker = "extra == 'all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-mistralai", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-mistralai", marker = "extra == 'mistralai'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-openai", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-openai", marker = "extra == 'dev'" }, + { name = "opentelemetry-instrumentation-openai", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-openai", marker = "extra == 'openai'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-pinecone", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-pinecone", marker = "extra == 'pinecone'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-pinecone", marker = "extra == 'vector-all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-qdrant", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-qdrant", marker = "extra == 'qdrant'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-qdrant", marker = "extra == 'vector-all'", specifier = "~=0.51.0" }, + { name = "opentelemetry-instrumentation-vertexai", marker = "extra == 'all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-vertexai", marker = "extra == 'dev'" }, + { name = "opentelemetry-instrumentation-vertexai", marker = "extra == 'llm-all'", specifier = "~=0.51.0" }, { name = "opentelemetry-instrumentation-vertexai", marker = "extra == 'vertex-ai'", specifier = "~=0.51.0" }, { name = "opentelemetry-sdk", specifier = "~=1.39.1" }, { name = "opentelemetry-semantic-conventions", specifier = "~=0.59b0" }, @@ -221,7 +236,51 @@ requires-dist = [ { name = "wheel", marker = "extra == 'dev'" }, { name = "wrapt", specifier = "~=1.17.3" }, ] -provides-extras = ["openai", "anthropic", "google-generativeai", "google-genai", "bedrock", "vertex-ai", "mistralai", "chromadb", "pinecone", "qdrant", "langchain", "llamaindex", "llm-all", "vector-all", "framework-all", "all", "dev"] +provides-extras = ["all", "anthropic", "bedrock", "chromadb", "dev", "framework-all", "google-generativeai", "langchain", "llamaindex", "llm-all", "mistralai", "openai", "pinecone", "qdrant", "vector-all", "vertex-ai"] + +[[package]] +name = "black" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, + { name = "pytokens" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/88/560b11e521c522440af991d46848a2bde64b5f7202ec14e1f46f9509d328/black-26.1.0.tar.gz", hash = "sha256:d294ac3340eef9c9eb5d29288e96dc719ff269a88e27b396340459dd85da4c58", size = 658785, upload-time = "2026-01-18T04:50:11.993Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/1b/523329e713f965ad0ea2b7a047eeb003007792a0353622ac7a8cb2ee6fef/black-26.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ca699710dece84e3ebf6e92ee15f5b8f72870ef984bf944a57a777a48357c168", size = 1849661, upload-time = "2026-01-18T04:59:12.425Z" }, + { url = "https://files.pythonhosted.org/packages/14/82/94c0640f7285fa71c2f32879f23e609dd2aa39ba2641f395487f24a578e7/black-26.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e8e75dabb6eb83d064b0db46392b25cabb6e784ea624219736e8985a6b3675d", size = 1689065, upload-time = "2026-01-18T04:59:13.993Z" }, + { url = "https://files.pythonhosted.org/packages/f0/78/474373cbd798f9291ed8f7107056e343fd39fef42de4a51c7fd0d360840c/black-26.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb07665d9a907a1a645ee41a0df8a25ffac8ad9c26cdb557b7b88eeeeec934e0", size = 1751502, upload-time = "2026-01-18T04:59:15.971Z" }, + { url = "https://files.pythonhosted.org/packages/29/89/59d0e350123f97bc32c27c4d79563432d7f3530dca2bff64d855c178af8b/black-26.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:7ed300200918147c963c87700ccf9966dceaefbbb7277450a8d646fc5646bf24", size = 1400102, upload-time = "2026-01-18T04:59:17.8Z" }, + { url = "https://files.pythonhosted.org/packages/e1/bc/5d866c7ae1c9d67d308f83af5462ca7046760158bbf142502bad8f22b3a1/black-26.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:c5b7713daea9bf943f79f8c3b46f361cc5229e0e604dcef6a8bb6d1c37d9df89", size = 1207038, upload-time = "2026-01-18T04:59:19.543Z" }, + { url = "https://files.pythonhosted.org/packages/30/83/f05f22ff13756e1a8ce7891db517dbc06200796a16326258268f4658a745/black-26.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3cee1487a9e4c640dc7467aaa543d6c0097c391dc8ac74eb313f2fbf9d7a7cb5", size = 1831956, upload-time = "2026-01-18T04:59:21.38Z" }, + { url = "https://files.pythonhosted.org/packages/7d/f2/b2c570550e39bedc157715e43927360312d6dd677eed2cc149a802577491/black-26.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d62d14ca31c92adf561ebb2e5f2741bf8dea28aef6deb400d49cca011d186c68", size = 1672499, upload-time = "2026-01-18T04:59:23.257Z" }, + { url = "https://files.pythonhosted.org/packages/7a/d7/990d6a94dc9e169f61374b1c3d4f4dd3037e93c2cc12b6f3b12bc663aa7b/black-26.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb1dafbbaa3b1ee8b4550a84425aac8874e5f390200f5502cf3aee4a2acb2f14", size = 1735431, upload-time = "2026-01-18T04:59:24.729Z" }, + { url = "https://files.pythonhosted.org/packages/36/1c/cbd7bae7dd3cb315dfe6eeca802bb56662cc92b89af272e014d98c1f2286/black-26.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:101540cb2a77c680f4f80e628ae98bd2bd8812fb9d72ade4f8995c5ff019e82c", size = 1400468, upload-time = "2026-01-18T04:59:27.381Z" }, + { url = "https://files.pythonhosted.org/packages/59/b1/9fe6132bb2d0d1f7094613320b56297a108ae19ecf3041d9678aec381b37/black-26.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:6f3977a16e347f1b115662be07daa93137259c711e526402aa444d7a88fdc9d4", size = 1207332, upload-time = "2026-01-18T04:59:28.711Z" }, + { url = "https://files.pythonhosted.org/packages/f5/13/710298938a61f0f54cdb4d1c0baeb672c01ff0358712eddaf29f76d32a0b/black-26.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6eeca41e70b5f5c84f2f913af857cf2ce17410847e1d54642e658e078da6544f", size = 1878189, upload-time = "2026-01-18T04:59:30.682Z" }, + { url = "https://files.pythonhosted.org/packages/79/a6/5179beaa57e5dbd2ec9f1c64016214057b4265647c62125aa6aeffb05392/black-26.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd39eef053e58e60204f2cdf059e2442e2eb08f15989eefe259870f89614c8b6", size = 1700178, upload-time = "2026-01-18T04:59:32.387Z" }, + { url = "https://files.pythonhosted.org/packages/8c/04/c96f79d7b93e8f09d9298b333ca0d31cd9b2ee6c46c274fd0f531de9dc61/black-26.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9459ad0d6cd483eacad4c6566b0f8e42af5e8b583cee917d90ffaa3778420a0a", size = 1777029, upload-time = "2026-01-18T04:59:33.767Z" }, + { url = "https://files.pythonhosted.org/packages/49/f9/71c161c4c7aa18bdda3776b66ac2dc07aed62053c7c0ff8bbda8c2624fe2/black-26.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a19915ec61f3a8746e8b10adbac4a577c6ba9851fa4a9e9fbfbcf319887a5791", size = 1406466, upload-time = "2026-01-18T04:59:35.177Z" }, + { url = "https://files.pythonhosted.org/packages/4a/8b/a7b0f974e473b159d0ac1b6bcefffeb6bec465898a516ee5cc989503cbc7/black-26.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:643d27fb5facc167c0b1b59d0315f2674a6e950341aed0fc05cf307d22bf4954", size = 1216393, upload-time = "2026-01-18T04:59:37.18Z" }, + { url = "https://files.pythonhosted.org/packages/79/04/fa2f4784f7237279332aa735cdfd5ae2e7730db0072fb2041dadda9ae551/black-26.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ba1d768fbfb6930fc93b0ecc32a43d8861ded16f47a40f14afa9bb04ab93d304", size = 1877781, upload-time = "2026-01-18T04:59:39.054Z" }, + { url = "https://files.pythonhosted.org/packages/cf/ad/5a131b01acc0e5336740a039628c0ab69d60cf09a2c87a4ec49f5826acda/black-26.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b807c240b64609cb0e80d2200a35b23c7df82259f80bef1b2c96eb422b4aac9", size = 1699670, upload-time = "2026-01-18T04:59:41.005Z" }, + { url = "https://files.pythonhosted.org/packages/da/7c/b05f22964316a52ab6b4265bcd52c0ad2c30d7ca6bd3d0637e438fc32d6e/black-26.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1de0f7d01cc894066a1153b738145b194414cc6eeaad8ef4397ac9abacf40f6b", size = 1775212, upload-time = "2026-01-18T04:59:42.545Z" }, + { url = "https://files.pythonhosted.org/packages/a6/a3/e8d1526bea0446e040193185353920a9506eab60a7d8beb062029129c7d2/black-26.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:91a68ae46bf07868963671e4d05611b179c2313301bd756a89ad4e3b3db2325b", size = 1409953, upload-time = "2026-01-18T04:59:44.357Z" }, + { url = "https://files.pythonhosted.org/packages/c7/5a/d62ebf4d8f5e3a1daa54adaab94c107b57be1b1a2f115a0249b41931e188/black-26.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:be5e2fe860b9bd9edbf676d5b60a9282994c03fbbd40fe8f5e75d194f96064ca", size = 1217707, upload-time = "2026-01-18T04:59:45.719Z" }, + { url = "https://files.pythonhosted.org/packages/6a/83/be35a175aacfce4b05584ac415fd317dd6c24e93a0af2dcedce0f686f5d8/black-26.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9dc8c71656a79ca49b8d3e2ce8103210c9481c57798b48deeb3a8bb02db5f115", size = 1871864, upload-time = "2026-01-18T04:59:47.586Z" }, + { url = "https://files.pythonhosted.org/packages/a5/f5/d33696c099450b1274d925a42b7a030cd3ea1f56d72e5ca8bbed5f52759c/black-26.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b22b3810451abe359a964cc88121d57f7bce482b53a066de0f1584988ca36e79", size = 1701009, upload-time = "2026-01-18T04:59:49.443Z" }, + { url = "https://files.pythonhosted.org/packages/1b/87/670dd888c537acb53a863bc15abbd85b22b429237d9de1b77c0ed6b79c42/black-26.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:53c62883b3f999f14e5d30b5a79bd437236658ad45b2f853906c7cbe79de00af", size = 1767806, upload-time = "2026-01-18T04:59:50.769Z" }, + { url = "https://files.pythonhosted.org/packages/fe/9c/cd3deb79bfec5bcf30f9d2100ffeec63eecce826eb63e3961708b9431ff1/black-26.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:f016baaadc423dc960cdddf9acae679e71ee02c4c341f78f3179d7e4819c095f", size = 1433217, upload-time = "2026-01-18T04:59:52.218Z" }, + { url = "https://files.pythonhosted.org/packages/4e/29/f3be41a1cf502a283506f40f5d27203249d181f7a1a2abce1c6ce188035a/black-26.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:66912475200b67ef5a0ab665011964bf924745103f51977a78b4fb92a9fc1bf0", size = 1245773, upload-time = "2026-01-18T04:59:54.457Z" }, + { url = "https://files.pythonhosted.org/packages/e4/3d/51bdb3ecbfadfaf825ec0c75e1de6077422b4afa2091c6c9ba34fbfc0c2d/black-26.1.0-py3-none-any.whl", hash = "sha256:1054e8e47ebd686e078c0bb0eaf31e6ce69c966058d122f2c0c950311f9f3ede", size = 204010, upload-time = "2026-01-18T04:50:09.978Z" }, +] [[package]] name = "cachetools" @@ -1163,6 +1222,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" }, ] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "nh3" version = "0.3.1" @@ -1592,6 +1660,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/2f/804f58f0b856ab3bf21617cccf5b39206e6c4c94c2cd227bde125ea6105f/parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b", size = 20475, upload-time = "2023-03-27T02:01:09.31Z" }, ] +[[package]] +name = "pathspec" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/b2/bb8e495d5262bfec41ab5cb18f522f1012933347fb5d9e62452d446baca2/pathspec-1.0.3.tar.gz", hash = "sha256:bac5cf97ae2c2876e2d25ebb15078eb04d76e4b98921ee31c6f85ade8b59444d", size = 130841, upload-time = "2026-01-09T15:46:46.009Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/2b/121e912bd60eebd623f873fd090de0e84f322972ab25a7f9044c056804ed/pathspec-1.0.3-py3-none-any.whl", hash = "sha256:e80767021c1cc524aa3fb14bedda9c34406591343cc42797b386ce7b9354fb6c", size = 55021, upload-time = "2026-01-09T15:46:44.652Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/86/0248f086a84f01b37aaec0fa567b397df1a119f73c16f6c7a9aac73ea309/platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda", size = 21715, upload-time = "2025-12-05T13:52:58.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1842,6 +1928,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "pytokens" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/16/4b9cfd90d55e66ffdb277d7ebe3bc25250c2311336ec3fc73b2673c794d5/pytokens-0.4.0.tar.gz", hash = "sha256:6b0b03e6ea7c9f9d47c5c61164b69ad30f4f0d70a5d9fe7eac4d19f24f77af2d", size = 15039, upload-time = "2026-01-19T07:59:50.623Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/c5/c20818fef16c4ab5f9fd7bad699268ba21bf24f655711df4e33bb7a9ab47/pytokens-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:af0c3166aea367a9e755a283171befb92dd3043858b94ae9b3b7efbe9def26a3", size = 160682, upload-time = "2026-01-19T07:58:51.583Z" }, + { url = "https://files.pythonhosted.org/packages/46/c4/ad03e4abe05c6af57c4d7f8f031fafe80f0074796d09ab5a73bf2fac895f/pytokens-0.4.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:daae524ed14ca459932cbf51d74325bea643701ba8a8b0cc2d10f7cd4b3e2b63", size = 245748, upload-time = "2026-01-19T07:58:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/6b/b9/4a7ee0a692603b16d8fdfbc5c44e0f6910d45eec6b2c2188daa4670f179d/pytokens-0.4.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e95cb158c44d642ed62f555bf8136bbe780dbd64d2fb0b9169e11ffb944664c3", size = 258671, upload-time = "2026-01-19T07:58:55.667Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a3/02bb29dc4985fb8d759d9c96f189c3a828e74f0879fdb843e9fb7a1db637/pytokens-0.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:df58d44630eaf25f587540e94bdf1fc50b4e6d5f212c786de0fb024bfcb8753a", size = 261749, upload-time = "2026-01-19T07:58:57.442Z" }, + { url = "https://files.pythonhosted.org/packages/10/f2/9a8bdcc5444d85d4dba4aa1b530d81af3edc4a9ab76bf1d53ea8bfe8479d/pytokens-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55efcc36f9a2e0e930cfba0ce7f83445306b02f8326745585ed5551864eba73a", size = 102805, upload-time = "2026-01-19T07:58:59.068Z" }, + { url = "https://files.pythonhosted.org/packages/b4/05/3196399a353dd4cd99138a88f662810979ee2f1a1cdb0b417cb2f4507836/pytokens-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:92eb3ef88f27c22dc9dbab966ace4d61f6826e02ba04dac8e2d65ea31df56c8e", size = 160075, upload-time = "2026-01-19T07:59:00.316Z" }, + { url = "https://files.pythonhosted.org/packages/28/1d/c8fc4ed0a1c4f660391b201cda00b1d5bbcc00e2998e8bcd48b15eefd708/pytokens-0.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4b77858a680635ee9904306f54b0ee4781effb89e211ba0a773d76539537165", size = 247318, upload-time = "2026-01-19T07:59:01.636Z" }, + { url = "https://files.pythonhosted.org/packages/8e/0e/53e55ba01f3e858d229cd84b02481542f42ba59050483a78bf2447ee1af7/pytokens-0.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25cacc20c2ad90acb56f3739d87905473c54ca1fa5967ffcd675463fe965865e", size = 259752, upload-time = "2026-01-19T07:59:04.229Z" }, + { url = "https://files.pythonhosted.org/packages/dc/56/2d930d7f899e3f21868ca6e8ec739ac31e8fc532f66e09cbe45d3df0a84f/pytokens-0.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:628fab535ebc9079e4db35cd63cb401901c7ce8720a9834f9ad44b9eb4e0f1d4", size = 262842, upload-time = "2026-01-19T07:59:06.14Z" }, + { url = "https://files.pythonhosted.org/packages/42/dd/4e7e6920d23deffaf66e6f40d45f7610dcbc132ca5d90ab4faccef22f624/pytokens-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:4d0f568d7e82b7e96be56d03b5081de40e43c904eb6492bf09aaca47cd55f35b", size = 102620, upload-time = "2026-01-19T07:59:07.839Z" }, + { url = "https://files.pythonhosted.org/packages/3d/65/65460ebbfefd0bc1b160457904370d44f269e6e4582e0a9b6cba7c267b04/pytokens-0.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cd8da894e5a29ba6b6da8be06a4f7589d7220c099b5e363cb0643234b9b38c2a", size = 159864, upload-time = "2026-01-19T07:59:08.908Z" }, + { url = "https://files.pythonhosted.org/packages/25/70/a46669ec55876c392036b4da9808b5c3b1c5870bbca3d4cc923bf68bdbc1/pytokens-0.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:237ba7cfb677dbd3b01b09860810aceb448871150566b93cd24501d5734a04b1", size = 254448, upload-time = "2026-01-19T07:59:10.594Z" }, + { url = "https://files.pythonhosted.org/packages/62/0b/c486fc61299c2fc3b7f88ee4e115d4c8b6ffd1a7f88dc94b398b5b1bc4b8/pytokens-0.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01d1a61e36812e4e971cfe2c0e4c1f2d66d8311031dac8bf168af8a249fa04dd", size = 268863, upload-time = "2026-01-19T07:59:12.31Z" }, + { url = "https://files.pythonhosted.org/packages/79/92/b036af846707d25feaff7cafbd5280f1bd6a1034c16bb06a7c910209c1ab/pytokens-0.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e47e2ef3ec6ee86909e520d79f965f9b23389fda47460303cf715d510a6fe544", size = 267181, upload-time = "2026-01-19T07:59:13.856Z" }, + { url = "https://files.pythonhosted.org/packages/0d/c0/6d011fc00fefa74ce34816c84a923d2dd7c46b8dbc6ee52d13419786834c/pytokens-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3d36954aba4557fd5a418a03cf595ecbb1cdcce119f91a49b19ef09d691a22ae", size = 102814, upload-time = "2026-01-19T07:59:15.288Z" }, + { url = "https://files.pythonhosted.org/packages/98/63/627b7e71d557383da5a97f473ad50f8d9c2c1f55c7d3c2531a120c796f6e/pytokens-0.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73eff3bdd8ad08da679867992782568db0529b887bed4c85694f84cdf35eafc6", size = 159744, upload-time = "2026-01-19T07:59:16.88Z" }, + { url = "https://files.pythonhosted.org/packages/28/d7/16f434c37ec3824eba6bcb6e798e5381a8dc83af7a1eda0f95c16fe3ade5/pytokens-0.4.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d97cc1f91b1a8e8ebccf31c367f28225699bea26592df27141deade771ed0afb", size = 253207, upload-time = "2026-01-19T07:59:18.069Z" }, + { url = "https://files.pythonhosted.org/packages/ab/96/04102856b9527701ae57d74a6393d1aca5bad18a1b1ca48ccffb3c93b392/pytokens-0.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a2c8952c537cb73a1a74369501a83b7f9d208c3cf92c41dd88a17814e68d48ce", size = 267452, upload-time = "2026-01-19T07:59:19.328Z" }, + { url = "https://files.pythonhosted.org/packages/0e/ef/0936eb472b89ab2d2c2c24bb81c50417e803fa89c731930d9fb01176fe9f/pytokens-0.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5dbf56f3c748aed9310b310d5b8b14e2c96d3ad682ad5a943f381bdbbdddf753", size = 265965, upload-time = "2026-01-19T07:59:20.613Z" }, + { url = "https://files.pythonhosted.org/packages/ae/f5/64f3d6f7df4a9e92ebda35ee85061f6260e16eac82df9396020eebbca775/pytokens-0.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:e131804513597f2dff2b18f9911d9b6276e21ef3699abeffc1c087c65a3d975e", size = 102813, upload-time = "2026-01-19T07:59:22.012Z" }, + { url = "https://files.pythonhosted.org/packages/5f/f1/d07e6209f18ef378fc2ae9dee8d1dfe91fd2447c2e2dbfa32867b6dd30cf/pytokens-0.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0d7374c917197106d3c4761374718bc55ea2e9ac0fb94171588ef5840ee1f016", size = 159968, upload-time = "2026-01-19T07:59:23.07Z" }, + { url = "https://files.pythonhosted.org/packages/0a/73/0eb111400abd382a04f253b269819db9fcc748aa40748441cebdcb6d068f/pytokens-0.4.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cd3fa1caf9e47a72ee134a29ca6b5bea84712724bba165d6628baa190c6ea5b", size = 253373, upload-time = "2026-01-19T07:59:24.381Z" }, + { url = "https://files.pythonhosted.org/packages/bd/8d/9e4e2fdb5bcaba679e54afcc304e9f13f488eb4d626e6b613f9553e03dbd/pytokens-0.4.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c6986576b7b07fe9791854caa5347923005a80b079d45b63b0be70d50cce5f1", size = 267024, upload-time = "2026-01-19T07:59:25.74Z" }, + { url = "https://files.pythonhosted.org/packages/cb/b7/e0a370321af2deb772cff14ff337e1140d1eac2c29a8876bfee995f486f0/pytokens-0.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9940f7c2e2f54fb1cb5fe17d0803c54da7a2bf62222704eb4217433664a186a7", size = 270912, upload-time = "2026-01-19T07:59:27.072Z" }, + { url = "https://files.pythonhosted.org/packages/7c/54/4348f916c440d4c3e68b53b4ed0e66b292d119e799fa07afa159566dcc86/pytokens-0.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:54691cf8f299e7efabcc25adb4ce715d3cef1491e1c930eaf555182f898ef66a", size = 103836, upload-time = "2026-01-19T07:59:28.112Z" }, + { url = "https://files.pythonhosted.org/packages/e8/f8/a693c0cfa9c783a2a8c4500b7b2a8bab420f8ca4f2d496153226bf1c12e3/pytokens-0.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:94ff5db97a0d3cd7248a5b07ba2167bd3edc1db92f76c6db00137bbaf068ddf8", size = 167643, upload-time = "2026-01-19T07:59:29.292Z" }, + { url = "https://files.pythonhosted.org/packages/c0/dd/a64eb1e9f3ec277b69b33ef1b40ffbcc8f0a3bafcde120997efc7bdefebf/pytokens-0.4.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d0dd6261cd9cc95fae1227b1b6ebee023a5fd4a4b6330b071c73a516f5f59b63", size = 289553, upload-time = "2026-01-19T07:59:30.537Z" }, + { url = "https://files.pythonhosted.org/packages/df/22/06c1079d93dbc3bca5d013e1795f3d8b9ed6c87290acd6913c1c526a6bb2/pytokens-0.4.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdca8159df407dbd669145af4171a0d967006e0be25f3b520896bc7068f02c4", size = 302490, upload-time = "2026-01-19T07:59:32.352Z" }, + { url = "https://files.pythonhosted.org/packages/8d/de/a6f5e43115b4fbf4b93aa87d6c83c79932cdb084f9711daae04549e1e4ad/pytokens-0.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4b5770abeb2a24347380a1164a558f0ebe06e98aedbd54c45f7929527a5fb26e", size = 305652, upload-time = "2026-01-19T07:59:33.685Z" }, + { url = "https://files.pythonhosted.org/packages/ab/3d/c136e057cb622e36e0c3ff7a8aaa19ff9720050c4078235691da885fe6ee/pytokens-0.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:74500d72c561dad14c037a9e86a657afd63e277dd5a3bb7570932ab7a3b12551", size = 115472, upload-time = "2026-01-19T07:59:34.734Z" }, + { url = "https://files.pythonhosted.org/packages/7c/3c/6941a82f4f130af6e1c68c076b6789069ef10c04559bd4733650f902fd3b/pytokens-0.4.0-py3-none-any.whl", hash = "sha256:0508d11b4de157ee12063901603be87fb0253e8f4cb9305eb168b1202ab92068", size = 13224, upload-time = "2026-01-19T07:59:49.822Z" }, +] + [[package]] name = "pywin32-ctypes" version = "0.2.3"