Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions sdks/python/src/agent_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ async def handle_input(user_message: str) -> str:
from ._control_registry import (
clear as clear_step_registry,
)

# Import client and operations modules
from .client import AgentControlClient

# Import control decorator
from .control_decorators import ControlSteerError, ControlViolationError, control
from .evaluation import check_evaluation_with_local, evaluate_controls
from .observability import (
Expand All @@ -98,8 +94,11 @@ async def handle_input(user_message: str) -> str:
shutdown_observability,
sync_shutdown_observability,
)

# Import tracing and observability
from .telemetry import (
clear_trace_context_provider,
get_trace_context_from_provider,
set_trace_context_provider,
)
from .tracing import (
get_current_span_id,
get_current_trace_id,
Expand Down Expand Up @@ -1305,6 +1304,9 @@ async def main():
"get_current_span_id",
"with_trace",
"is_otel_available",
"set_trace_context_provider",
"get_trace_context_from_provider",
"clear_trace_context_provider",
# Observability
"init_observability",
"add_event",
Expand Down
20 changes: 14 additions & 6 deletions sdks/python/src/agent_control/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ._state import state
from .client import AgentControlClient
from .observability import add_event, get_logger, is_observability_enabled
from .tracing import get_trace_and_span_ids
from .validation import ensure_agent_name

_logger = get_logger(__name__)
Expand Down Expand Up @@ -291,6 +292,13 @@ async def check_evaluation_with_local(
httpx.HTTPError: If server request fails
"""
normalized_name = ensure_agent_name(agent_name)
resolved_trace_id = trace_id
resolved_span_id = span_id
if trace_id is None or span_id is None:
current_trace_id, current_span_id = get_trace_and_span_ids()
resolved_trace_id = trace_id or current_trace_id
resolved_span_id = span_id or current_span_id

# Partition controls by local flag
local_controls: list[_ControlAdapter] = []
parse_errors: list[ControlMatch] = []
Expand Down Expand Up @@ -389,8 +397,8 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult:
local_result,
request,
applicable_local_controls,
trace_id,
span_id,
resolved_trace_id,
resolved_span_id,
agent_name=event_agent_name,
)

Expand All @@ -409,10 +417,10 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult:
if _has_applicable_prefiltered_server_controls(server_control_payloads, request):
request_payload = request.model_dump(mode="json", exclude_none=True)
headers: dict[str, str] = {}
if trace_id:
headers["X-Trace-Id"] = trace_id
if span_id:
headers["X-Span-Id"] = span_id
if resolved_trace_id:
headers["X-Trace-Id"] = resolved_trace_id
if resolved_span_id:
headers["X-Span-Id"] = resolved_span_id

response = await client.http_client.post(
"/api/v1/evaluation",
Expand Down
17 changes: 17 additions & 0 deletions sdks/python/src/agent_control/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Telemetry interfaces for provider-agnostic tracing."""

from .trace_context import (
TraceContext,
TraceContextProvider,
clear_trace_context_provider,
get_trace_context_from_provider,
set_trace_context_provider,
)

__all__ = [
"TraceContext",
"TraceContextProvider",
"clear_trace_context_provider",
"get_trace_context_from_provider",
"set_trace_context_provider",
]
55 changes: 55 additions & 0 deletions sdks/python/src/agent_control/telemetry/trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Provider-agnostic trace context interface for external tracing systems."""

from collections.abc import Callable
from typing import TypedDict


class TraceContext(TypedDict):
"""Resolved trace context for a control evaluation."""

trace_id: str
span_id: str


TraceContextProvider = Callable[[], TraceContext | None]

_trace_context_provider: TraceContextProvider | None = None


def set_trace_context_provider(provider: TraceContextProvider | None) -> None:
"""Register a provider that returns the current trace context."""
global _trace_context_provider
_trace_context_provider = provider


def get_trace_context_from_provider() -> TraceContext | None:
"""Return trace context from the registered provider, if any."""
if _trace_context_provider is None:
return None

try:
trace_context = _trace_context_provider()
except Exception:
# Provider failures should not break control evaluation.
return None

if trace_context is None:
return None

trace_id = trace_context.get("trace_id")
span_id = trace_context.get("span_id")
if not isinstance(trace_id, str) or not isinstance(span_id, str):
return None
if not trace_id or not span_id:
return None

return {
"trace_id": trace_id,
"span_id": span_id,
}


def clear_trace_context_provider() -> None:
"""Clear the registered trace context provider."""
global _trace_context_provider
_trace_context_provider = None
22 changes: 20 additions & 2 deletions sdks/python/src/agent_control/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from contextlib import contextmanager
from contextvars import ContextVar, Token

from .telemetry.trace_context import get_trace_context_from_provider

# Context variables for trace/span propagation
_trace_id_var: ContextVar[str | None] = ContextVar("trace_id", default=None)
_span_id_var: ContextVar[str | None] = ContextVar("span_id", default=None)
Expand Down Expand Up @@ -94,8 +96,9 @@ def get_trace_and_span_ids() -> tuple[str, str]:

Priority:
1. Context variable (set by with_trace or explicitly)
2. OpenTelemetry context (if OTEL is installed and active)
3. Generate new OTEL-compatible IDs
2. External provider
3. OpenTelemetry context (if OTEL is installed and active)
4. Generate new OTEL-compatible IDs

Returns:
Tuple of (trace_id, span_id) - both are hex strings
Expand All @@ -114,6 +117,11 @@ def get_trace_and_span_ids() -> tuple[str, str]:
if trace_id is not None and span_id is not None:
return trace_id, span_id

# Try external provider
trace_context = get_trace_context_from_provider()
if trace_context:
return trace_context["trace_id"], trace_context["span_id"]

# Try OpenTelemetry context
otel_trace_id, otel_span_id = _get_otel_ids()

Expand All @@ -136,6 +144,11 @@ def get_current_trace_id() -> str | None:
if trace_id is not None:
return trace_id

# Try external provider
trace_context = get_trace_context_from_provider()
if trace_context:
return trace_context["trace_id"]

# Try OpenTelemetry
otel_trace_id, _ = _get_otel_ids()
return otel_trace_id
Expand All @@ -153,6 +166,11 @@ def get_current_span_id() -> str | None:
if span_id is not None:
return span_id

# Try external provider
trace_context = get_trace_context_from_provider()
if trace_context:
return trace_context["span_id"]

# Try OpenTelemetry
_, otel_span_id = _get_otel_ids()
return otel_span_id
Expand Down
72 changes: 69 additions & 3 deletions sdks/python/tests/test_observability_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
_map_applies_to,
_merge_results,
)
from agent_control.telemetry.trace_context import (
clear_trace_context_provider,
set_trace_context_provider,
)
from agent_control_models import ControlDefinition

# =============================================================================
Expand Down Expand Up @@ -326,6 +330,9 @@ def test_fallback_warning_logged_only_once(self):
class TestCheckEvaluationWithLocal:
"""Tests for check_evaluation_with_local event emission and non_matches."""

def teardown_method(self) -> None:
clear_trace_context_provider()

@pytest.mark.asyncio
async def test_emits_events_when_trace_context_provided(self):
"""Should emit observability events when trace_id and span_id are passed."""
Expand Down Expand Up @@ -398,7 +405,7 @@ async def test_emits_events_when_trace_context_provided(self):

@pytest.mark.asyncio
async def test_emits_events_without_trace_context(self):
"""Should still emit events when trace_id/span_id not provided (fallback IDs)."""
"""Should resolve trace context from the provider when IDs are omitted."""
from agent_control_models import EvaluationResponse, Step

mock_response = EvaluationResponse(
Expand All @@ -424,6 +431,12 @@ async def test_emits_events_without_trace_context(self):
client = MagicMock()
client.http_client = AsyncMock()
step = Step(type="llm", name="test-step", input="hello")
set_trace_context_provider(
lambda: {
"trace_id": "a" * 32,
"span_id": "b" * 16,
}
)

with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \
patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \
Expand All @@ -438,8 +451,8 @@ async def test_emits_events_without_trace_context(self):
)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[0][3] is None # trace_id passed as None
assert call_args[0][4] is None # span_id passed as None
assert call_args[0][3] == "a" * 32
assert call_args[0][4] == "b" * 16

@pytest.mark.asyncio
async def test_forwards_trace_headers_to_server(self):
Expand Down Expand Up @@ -492,6 +505,59 @@ async def test_forwards_trace_headers_to_server(self):
assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444"
assert headers["X-Span-Id"] == "eeee5555ffff6666"

@pytest.mark.asyncio
async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self):
"""Server POST should resolve trace headers from the provider when omitted."""
from agent_control_models import Step

controls = [{
"id": 1,
"name": "server-ctrl",
"control": {
"condition": {
"evaluator": {"name": "regex", "config": {"pattern": "test"}},
"selector": {"path": "input"},
},
"action": {"decision": "deny"},
"execution": "server",
},
}]

mock_http_response = MagicMock()
mock_http_response.json.return_value = {
"is_safe": True,
"confidence": 1.0,
"matches": None,
"errors": None,
"non_matches": None,
}
mock_http_response.raise_for_status = MagicMock()

client = MagicMock()
client.http_client = AsyncMock()
client.http_client.post = AsyncMock(return_value=mock_http_response)
step = Step(type="llm", name="test-step", input="hello")
set_trace_context_provider(
lambda: {
"trace_id": "c" * 32,
"span_id": "d" * 16,
}
)

with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]):
await evaluation.check_evaluation_with_local(
client=client,
agent_name="agent-000000000001",
step=step,
stage="pre",
controls=controls,
)

call_kwargs = client.http_client.post.call_args
headers = call_kwargs.kwargs.get("headers", {})
assert headers["X-Trace-Id"] == "c" * 32
assert headers["X-Span-Id"] == "d" * 16


# =============================================================================
# control_decorators non_matches dict conversion
Expand Down
65 changes: 65 additions & 0 deletions sdks/python/tests/test_trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tests for the telemetry trace context provider interface."""

from agent_control.telemetry.trace_context import (
clear_trace_context_provider,
get_trace_context_from_provider,
set_trace_context_provider,
)


def teardown_function() -> None:
clear_trace_context_provider()


def test_get_trace_context_from_provider_returns_registered_context() -> None:
set_trace_context_provider(
lambda: {
"trace_id": "a" * 32,
"span_id": "b" * 16,
}
)

assert get_trace_context_from_provider() == {
"trace_id": "a" * 32,
"span_id": "b" * 16,
}


def test_get_trace_context_from_provider_returns_none_when_unset() -> None:
assert get_trace_context_from_provider() is None


def test_get_trace_context_from_provider_returns_none_when_provider_returns_none() -> None:
set_trace_context_provider(lambda: None)

assert get_trace_context_from_provider() is None


def test_get_trace_context_from_provider_swallows_provider_failures() -> None:
def _raising_provider():
raise RuntimeError("boom")

set_trace_context_provider(_raising_provider)

assert get_trace_context_from_provider() is None


def test_get_trace_context_from_provider_returns_none_for_invalid_shape() -> None:
set_trace_context_provider( # type: ignore[arg-type]
lambda: {
"trace_id": "a" * 32,
}
)

assert get_trace_context_from_provider() is None


def test_get_trace_context_from_provider_returns_none_for_empty_ids() -> None:
set_trace_context_provider(
lambda: {
"trace_id": "",
"span_id": "",
}
)

assert get_trace_context_from_provider() is None
Loading
Loading