Skip to content
Merged
37 changes: 36 additions & 1 deletion agent_assembly/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal
from typing import Literal, cast

# ── Event types (AAASM-1435) ──────────────────────────────────────────────────

Expand Down Expand Up @@ -102,6 +102,41 @@ class AuditEvent:
labels: dict[str, str] = field(default_factory=dict)
call_stack: list[CallStackNode] = field(default_factory=list)

def to_wire_bytes(self) -> bytes:
"""Encode this event to `aa_proto::AuditEvent` wire bytes.

Requires the native `agent_assembly._core` extension. Raises
`ImportError` when the SDK is installed without the native
wheel (pure-Python mode).
"""
try:
from agent_assembly._core import audit_event_to_wire_bytes # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"AuditEvent.to_wire_bytes() requires the native "
"agent_assembly._core extension; reinstall with the "
"native wheel or run `maturin develop` in rust/aa-ffi-python/."
) from exc
return cast(bytes, audit_event_to_wire_bytes(self))

@classmethod
def from_wire_bytes(cls, data: bytes) -> AuditEvent:
"""Decode `aa_proto::AuditEvent` wire bytes into a dataclass.

Requires the native `agent_assembly._core` extension. Raises
`ImportError` when the SDK is installed without the native
wheel.
"""
try:
from agent_assembly._core import audit_event_from_wire_bytes
except ImportError as exc:
raise ImportError(
"AuditEvent.from_wire_bytes() requires the native "
"agent_assembly._core extension; reinstall with the "
"native wheel or run `maturin develop` in rust/aa-ffi-python/."
) from exc
return cast(AuditEvent, audit_event_from_wire_bytes(data))


__all__ = [
"AuditEvent",
Expand Down
169 changes: 169 additions & 0 deletions rust/aa-ffi-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

use aa_core::AuditEntry;
use aa_proto::assembly::audit::v1::AuditEvent;
use aa_proto::assembly::audit::v1::CallStackNode as ProtoCallStackNode;
use aa_proto::assembly::common::v1::ActionType;
use aa_proto::assembly::common::v1::AgentId;
use aa_proto::assembly::common::v1::Decision;
use aa_proto::assembly::policy::v1::CheckActionRequest;
use aa_proto::assembly::policy::v1::CheckActionResponse;
Expand Down Expand Up @@ -403,6 +406,156 @@ fn make_event_id() -> String {
format!("py-{}-{}", now.as_secs(), now.subsec_nanos())
}

fn action_type_from_str(value: &str) -> i32 {
match value {
"llm_call" => ActionType::LlmCall as i32,
"tool_call" => ActionType::ToolCall as i32,
"file_op" | "file_operation" => ActionType::FileOperation as i32,
"network_call" => ActionType::NetworkCall as i32,
"process_exec" => ActionType::ProcessExec as i32,
"agent_spawn" => ActionType::AgentSpawn as i32,
_ => ActionType::ActionUnspecified as i32,
}
}

fn action_type_to_str(value: i32) -> &'static str {
match ActionType::try_from(value).unwrap_or(ActionType::ActionUnspecified) {
ActionType::LlmCall => "llm_call",
ActionType::ToolCall => "tool_call",
ActionType::FileOperation => "file_op",
ActionType::NetworkCall => "network_call",
ActionType::ProcessExec => "process_exec",
ActionType::AgentSpawn => "agent_spawn",
ActionType::ActionUnspecified => "",
}
}

fn decision_from_str(value: &str) -> i32 {
match value {
"allow" => Decision::Allow as i32,
"deny" => Decision::Deny as i32,
"pending" => Decision::Pending as i32,
"redact" => Decision::Redact as i32,
_ => Decision::Unspecified as i32,
}
}

fn audit_event_from_py(event: &PyAny) -> PyResult<AuditEvent> {
let event_id = event.getattr("event_id")?.extract::<String>()?;
let agent_id_str = event.getattr("agent_id")?.extract::<String>()?;
let action_type_str = event.getattr("action_type")?.extract::<String>()?;
let decision_str = event.getattr("decision")?.extract::<String>()?;
let trace_id = event.getattr("trace_id")?.extract::<String>()?;
let span_id = event.getattr("span_id")?.extract::<String>()?;
let parent_span_id = event.getattr("parent_span_id")?.extract::<String>()?;
let labels = event
.getattr("labels")?
.extract::<std::collections::HashMap<String, String>>()?;
let call_stack_py = event.getattr("call_stack")?;
let mut call_stack = Vec::new();
for node in call_stack_py.iter()? {
call_stack.push(call_stack_node_from_py(node?)?);
}
Ok(AuditEvent {
event_id,
agent_id: Some(AgentId {
org_id: String::new(),
team_id: String::new(),
agent_id: agent_id_str,
}),
action_type: action_type_from_str(&action_type_str),
decision: decision_from_str(&decision_str),
trace_id,
span_id,
parent_span_id,
labels,
call_stack,
..Default::default()
})
}

fn audit_event_to_py(py: Python<'_>, event: &AuditEvent) -> PyResult<PyObject> {
let types_module = PyModule::import(py, "agent_assembly.types")?;
let cls = types_module.getattr("AuditEvent")?;
let kwargs = PyDict::new(py);
kwargs.set_item("event_id", &event.event_id)?;
let agent_id_str = event
.agent_id
.as_ref()
.map(|id| id.agent_id.clone())
.unwrap_or_default();
kwargs.set_item("agent_id", agent_id_str)?;
kwargs.set_item("action_type", action_type_to_str(event.action_type))?;
kwargs.set_item("decision", decision_to_str(event.decision))?;
kwargs.set_item("trace_id", &event.trace_id)?;
kwargs.set_item("span_id", &event.span_id)?;
kwargs.set_item("parent_span_id", &event.parent_span_id)?;
let labels = PyDict::new(py);
for (k, v) in &event.labels {
labels.set_item(k, v)?;
}
kwargs.set_item("labels", labels)?;
let call_stack = pyo3::types::PyList::empty(py);
for node in &event.call_stack {
call_stack.append(call_stack_node_to_py(py, node)?)?;
}
kwargs.set_item("call_stack", call_stack)?;
Ok(cls.call((), Some(kwargs))?.into())
}

fn call_stack_node_from_py(node: &PyAny) -> PyResult<ProtoCallStackNode> {
let id = node.getattr("id")?.extract::<String>()?;
let kind = node.getattr("kind")?.extract::<String>()?;
let label = node.getattr("label")?.extract::<String>()?;
let latency_ms = node
.getattr("latency_ms")?
.extract::<Option<i64>>()?
.unwrap_or(0);
let children_py = node.getattr("children")?;
let mut children = Vec::new();
for child in children_py.iter()? {
children.push(call_stack_node_from_py(child?)?);
}
Ok(ProtoCallStackNode {
id,
kind,
label,
latency_ms,
children,
})
}

fn call_stack_node_to_py(py: Python<'_>, node: &ProtoCallStackNode) -> PyResult<PyObject> {
let types_module = PyModule::import(py, "agent_assembly.types")?;
let cls = types_module.getattr("CallStackNode")?;
let kwargs = PyDict::new(py);
kwargs.set_item("id", &node.id)?;
kwargs.set_item("kind", &node.kind)?;
kwargs.set_item("label", &node.label)?;
let latency: Option<i64> = if node.latency_ms == 0 {
None
} else {
Some(node.latency_ms)
};
kwargs.set_item("latency_ms", latency)?;
let children = pyo3::types::PyList::empty(py);
for child in &node.children {
children.append(call_stack_node_to_py(py, child)?)?;
}
kwargs.set_item("children", children)?;
Ok(cls.call((), Some(kwargs))?.into())
}

fn decision_to_str(value: i32) -> &'static str {
match Decision::try_from(value).unwrap_or(Decision::Unspecified) {
Decision::Allow => "allow",
Decision::Deny => "deny",
Decision::Pending => "pending",
Decision::Redact => "redact",
Decision::Unspecified => "",
}
}

fn bytes_to_hex(bytes: &[u8; 16]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut result = String::with_capacity(bytes.len() * 2);
Expand Down Expand Up @@ -540,11 +693,27 @@ fn wait_for_worker_response(
.map_err(|_| WorkerWaitError::Disconnected)
}

#[pyfunction]
fn audit_event_to_wire_bytes(py: Python<'_>, event: &PyAny) -> PyResult<PyObject> {
let proto = audit_event_from_py(event)?;
let encoded = proto.encode_to_vec();
Ok(pyo3::types::PyBytes::new(py, &encoded).into())
}

#[pyfunction]
fn audit_event_from_wire_bytes(py: Python<'_>, data: &pyo3::types::PyBytes) -> PyResult<PyObject> {
let proto = AuditEvent::decode(data.as_bytes())
.map_err(|error| PyValueError::new_err(format!("failed to decode AuditEvent wire bytes: {error}")))?;
audit_event_to_py(py, &proto)
}

#[pymodule]
fn _core(py: Python<'_>, module: &PyModule) -> PyResult<()> {
module.add("PolicyTimeoutError", py.get_type::<PolicyTimeoutError>())?;
module.add_class::<GovernanceEvent>()?;
module.add_class::<PolicyResult>()?;
module.add_class::<RuntimeClient>()?;
module.add_function(wrap_pyfunction!(audit_event_to_wire_bytes, module)?)?;
module.add_function(wrap_pyfunction!(audit_event_from_wire_bytes, module)?)?;
Ok(())
}
111 changes: 111 additions & 0 deletions test/unit/test_audit_event_wire_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Wire-protocol round-trip tests for `AuditEvent` / `CallStackNode`.

Covers the AAASM-1442 bridge between the pure-Python dataclasses in
`agent_assembly.types` and the Rust `aa_proto` encoder via PyO3.

All tests skip cleanly when the native `agent_assembly._core` module is
not built (pure-Python install), matching the existing
`test/bench/test_report_llm_call_roundtrip.py` convention.
"""

from __future__ import annotations

import pytest

from agent_assembly import AuditEvent, CallStackNode

pytest.importorskip(
"agent_assembly._core",
reason="native _core module not built (requires maturin develop)",
)


def test_three_level_call_stack_round_trips_without_data_loss() -> None:
original = AuditEvent(
event_id="evt-1",
agent_id="support-agent",
action_type="llm_call",
decision="allow",
trace_id="trace-1",
span_id="span-1",
parent_span_id="span-0",
labels={"team": "platform", "env": "prod"},
call_stack=[
CallStackNode(
id="n0",
kind="llm",
label="gpt-4o",
latency_ms=300,
children=[
CallStackNode(
id="n1",
kind="tool",
label="gmail.send",
latency_ms=120,
children=[
CallStackNode(
id="n2",
kind="result",
label="200 OK",
latency_ms=5,
),
],
),
],
),
],
)

decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes())

assert decoded == original


def test_legacy_payload_without_call_stack_decodes_to_empty_list() -> None:
"""Events emitted before AAASM-1419 added `call_stack` must still decode.

Proto3 elides default-valued repeated fields on the wire, so an
event with `call_stack=[]` produces bytes indistinguishable from
a pre-1419 event that did not set the field at all. The decoded
dataclass must surface this as the empty list (not None or a
missing attribute).
"""
original = AuditEvent(
event_id="evt-legacy",
agent_id="legacy-agent",
action_type="tool_call",
decision="deny",
)

decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes())

assert decoded.call_stack == []
assert decoded == original


def test_call_stack_node_kind_outside_literal_round_trips_unchanged() -> None:
"""`kind` is proto `string`, not enum β€” the bridge accepts any value.

The Python `CallStackNodeKind` `Literal` narrows the type for
Python authors but does not restrict what arrives from a future
producer that emits a new node category. The wire layer must
preserve such values verbatim instead of normalising them.
"""
original = AuditEvent(
event_id="evt-invalid-kind",
agent_id="future-agent",
action_type="llm_call",
decision="allow",
call_stack=[
CallStackNode(
id="n0",
kind="unknown", # type: ignore[arg-type]
label="future-node-type",
),
],
)

decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes())

assert decoded.call_stack[0].kind == "unknown"
assert decoded == original