diff --git a/agent_assembly/types.py b/agent_assembly/types.py index 01d60b7..de601bc 100644 --- a/agent_assembly/types.py +++ b/agent_assembly/types.py @@ -17,7 +17,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, cast # ── Event types (AAASM-1435) ────────────────────────────────────────────────── @@ -102,6 +102,41 @@ class AuditEvent: labels: dict[str, str] = field(default_factory=dict) call_stack: list[CallStackNode] = field(default_factory=list) + def to_wire_bytes(self) -> bytes: + """Encode this event to `aa_proto::AuditEvent` wire bytes. + + Requires the native `agent_assembly._core` extension. Raises + `ImportError` when the SDK is installed without the native + wheel (pure-Python mode). + """ + try: + from agent_assembly._core import audit_event_to_wire_bytes # type: ignore[import-not-found] + except ImportError as exc: + raise ImportError( + "AuditEvent.to_wire_bytes() requires the native " + "agent_assembly._core extension; reinstall with the " + "native wheel or run `maturin develop` in rust/aa-ffi-python/." + ) from exc + return cast(bytes, audit_event_to_wire_bytes(self)) + + @classmethod + def from_wire_bytes(cls, data: bytes) -> AuditEvent: + """Decode `aa_proto::AuditEvent` wire bytes into a dataclass. + + Requires the native `agent_assembly._core` extension. Raises + `ImportError` when the SDK is installed without the native + wheel. + """ + try: + from agent_assembly._core import audit_event_from_wire_bytes + except ImportError as exc: + raise ImportError( + "AuditEvent.from_wire_bytes() requires the native " + "agent_assembly._core extension; reinstall with the " + "native wheel or run `maturin develop` in rust/aa-ffi-python/." + ) from exc + return cast(AuditEvent, audit_event_from_wire_bytes(data)) + __all__ = [ "AuditEvent", diff --git a/rust/aa-ffi-python/src/lib.rs b/rust/aa-ffi-python/src/lib.rs index 5ae2616..4cbf47f 100644 --- a/rust/aa-ffi-python/src/lib.rs +++ b/rust/aa-ffi-python/src/lib.rs @@ -2,6 +2,9 @@ use aa_core::AuditEntry; use aa_proto::assembly::audit::v1::AuditEvent; +use aa_proto::assembly::audit::v1::CallStackNode as ProtoCallStackNode; +use aa_proto::assembly::common::v1::ActionType; +use aa_proto::assembly::common::v1::AgentId; use aa_proto::assembly::common::v1::Decision; use aa_proto::assembly::policy::v1::CheckActionRequest; use aa_proto::assembly::policy::v1::CheckActionResponse; @@ -403,6 +406,156 @@ fn make_event_id() -> String { format!("py-{}-{}", now.as_secs(), now.subsec_nanos()) } +fn action_type_from_str(value: &str) -> i32 { + match value { + "llm_call" => ActionType::LlmCall as i32, + "tool_call" => ActionType::ToolCall as i32, + "file_op" | "file_operation" => ActionType::FileOperation as i32, + "network_call" => ActionType::NetworkCall as i32, + "process_exec" => ActionType::ProcessExec as i32, + "agent_spawn" => ActionType::AgentSpawn as i32, + _ => ActionType::ActionUnspecified as i32, + } +} + +fn action_type_to_str(value: i32) -> &'static str { + match ActionType::try_from(value).unwrap_or(ActionType::ActionUnspecified) { + ActionType::LlmCall => "llm_call", + ActionType::ToolCall => "tool_call", + ActionType::FileOperation => "file_op", + ActionType::NetworkCall => "network_call", + ActionType::ProcessExec => "process_exec", + ActionType::AgentSpawn => "agent_spawn", + ActionType::ActionUnspecified => "", + } +} + +fn decision_from_str(value: &str) -> i32 { + match value { + "allow" => Decision::Allow as i32, + "deny" => Decision::Deny as i32, + "pending" => Decision::Pending as i32, + "redact" => Decision::Redact as i32, + _ => Decision::Unspecified as i32, + } +} + +fn audit_event_from_py(event: &PyAny) -> PyResult { + let event_id = event.getattr("event_id")?.extract::()?; + let agent_id_str = event.getattr("agent_id")?.extract::()?; + let action_type_str = event.getattr("action_type")?.extract::()?; + let decision_str = event.getattr("decision")?.extract::()?; + let trace_id = event.getattr("trace_id")?.extract::()?; + let span_id = event.getattr("span_id")?.extract::()?; + let parent_span_id = event.getattr("parent_span_id")?.extract::()?; + let labels = event + .getattr("labels")? + .extract::>()?; + let call_stack_py = event.getattr("call_stack")?; + let mut call_stack = Vec::new(); + for node in call_stack_py.iter()? { + call_stack.push(call_stack_node_from_py(node?)?); + } + Ok(AuditEvent { + event_id, + agent_id: Some(AgentId { + org_id: String::new(), + team_id: String::new(), + agent_id: agent_id_str, + }), + action_type: action_type_from_str(&action_type_str), + decision: decision_from_str(&decision_str), + trace_id, + span_id, + parent_span_id, + labels, + call_stack, + ..Default::default() + }) +} + +fn audit_event_to_py(py: Python<'_>, event: &AuditEvent) -> PyResult { + let types_module = PyModule::import(py, "agent_assembly.types")?; + let cls = types_module.getattr("AuditEvent")?; + let kwargs = PyDict::new(py); + kwargs.set_item("event_id", &event.event_id)?; + let agent_id_str = event + .agent_id + .as_ref() + .map(|id| id.agent_id.clone()) + .unwrap_or_default(); + kwargs.set_item("agent_id", agent_id_str)?; + kwargs.set_item("action_type", action_type_to_str(event.action_type))?; + kwargs.set_item("decision", decision_to_str(event.decision))?; + kwargs.set_item("trace_id", &event.trace_id)?; + kwargs.set_item("span_id", &event.span_id)?; + kwargs.set_item("parent_span_id", &event.parent_span_id)?; + let labels = PyDict::new(py); + for (k, v) in &event.labels { + labels.set_item(k, v)?; + } + kwargs.set_item("labels", labels)?; + let call_stack = pyo3::types::PyList::empty(py); + for node in &event.call_stack { + call_stack.append(call_stack_node_to_py(py, node)?)?; + } + kwargs.set_item("call_stack", call_stack)?; + Ok(cls.call((), Some(kwargs))?.into()) +} + +fn call_stack_node_from_py(node: &PyAny) -> PyResult { + let id = node.getattr("id")?.extract::()?; + let kind = node.getattr("kind")?.extract::()?; + let label = node.getattr("label")?.extract::()?; + let latency_ms = node + .getattr("latency_ms")? + .extract::>()? + .unwrap_or(0); + let children_py = node.getattr("children")?; + let mut children = Vec::new(); + for child in children_py.iter()? { + children.push(call_stack_node_from_py(child?)?); + } + Ok(ProtoCallStackNode { + id, + kind, + label, + latency_ms, + children, + }) +} + +fn call_stack_node_to_py(py: Python<'_>, node: &ProtoCallStackNode) -> PyResult { + let types_module = PyModule::import(py, "agent_assembly.types")?; + let cls = types_module.getattr("CallStackNode")?; + let kwargs = PyDict::new(py); + kwargs.set_item("id", &node.id)?; + kwargs.set_item("kind", &node.kind)?; + kwargs.set_item("label", &node.label)?; + let latency: Option = if node.latency_ms == 0 { + None + } else { + Some(node.latency_ms) + }; + kwargs.set_item("latency_ms", latency)?; + let children = pyo3::types::PyList::empty(py); + for child in &node.children { + children.append(call_stack_node_to_py(py, child)?)?; + } + kwargs.set_item("children", children)?; + Ok(cls.call((), Some(kwargs))?.into()) +} + +fn decision_to_str(value: i32) -> &'static str { + match Decision::try_from(value).unwrap_or(Decision::Unspecified) { + Decision::Allow => "allow", + Decision::Deny => "deny", + Decision::Pending => "pending", + Decision::Redact => "redact", + Decision::Unspecified => "", + } +} + fn bytes_to_hex(bytes: &[u8; 16]) -> String { const HEX: &[u8; 16] = b"0123456789abcdef"; let mut result = String::with_capacity(bytes.len() * 2); @@ -540,11 +693,27 @@ fn wait_for_worker_response( .map_err(|_| WorkerWaitError::Disconnected) } +#[pyfunction] +fn audit_event_to_wire_bytes(py: Python<'_>, event: &PyAny) -> PyResult { + let proto = audit_event_from_py(event)?; + let encoded = proto.encode_to_vec(); + Ok(pyo3::types::PyBytes::new(py, &encoded).into()) +} + +#[pyfunction] +fn audit_event_from_wire_bytes(py: Python<'_>, data: &pyo3::types::PyBytes) -> PyResult { + let proto = AuditEvent::decode(data.as_bytes()) + .map_err(|error| PyValueError::new_err(format!("failed to decode AuditEvent wire bytes: {error}")))?; + audit_event_to_py(py, &proto) +} + #[pymodule] fn _core(py: Python<'_>, module: &PyModule) -> PyResult<()> { module.add("PolicyTimeoutError", py.get_type::())?; module.add_class::()?; module.add_class::()?; module.add_class::()?; + module.add_function(wrap_pyfunction!(audit_event_to_wire_bytes, module)?)?; + module.add_function(wrap_pyfunction!(audit_event_from_wire_bytes, module)?)?; Ok(()) } diff --git a/test/unit/test_audit_event_wire_roundtrip.py b/test/unit/test_audit_event_wire_roundtrip.py new file mode 100644 index 0000000..5021ca5 --- /dev/null +++ b/test/unit/test_audit_event_wire_roundtrip.py @@ -0,0 +1,111 @@ +"""Wire-protocol round-trip tests for `AuditEvent` / `CallStackNode`. + +Covers the AAASM-1442 bridge between the pure-Python dataclasses in +`agent_assembly.types` and the Rust `aa_proto` encoder via PyO3. + +All tests skip cleanly when the native `agent_assembly._core` module is +not built (pure-Python install), matching the existing +`test/bench/test_report_llm_call_roundtrip.py` convention. +""" + +from __future__ import annotations + +import pytest + +from agent_assembly import AuditEvent, CallStackNode + +pytest.importorskip( + "agent_assembly._core", + reason="native _core module not built (requires maturin develop)", +) + + +def test_three_level_call_stack_round_trips_without_data_loss() -> None: + original = AuditEvent( + event_id="evt-1", + agent_id="support-agent", + action_type="llm_call", + decision="allow", + trace_id="trace-1", + span_id="span-1", + parent_span_id="span-0", + labels={"team": "platform", "env": "prod"}, + call_stack=[ + CallStackNode( + id="n0", + kind="llm", + label="gpt-4o", + latency_ms=300, + children=[ + CallStackNode( + id="n1", + kind="tool", + label="gmail.send", + latency_ms=120, + children=[ + CallStackNode( + id="n2", + kind="result", + label="200 OK", + latency_ms=5, + ), + ], + ), + ], + ), + ], + ) + + decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes()) + + assert decoded == original + + +def test_legacy_payload_without_call_stack_decodes_to_empty_list() -> None: + """Events emitted before AAASM-1419 added `call_stack` must still decode. + + Proto3 elides default-valued repeated fields on the wire, so an + event with `call_stack=[]` produces bytes indistinguishable from + a pre-1419 event that did not set the field at all. The decoded + dataclass must surface this as the empty list (not None or a + missing attribute). + """ + original = AuditEvent( + event_id="evt-legacy", + agent_id="legacy-agent", + action_type="tool_call", + decision="deny", + ) + + decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes()) + + assert decoded.call_stack == [] + assert decoded == original + + +def test_call_stack_node_kind_outside_literal_round_trips_unchanged() -> None: + """`kind` is proto `string`, not enum — the bridge accepts any value. + + The Python `CallStackNodeKind` `Literal` narrows the type for + Python authors but does not restrict what arrives from a future + producer that emits a new node category. The wire layer must + preserve such values verbatim instead of normalising them. + """ + original = AuditEvent( + event_id="evt-invalid-kind", + agent_id="future-agent", + action_type="llm_call", + decision="allow", + call_stack=[ + CallStackNode( + id="n0", + kind="unknown", # type: ignore[arg-type] + label="future-node-type", + ), + ], + ) + + decoded = AuditEvent.from_wire_bytes(original.to_wire_bytes()) + + assert decoded.call_stack[0].kind == "unknown" + assert decoded == original