diff --git a/basalt/observability/__init__.py b/basalt/observability/__init__.py index 164130d..03e77d0 100644 --- a/basalt/observability/__init__.py +++ b/basalt/observability/__init__.py @@ -19,6 +19,7 @@ EventSpanHandle, FunctionSpanHandle, LLMSpanHandle, + NoOpSpanHandle, RetrievalSpanHandle, SpanHandle, StartSpanHandle, @@ -70,6 +71,7 @@ "ToolSpanHandle", "FunctionSpanHandle", "EventSpanHandle", + "NoOpSpanHandle", ] _instrumentation = InstrumentationManager() diff --git a/basalt/observability/api.py b/basalt/observability/api.py index 1f5be1e..165105e 100644 --- a/basalt/observability/api.py +++ b/basalt/observability/api.py @@ -20,6 +20,7 @@ EventSpanHandle, FunctionSpanHandle, LLMSpanHandle, + NoOpSpanHandle, RetrievalSpanHandle, SpanHandle, StartSpanHandle, @@ -117,9 +118,7 @@ def _resolve_kind_str(kind: ObserveKind | str) -> str: kind_str = str(kind).lower() valid_kinds = {k.value for k in ObserveKind} if kind_str not in valid_kinds: - raise ValueError( - f"Invalid kind '{kind_str}'. Must be one of: {', '.join(sorted(valid_kinds))}" - ) + raise ValueError(f"Invalid kind '{kind_str}'. Must be one of: {', '.join(sorted(valid_kinds))}") return kind_str @@ -255,9 +254,7 @@ async def async_wrapper( **kwargs: object, ) -> object: bound = resolve_bound_arguments(func, args, kwargs) - user_identity, org_identity = resolve_identity_payload( - self.identity_resolver, bound - ) + user_identity, org_identity = resolve_identity_payload(self.identity_resolver, bound) pre_evaluators = resolve_evaluators_payload(self.evaluators, bound) span_name = self.name @@ -337,8 +334,7 @@ def __init__( # Validate name is provided and non-empty if not name or not isinstance(name, str) or not name.strip(): raise ValueError( - "name is required and must be a non-empty string. " - "Please provide a descriptive name for this span (e.g., 'llm_generation', 'vector_search')." + "name is required and must be a non-empty string. Please provide a descriptive name for this span (e.g., 'llm_generation', 'vector_search')." ) self.name = name.strip() @@ -349,24 +345,21 @@ def __init__( self.output_resolver = output self.variables_resolver = variables self.prompt = prompt - self._span_handle: SpanHandle | None = None + self._span_handle: SpanHandle | NoOpSpanHandle | None = None self._ctx_manager = None @staticmethod def _get_config_for_kind(kind_str: str): return _get_observe_config_for_kind(kind_str) - def __enter__(self) -> SpanHandle: + def __enter__(self) -> SpanHandle | NoOpSpanHandle: span_name = self.name kind_str = _resolve_kind_str(self.kind) # Reject ROOT kind if kind_str == ObserveKind.ROOT.value: - raise ValueError( - f"Cannot use kind='{ObserveKind.ROOT.value}' with Observe. " - f"Use StartObserve (start_observe) for root spans." - ) + raise ValueError(f"Cannot use kind='{ObserveKind.ROOT.value}' with Observe. Use StartObserve (start_observe) for root spans.") handle_cls, tracer_name, _, _ = self._get_config_for_kind(kind_str) @@ -387,16 +380,23 @@ def __enter__(self) -> SpanHandle: if self.prompt.variables: prompt_attrs["basalt.prompt.variables"] = json.dumps(self.prompt.variables) - # Check for root span + # Check for root span - if no root span, return a no-op handle + # This ensures observe() is ignored unless start_observe has been called first from opentelemetry import context as otel_context + from .context_managers import NoOpSpanHandle if not otel_context.get_value(ROOT_SPAN_CONTEXT_KEY): - import logging + # No root span exists - return a no-op handle + self._ctx_manager = NoOpSpanHandle() + self._span_handle = self._ctx_manager.__enter__() + return self._span_handle - logger = logging.getLogger(__name__) - logger.warning( - "Observe used without a preceding start_observe. This may lead to missing trace context." - ) + # Inherit feature_slug from parent context if available + from .trace_context import FEATURE_SLUG_CONTEXT_KEY + + current_feature_slug = otel_context.get_value(FEATURE_SLUG_CONTEXT_KEY) + if not isinstance(current_feature_slug, str): + current_feature_slug = None self._ctx_manager = _with_span_handle( name=span_name, @@ -406,6 +406,7 @@ def __enter__(self) -> SpanHandle: span_type=kind_str, evaluators=self.evaluators, metadata=self._metadata, + feature_slug=current_feature_slug, # In context manager mode, we don't auto-resolve input/vars from args # User must call observe.input() or pass explicit input_payload if we added it to __init__ # But __init__ has resolvers, not values. @@ -420,6 +421,10 @@ def __exit__(self, exc_type, exc_value, traceback) -> bool | None: return None def __call__(self, func: F) -> F: + # Import context module once for all nested functions + from opentelemetry import context as otel_context + from .trace_context import FEATURE_SLUG_CONTEXT_KEY + if isinstance(self.kind, ObserveKind): kind_str = self.kind.value else: @@ -427,18 +432,13 @@ def __call__(self, func: F) -> F: # Reject ROOT kind if kind_str == ObserveKind.ROOT.value: - raise ValueError( - f"Cannot use kind='{ObserveKind.ROOT.value}' with Observe. " - f"Use StartObserve (start_observe) for root spans." - ) + raise ValueError(f"Cannot use kind='{ObserveKind.ROOT.value}' with Observe. Use StartObserve (start_observe) for root spans.") handle_cls, tracer_name, default_input, default_vars = self._get_config_for_kind(kind_str) # Use defaults if not provided input_resolver = self.input_resolver if self.input_resolver is not None else default_input - variables_resolver = ( - self.variables_resolver if self.variables_resolver is not None else default_vars - ) + variables_resolver = self.variables_resolver if self.variables_resolver is not None else default_vars # Process prompt parameter if provided prompt_metadata = {} @@ -476,25 +476,12 @@ def prepare_call_data( list[Any] | None, ]: computed_metadata_raw = resolve_attributes(self._metadata, args, kwargs) - computed_metadata = ( - computed_metadata_raw if isinstance(computed_metadata_raw, Mapping) else None - ) + computed_metadata = computed_metadata_raw if isinstance(computed_metadata_raw, Mapping) else None bound = resolve_bound_arguments(func, args, kwargs) input_payload = resolve_payload_from_bound(input_resolver, bound) variables_payload = resolve_variables_payload(variables_resolver, bound) pre_evaluators = resolve_evaluators_payload(self.evaluators, bound) - # Check for root span - from opentelemetry import context as otel_context - - if not otel_context.get_value(ROOT_SPAN_CONTEXT_KEY): - import logging - - logger = logging.getLogger(__name__) - logger.warning( - "Observe used without a preceding start_observe. This may lead to missing trace context." - ) - return computed_metadata, bound, input_payload, variables_payload, pre_evaluators # Pre-hooks @@ -513,9 +500,17 @@ def apply_post(span, result): @functools.wraps(func) def wrapper(*args, **kwargs): - computed_metadata, bound, input_payload, variables_payload, pre_evaluators = ( - prepare_call_data(args, kwargs) - ) + # Early exit: if no root span exists, just call the function (no-op) + if not otel_context.get_value(ROOT_SPAN_CONTEXT_KEY): + return func(*args, **kwargs) + + computed_metadata, bound, input_payload, variables_payload, pre_evaluators = prepare_call_data(args, kwargs) + # Inherit feature_slug from parent context if available + from .trace_context import FEATURE_SLUG_CONTEXT_KEY + + current_feature_slug = otel_context.get_value(FEATURE_SLUG_CONTEXT_KEY) + if not isinstance(current_feature_slug, str): + current_feature_slug = None with _with_span_handle( name=self.name, @@ -527,6 +522,7 @@ def wrapper(*args, **kwargs): variables=variables_payload, evaluators=pre_evaluators, metadata=computed_metadata, + feature_slug=current_feature_slug, ) as span: if apply_pre: apply_pre(span, bound) @@ -556,9 +552,18 @@ async def async_wrapper( *args: object, **kwargs: object, ) -> object: - computed_metadata, bound, input_payload, variables_payload, pre_evaluators = ( - prepare_call_data(args, kwargs) - ) + # Early exit: if no root span exists, just call the function (no-op) + if not otel_context.get_value(ROOT_SPAN_CONTEXT_KEY): + return await func(*args, **kwargs) + + computed_metadata, bound, input_payload, variables_payload, pre_evaluators = prepare_call_data(args, kwargs) + + # Inherit feature_slug from parent context if available + from .trace_context import FEATURE_SLUG_CONTEXT_KEY + + current_feature_slug = otel_context.get_value(FEATURE_SLUG_CONTEXT_KEY) + if not isinstance(current_feature_slug, str): + current_feature_slug = None with _with_span_handle( name=self.name, @@ -570,6 +575,7 @@ async def async_wrapper( variables=variables_payload, evaluators=pre_evaluators, metadata=computed_metadata, + feature_slug=current_feature_slug, ) as span: if apply_pre: apply_pre(span, bound) @@ -598,9 +604,7 @@ async def async_wrapper( # Static Domain Methods @staticmethod - def _identify( - user: str | dict[str, Any] | None = None, organization: str | dict[str, Any] | None = None - ) -> None: + def _identify(user: str | dict[str, Any] | None = None, organization: str | dict[str, Any] | None = None) -> None: """Set the user and/or organization identity for the current context.""" if user: if isinstance(user, str): @@ -612,9 +616,7 @@ def _identify( if isinstance(organization, str): _set_trace_organization(organization_id=organization) elif isinstance(organization, dict): - _set_trace_organization( - organization_id=organization.get("id", "unknown"), name=organization.get("name") - ) + _set_trace_organization(organization_id=organization.get("id", "unknown"), name=organization.get("name")) @staticmethod def _root_span() -> StartSpanHandle | None: @@ -870,10 +872,7 @@ def _set_evaluation_config(config: EvaluationConfig | dict[str, Any]) -> None: import logging logger = logging.getLogger(__name__) - logger.warning( - "_set_evaluation_config() can only be called on root spans (StartSpanHandle). " - "This call will be ignored." - ) + logger.warning("_set_evaluation_config() can only be called on root spans (StartSpanHandle). This call will be ignored.") return handle.set_evaluation_config(config) @@ -1038,8 +1037,7 @@ def __init__( # Validate name is provided and non-empty if not name or not isinstance(name, str) or not name.strip(): raise ValueError( - "name is required and must be a non-empty string. " - "Please provide a descriptive name for this span (e.g., 'async_workflow', 'async_operation')." + "name is required and must be a non-empty string. Please provide a descriptive name for this span (e.g., 'async_workflow', 'async_operation')." ) self.name = name.strip() @@ -1129,8 +1127,7 @@ def __init__( # Validate name is provided and non-empty if not name or not isinstance(name, str) or not name.strip(): raise ValueError( - "name is required and must be a non-empty string. " - "Please provide a descriptive name for this span (e.g., 'async_operation', 'async_fetch')." + "name is required and must be a non-empty string. Please provide a descriptive name for this span (e.g., 'async_operation', 'async_fetch')." ) self.name = name.strip() @@ -1155,10 +1152,7 @@ async def __aenter__(self) -> SpanHandle: # Reject ROOT kind if kind_str == ObserveKind.ROOT.value: - raise ValueError( - f"Cannot use kind='{ObserveKind.ROOT.value}' with AsyncObserve. " - f"Use AsyncStartObserve (async_start_observe) for root spans." - ) + raise ValueError(f"Cannot use kind='{ObserveKind.ROOT.value}' with AsyncObserve. Use AsyncStartObserve (async_start_observe) for root spans.") handle_cls, tracer_name, _, _ = self._get_config_for_kind(kind_str) @@ -1191,9 +1185,7 @@ async def __aenter__(self) -> SpanHandle: import logging logger = logging.getLogger(__name__) - logger.warning( - "AsyncObserve used without a preceding async_start_observe. This may lead to missing trace context." - ) + logger.warning("AsyncObserve used without a preceding async_start_observe. This may lead to missing trace context.") self._ctx_manager = _async_with_span_handle( name=span_name, diff --git a/basalt/observability/context_managers.py b/basalt/observability/context_managers.py index e063ee6..bfa9ba6 100644 --- a/basalt/observability/context_managers.py +++ b/basalt/observability/context_managers.py @@ -43,6 +43,80 @@ logger = logging.getLogger(__name__) +class NoOpSpanHandle: + """No-op span handle that does nothing. + + Used when observe() is called without a preceding start_observe. + This prevents creating Basalt spans when not in a trace context. + """ + + def set_attribute(self, key: str, value: str | int | float | bool | None) -> None: + """No-op: does nothing.""" + pass + + def set_attributes(self, attributes: dict[str, Any]) -> None: + """No-op: does nothing.""" + pass + + def set_metadata(self, metadata: Mapping[str, Any] | None) -> None: + """No-op: does nothing.""" + pass + + def set_prompt(self, prompt: Any) -> None: + """No-op: does nothing.""" + pass + + def set_input(self, payload: JSONValue) -> None: + """No-op: does nothing.""" + pass + + def set_output(self, payload: JSONValue) -> None: + """No-op: does nothing.""" + pass + + def set_io( + self, + *, + input_payload: JSONValue | None = None, + output_payload: JSONValue | None = None, + variables: Mapping[str, Any] | None = None, + ) -> None: + """No-op: does nothing.""" + pass + + def add_evaluator(self, evaluator_slug: str) -> None: + """No-op: does nothing.""" + pass + + def add_evaluators(self, *evaluators: Sequence[str]) -> None: + """No-op: does nothing.""" + pass + + def set_identity(self, identity: Mapping[str, Any] | None = None) -> None: + """No-op: does nothing.""" + pass + + def set_model(self, model: str) -> None: + """No-op: does nothing.""" + pass + + def set_response_model(self, model: str) -> None: + """No-op: does nothing.""" + pass + + def set_operation_name(self, operation: str) -> None: + """No-op: does nothing.""" + pass + + def __enter__(self): + """No-op: does nothing.""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """No-op: does nothing.""" + pass + + @dataclass(slots=True, frozen=True) class EvaluationConfig: """ @@ -232,9 +306,7 @@ def __init__( ) -> None: self._span = span self._io_payload: dict[str, Any] = {"input": None, "output": None, "variables": None} - self._parent_span = ( - parent_span if parent_span and parent_span.get_span_context().is_valid else None - ) + self._parent_span = parent_span if parent_span and parent_span.get_span_context().is_valid else None self._evaluators: dict[str, EvaluatorAttachment] = {} self._evaluator_config: EvaluationConfig | None = None self._hydrate_existing_evaluators() @@ -350,9 +422,7 @@ def set_io( if trace_content_enabled(): _set_serialized_attribute(self._span, semconv.BasaltSpan.VARIABLES, dict(variables)) if self._parent_span: - _set_serialized_attribute( - self._parent_span, semconv.BasaltSpan.VARIABLES, dict(variables) - ) + _set_serialized_attribute(self._parent_span, semconv.BasaltSpan.VARIABLES, dict(variables)) def _io_snapshot(self) -> dict[str, Any]: """Return a shallow copy of the tracked IO payload.""" @@ -510,9 +580,7 @@ def set_evaluation_config(self, config: EvaluationConfig | Mapping[str, Any]) -> """ if isinstance(config, EvaluationConfig): self._evaluator_config = config - _set_serialized_attribute( - self._span, semconv.BasaltSpan.EVALUATION_CONFIG, config.to_dict() - ) + _set_serialized_attribute(self._span, semconv.BasaltSpan.EVALUATION_CONFIG, config.to_dict()) elif isinstance(config, Mapping): config_dict = dict(config) self._evaluator_config = EvaluationConfig(**config_dict) @@ -703,9 +771,7 @@ def _with_span_handle( defaults = _current_trace_defaults() parent_span = trace.get_current_span() - if parent_span and ( - not parent_span.get_span_context().is_valid or not parent_span.is_recording() - ): + if parent_span and (not parent_span.get_span_context().is_valid or not parent_span.is_recording()): parent_span = None # Prepare context tokens for user/org propagation @@ -805,9 +871,7 @@ def _with_span_handle( root_span_token = attach(set_value(ROOT_SPAN_CONTEXT_KEY, span)) # Also store the parent span (if any) for trace-level helpers if parent_span is not None: - root_parent_span_token = attach( - set_value(ROOT_PARENT_SPAN_CONTEXT_KEY, parent_span) - ) + root_parent_span_token = attach(set_value(ROOT_PARENT_SPAN_CONTEXT_KEY, parent_span)) # Set basalt.root attribute span.set_attribute("basalt.root", True) elif in_basalt_trace: diff --git a/examples/microservices/service_a/main.py b/examples/microservices/service_a/main.py index 81641dc..996d678 100644 --- a/examples/microservices/service_a/main.py +++ b/examples/microservices/service_a/main.py @@ -68,7 +68,7 @@ def build_basalt_client() -> Basalt: service_name="service-a-orchestrator", enabled_providers=["openai"], # Auto-instrument OpenAI SDK calls trace_content=True, # Capture prompt and completion content - # exporter=[exporter], # Use custom local exporter + # exporter=[exporter], # Use custom local exporter ) return Basalt(api_key=basalt_key, telemetry_config=telemetry) @@ -100,15 +100,15 @@ async def lifespan(app: FastAPI): # Shutdown logger.info("Shutting down Service A...") - + # Close OpenAI client if openai_client: await openai_client.close() logger.info("OpenAI client closed") - + # Uninstrument httpx HTTPXClientInstrumentor().uninstrument() - + if basalt_client: logger.info("Flushing telemetry...") basalt_client.shutdown() @@ -120,7 +120,7 @@ async def lifespan(app: FastAPI): # Auto-instrument FastAPI for incoming HTTP requests (distributed tracing!) # This now works with Basalt thanks to the smart root detection fix -FastAPIInstrumentor.instrument_app(app) +# FastAPIInstrumentor.instrument_app(app) logger.info("FastAPI instrumentation enabled - distributed tracing active") @@ -158,7 +158,7 @@ async def call_service_b() -> dict: async def summarize_with_llm(service_b_response: dict) -> dict: """ Use OpenAI to generate a summary of Service B's analysis. - + This function demonstrates auto-instrumented OpenAI calls: - GENERATION spans created automatically - Token usage captured automatically @@ -173,24 +173,21 @@ async def summarize_with_llm(service_b_response: dict) -> dict: # Get prompt from Basalt API prompt_cm = await basalt_client.prompts.get( slug="joke-analyzer", - variables={ - "analysis": str(service_b_response.get("analysis", {})), - "ticket_id": service_b_response.get("ticket_id", "unknown") - } + variables={"analysis": str(service_b_response.get("analysis", {})), "ticket_id": service_b_response.get("ticket_id", "unknown")}, ) async with prompt_cm as prompt: logger.info(f"Retrieved prompt: {prompt.slug} v{prompt.version}") - + # Auto-instrumented OpenAI call - GENERATION span created automatically response = await openai_client.chat.completions.create( model=OPENAI_MODEL, messages=[ {"role": "system", "content": "You are a helpful assistant that summarizes support ticket analyses."}, - {"role": "user", "content": f"Summarize this analysis in 2-3 sentences: {service_b_response}"} + {"role": "user", "content": f"Summarize this analysis in 2-3 sentences: {service_b_response}"}, ], temperature=0.7, - max_tokens=150 + max_tokens=150, ) return { @@ -198,7 +195,7 @@ async def summarize_with_llm(service_b_response: dict) -> dict: "model": OPENAI_MODEL, "tokens_used": response.usage.total_tokens, "prompt_slug": prompt.slug, - "prompt_version": prompt.version + "prompt_version": prompt.version, } except Exception as e: @@ -212,14 +209,19 @@ async def process_support_request(): Main endpoint for processing support requests. This demonstrates: - 1. start_observe with feature_slug="support-ticket" - 2. HTTP call to Service B (auto-instrumented via httpx) - 3. Distributed tracing (trace context propagated via HTTP headers) - 4. Proper input/output tracking + 1. Get prompt BEFORE start_observe (tests feature_slug propagation fix) + 2. start_observe with feature_slug="support-ticket" + 3. HTTP call to Service B (auto-instrumented via httpx) + 4. Distributed tracing (trace context propagated via HTTP headers) + 5. Proper input/output tracking """ - async with async_start_observe( - name="process_support_request", feature_slug="support-ticket" - ) as root_span: + # Get prompt BEFORE start_observe - this tests that feature_slug propagates + # correctly to the prompt request span even without an active trace context + pre_prompt_cm = await basalt_client.prompts.get(slug="joke-analyzer", variables={"request_type": "support_ticket_processing"}) + async with pre_prompt_cm as pre_prompt: + logger.info(f"Pre-trace prompt retrieved: {pre_prompt.slug}") + + async with async_start_observe(name="process_support_request", feature_slug="support-ticket") as root_span: # Set input for observability request_data = {"request_type": "support_ticket_processing", "source": "service-a"} root_span.set_input(request_data) diff --git a/examples/microservices/service_b/main.py b/examples/microservices/service_b/main.py index bc9720c..5a3a120 100644 --- a/examples/microservices/service_b/main.py +++ b/examples/microservices/service_b/main.py @@ -61,7 +61,7 @@ def build_basalt_client() -> Basalt: service_name="service-b-analysis", enabled_providers=["google_generativeai"], # NEW Google GenAI SDK (from google import genai) trace_content=True, # Capture prompt and completion content - # exporter=[exporter], # Use custom local exporter + # exporter=[exporter], # Use custom local exporter ) return Basalt(api_key=basalt_key, telemetry_config=telemetry) @@ -127,7 +127,7 @@ async def perform_retrieval(query: str) -> dict: async def analyze_with_prompt(ticket_data: dict, context: dict) -> dict: """ Analyze ticket using Gemini LLM with the joke-analyzer prompt. - + This function demonstrates auto-instrumented Gemini calls: - GENERATION spans created automatically - Token usage captured automatically @@ -137,19 +137,12 @@ async def analyze_with_prompt(ticket_data: dict, context: dict) -> dict: gemini_api_key = os.getenv("GEMINI_API_KEY") if not gemini_api_key: logger.warning("GEMINI_API_KEY not set. Using fallback.") - return { - "analysis": "Fallback analysis (API key not set)", - "error": "GEMINI_API_KEY not configured" - } + return {"analysis": "Fallback analysis (API key not set)", "error": "GEMINI_API_KEY not configured"} try: # Retrieve prompt from Basalt prompt_cm = await basalt_client.prompts.get( - slug="joke-analyzer", - variables={ - "ticket_id": ticket_data.get("ticket_id", "unknown"), - "context_count": context["count"] - } + slug="joke-analyzer", variables={"ticket_id": ticket_data.get("ticket_id", "unknown"), "context_count": context["count"]} ) async with prompt_cm as prompt: @@ -171,7 +164,7 @@ async def analyze_with_prompt(ticket_data: dict, context: dict) -> dict: 3. Brief analysis (2-3 sentences) Context: {context.get("results", [])} -""" +""", ) # Extract response and metadata @@ -184,7 +177,7 @@ async def analyze_with_prompt(ticket_data: dict, context: dict) -> dict: tokens_used = { "prompt_tokens": getattr(usage_metadata, "prompt_token_count", 0), "completion_tokens": getattr(usage_metadata, "candidates_token_count", 0), - "total_tokens": getattr(usage_metadata, "total_token_count", 0) + "total_tokens": getattr(usage_metadata, "total_token_count", 0), } return { @@ -210,14 +203,19 @@ async def analyze_ticket(): Main endpoint for ticket analysis. This demonstrates: - 1. start_observe with feature_slug="support-ticket" - 2. Nested observe span with kind=RETRIEVAL - 3. Prompt retrieval using the prompts API - 4. Proper input/output tracking + 1. Get prompt BEFORE start_observe (tests feature_slug propagation fix) + 2. start_observe with feature_slug="support-ticket" + 3. Nested observe span with kind=RETRIEVAL + 4. Prompt retrieval using the prompts API + 5. Proper input/output tracking """ - async with async_start_observe( - name="analyze_support_ticket", feature_slug="support-ticket" - ) as root_span: + # Get prompt BEFORE start_observe - this tests that feature_slug propagates + # correctly to the prompt request span even without an active trace context + pre_prompt_cm = await basalt_client.prompts.get(slug="joke-analyzer", variables={"request_type": "ticket_analysis"}) + async with pre_prompt_cm as pre_prompt: + logger.info(f"Pre-trace prompt retrieved: {pre_prompt.slug}") + + async with async_start_observe(name="analyze_support_ticket", feature_slug="support-ticket") as root_span: # Set input for observability ticket_data = {"ticket_id": "DEMO-001", "request_type": "analysis"} root_span.set_input(ticket_data) diff --git a/tests/observability/test_api.py b/tests/observability/test_api.py index e0716af..33477e0 100644 --- a/tests/observability/test_api.py +++ b/tests/observability/test_api.py @@ -1,8 +1,11 @@ # File: tests/test_api.py +import json + import pytest from basalt.observability import ObserveKind, SpanHandle from basalt.observability.api import Observe, StartObserve +from basalt.observability import semconv def test_get_config_for_kind_generation(): @@ -123,11 +126,20 @@ def decorated_function(x, y): def test_observe_as_context_manager(): - """Test Observe when used as a context manager.""" + """Test Observe when used as a context manager with start_observe.""" + + from .utils import get_exporter + + exporter = get_exporter() + exporter.clear() - with Observe(name="test_context_manager", kind=ObserveKind.EVENT) as span: - assert isinstance(span, SpanHandle) - span.set_attribute("test_key", "test_value") + with StartObserve(name="test_root", feature_slug="test"): + with Observe(name="test_context_manager", kind=ObserveKind.EVENT) as span: + assert isinstance(span, SpanHandle) + span.set_attribute("test_key", "test_value") + + spans = exporter.get_finished_spans() + assert len(spans) > 0 def test_invalid_observe_kind(): @@ -150,10 +162,7 @@ def function_raises_exception(): def test_observe_static_metadata(): - """Test adding static metadata using Observe.""" - import json - - from basalt.observability import semconv + """Test Observe with static metadata (requires start_observe).""" from .utils import get_exporter @@ -162,8 +171,9 @@ def test_observe_static_metadata(): metadata = {"key1": "value1", "key2": "value2"} - with Observe(name="test_static_metadata", kind=ObserveKind.SPAN, metadata=metadata) as span: - assert span is not None + with StartObserve(name="test_root", feature_slug="test"): + with Observe(name="test_static_metadata", kind=ObserveKind.SPAN, metadata=metadata) as span: + assert span is not None # Verify metadata was set as aggregated JSON at basalt.metadata spans = exporter.get_finished_spans() @@ -243,7 +253,8 @@ async def async_root_function(x: int, y: int) -> int: def test_get_root_span(): - """Test retrieving the root span using Observe.root_span().""" + """Test Observe._root_span() returns the root span handle when called within start_observe.""" + from .utils import get_exporter exporter = get_exporter() @@ -257,7 +268,9 @@ def root_function(): root_span = Observe._root_span() return True - result = root_function() + # Call within start_observe context + with StartObserve(name="test_root", feature_slug="test"): + result = root_function() assert result is True assert root_span is not None @@ -409,14 +422,16 @@ class MockPrompt: def generate_text(): return "Generated response" - result = generate_text() + with StartObserve(name="test_root", feature_slug="test"): + result = generate_text() assert result == "Generated response" # Verify span attributes contain prompt metadata spans = exporter.get_finished_spans() assert len(spans) > 0 - span = spans[-1] + span = next((s for s in spans if s.name == "test_with_prompt"), None) + assert span is not None assert span.attributes.get("basalt.prompt.slug") == "test-prompt" assert span.attributes.get("basalt.prompt.version") == "1.0.0" assert span.attributes.get("basalt.prompt.model.provider") == "openai" @@ -476,16 +491,16 @@ class MockPrompt: variables={"test": "value"}, ) - with Observe( - kind=ObserveKind.GENERATION, name="test_context_with_prompt", prompt=mock_prompt - ) as span: - pass + with StartObserve(name="test_root", feature_slug="test"): + with Observe(kind=ObserveKind.GENERATION, name="test_context_with_prompt", prompt=mock_prompt) as span: + pass # Verify span attributes contain prompt metadata spans = exporter.get_finished_spans() assert len(spans) > 0 - span = spans[-1] + span = next((s for s in spans if s.name == "test_context_with_prompt"), None) + assert span is not None assert span.attributes.get("basalt.prompt.slug") == "context-prompt" assert span.attributes.get("basalt.prompt.version") == "2.0.0" assert span.attributes.get("basalt.prompt.model.provider") == "anthropic" @@ -549,14 +564,16 @@ class MockPrompt: def generate_text(): return "Response" - result = generate_text() + with StartObserve(name="test_root", feature_slug="test"): + result = generate_text() assert result == "Response" # Verify span attributes contain prompt metadata but no variables spans = exporter.get_finished_spans() assert len(spans) > 0 - span = spans[-1] + span = next((s for s in spans if s.name == "test_no_vars"), None) + assert span is not None assert span.attributes.get("basalt.prompt.slug") == "no-vars-prompt" assert span.attributes.get("basalt.prompt.version") == "1.0.0" assert "basalt.prompt.variables" not in span.attributes diff --git a/tests/observability/test_context_managers.py b/tests/observability/test_context_managers.py index a0b88fd..96a58db 100644 --- a/tests/observability/test_context_managers.py +++ b/tests/observability/test_context_managers.py @@ -9,6 +9,7 @@ EVALUATOR_CONTEXT_KEY, ROOT_SPAN_CONTEXT_KEY, EvaluatorAttachment, + NoOpSpanHandle, SpanHandle, _normalize_evaluator_entry, get_root_span_handle, @@ -214,9 +215,7 @@ def test_set_attribute(mock_span): def test_set_input(mock_span): """Test SpanHandle.set_input sets input payload and serializes it if tracing is enabled.""" with pytest.MonkeyPatch().context() as monkeypatch: - monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", lambda: True - ) + monkeypatch.setattr("basalt.observability.context_managers.trace_content_enabled", lambda: True) span_handle = SpanHandle(span=mock_span) payload = {"key": "value"} span_handle.set_input(payload) @@ -227,9 +226,7 @@ def test_set_input(mock_span): def test_set_output(mock_span): """Test SpanHandle.set_output sets output payload and serializes it if tracing is enabled.""" with pytest.MonkeyPatch().context() as monkeypatch: - monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", lambda: True - ) + monkeypatch.setattr("basalt.observability.context_managers.trace_content_enabled", lambda: True) span_handle = SpanHandle(span=mock_span) payload = {"result": "success"} span_handle.set_output(payload) @@ -240,16 +237,12 @@ def test_set_output(mock_span): def test_set_io(mock_span): """Test SpanHandle.set_io sets all I/O payloads correctly.""" with pytest.MonkeyPatch().context() as monkeypatch: - monkeypatch.setattr( - "basalt.observability.context_managers.trace_content_enabled", lambda: False - ) + monkeypatch.setattr("basalt.observability.context_managers.trace_content_enabled", lambda: False) span_handle = SpanHandle(span=mock_span) input_payload = {"input": "data"} output_payload = {"output": "data"} variables = {"key": "value"} - span_handle.set_io( - input_payload=input_payload, output_payload=output_payload, variables=variables - ) + span_handle.set_io(input_payload=input_payload, output_payload=output_payload, variables=variables) assert span_handle._io_payload["input"] == input_payload assert span_handle._io_payload["output"] == output_payload assert span_handle._io_payload["variables"] == variables @@ -422,9 +415,7 @@ async def test_async_start_observe_with_metadata(): exporter.clear() metadata = {"custom_key": "custom_value", "test_id": 42} - async with AsyncStartObserve( - name="test_with_metadata", feature_slug="test_metadata", metadata=metadata - ) as span: + async with AsyncStartObserve(name="test_with_metadata", feature_slug="test_metadata", metadata=metadata) as span: # Verify span was created successfully assert span is not None assert isinstance(span._span, object) # Span exists @@ -448,9 +439,7 @@ async def test_async_observe_with_evaluators(): async with AsyncStartObserve(name="test_root", feature_slug="test_evaluators"): evaluator_slugs = ["test_evaluator_1", "test_evaluator_2"] - async with AsyncObserve( - name="test_with_evaluators", kind=ObserveKind.GENERATION, evaluators=evaluator_slugs - ) as span: + async with AsyncObserve(name="test_with_evaluators", kind=ObserveKind.GENERATION, evaluators=evaluator_slugs) as span: # The evaluators should be attached to the span # We can verify by checking the context or span attributes assert span is not None @@ -513,9 +502,7 @@ async def test_async_start_observe_with_identity(): "organization": {"id": "org-456", "name": "Test Org"}, } - async with AsyncStartObserve( - name="test_identity", feature_slug="test_identity", identity=identity - ) as span: + async with AsyncStartObserve(name="test_identity", feature_slug="test_identity", identity=identity) as span: # Verify span and identity setup completed successfully assert span is not None @@ -540,13 +527,14 @@ def test_child_observe_has_in_trace_attribute(setup_tracing): def test_standalone_observe_has_in_trace_attribute(setup_tracing): - """Verify that standalone observe spans have basalt.in_trace=true.""" - from basalt.observability import Observe, ObserveKind + """Verify that standalone observe spans (without start_observe) are no-ops.""" + from basalt.observability import Observe, ObserveKind, NoOpSpanHandle with Observe(kind=ObserveKind.GENERATION, name="standalone") as span: - # Standalone observe creates a root span - assert span._span.attributes.get("basalt.root") is True - assert span._span.attributes.get("basalt.in_trace") is True + # Standalone observe now returns a NoOpSpanHandle (no-op) + assert isinstance(span, NoOpSpanHandle) + # No spans should be created + assert span._span is None if hasattr(span, "_span") else True def test_deeply_nested_spans_have_in_trace_attribute(setup_tracing): diff --git a/tests/observability/test_feature_slug_propagation.py b/tests/observability/test_feature_slug_propagation.py index 54d5e3c..59b7a98 100644 --- a/tests/observability/test_feature_slug_propagation.py +++ b/tests/observability/test_feature_slug_propagation.py @@ -101,3 +101,89 @@ def test_feature_slug_propagates_to_basalt_fastapi_and_autoinstrumented_spans( assert span.attributes.get(BasaltSpan.FEATURE_SLUG) == "support-ticket", ( f"Span {name} missing feature_slug propagation" ) + + +def test_feature_slug_propagates_to_api_request_spans(setup_feature_slug_tracer): + """Ensure feature_slug from start_observe reaches API request spans created via observe().""" + exporter = setup_feature_slug_tracer + + with start_observe(name="root", feature_slug="test-feature"): + # Simulate an API request span using observe (as used in request_tracing.py) + with observe(name="api.prompts.get", kind=ObserveKind.SPAN): + pass + + # Simulate nested observe calls + with observe(name="nested.operation", kind=ObserveKind.FUNCTION): + with observe(name="deeply.nested", kind=ObserveKind.SPAN): + pass + + spans = exporter.get_finished_spans() + span_by_name = {span.name: span for span in spans} + + expected_names = { + "root", + "api.prompts.get", + "nested.operation", + "deeply.nested", + } + assert expected_names.issubset(span_by_name.keys()), ( + f"Missing spans: {expected_names - set(span_by_name.keys())}" + ) + + # Verify all spans have the feature_slug + for name in expected_names: + span = span_by_name[name] + assert span.attributes.get(BasaltSpan.FEATURE_SLUG) == "test-feature", ( + f"Span '{name}' missing feature_slug propagation. " + f"Expected 'test-feature', got {span.attributes.get(BasaltSpan.FEATURE_SLUG)}" + ) + + +def test_feature_slug_propagates_to_decorated_functions(setup_feature_slug_tracer): + """Ensure feature_slug from start_observe reaches decorated function spans.""" + exporter = setup_feature_slug_tracer + + @observe(name="decorated_function", kind=ObserveKind.FUNCTION) + def my_function(): + return "result" + + @observe(name="decorated_async_function", kind=ObserveKind.FUNCTION) + async def my_async_function(): + return "async_result" + + with start_observe(name="root", feature_slug="decorator-test"): + my_function() + + spans = exporter.get_finished_spans() + span_by_name = {span.name: span for span in spans} + + # Verify decorated function span has feature_slug + assert "decorated_function" in span_by_name, "decorated_function span not found" + function_span = span_by_name["decorated_function"] + assert function_span.attributes.get(BasaltSpan.FEATURE_SLUG) == "decorator-test", ( + f"Decorated function span missing feature_slug propagation. " + f"Expected 'decorator-test', got {function_span.attributes.get(BasaltSpan.FEATURE_SLUG)}" + ) + + # Clear for next test + exporter.clear() + + # Test async decorated function + import asyncio + + async def test_async(): + with start_observe(name="async_root", feature_slug="async-decorator-test"): + await my_async_function() + + asyncio.run(test_async()) + + spans = exporter.get_finished_spans() + span_by_name = {span.name: span for span in spans} + + # Verify async decorated function span has feature_slug + assert "decorated_async_function" in span_by_name, "decorated_async_function span not found" + async_function_span = span_by_name["decorated_async_function"] + assert async_function_span.attributes.get(BasaltSpan.FEATURE_SLUG) == "async-decorator-test", ( + f"Async decorated function span missing feature_slug propagation. " + f"Expected 'async-decorator-test', got {async_function_span.attributes.get(BasaltSpan.FEATURE_SLUG)}" + )