diff --git a/.github/workflows/native-core-build.yml b/.github/workflows/native-core-build.yml new file mode 100644 index 0000000..71dc467 --- /dev/null +++ b/.github/workflows/native-core-build.yml @@ -0,0 +1,36 @@ +name: Native Core Build Check + +on: + pull_request: + paths: + - "rust/**" + - "agent_assembly/**" + - ".github/workflows/native-core-build.yml" + workflow_dispatch: + +jobs: + build-native-core: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.13" + + - name: Setup uv + uses: astral-sh/setup-uv@v6 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + + - name: Build native module with maturin + env: + PYO3_USE_ABI3_FORWARD_COMPATIBILITY: "1" + run: uv tool run maturin develop --manifest-path rust/aa-ffi-python/Cargo.toml --release + + - name: Verify native module import + run: | + uv run python -c "from agent_assembly._core import RuntimeClient, GovernanceEvent, PolicyResult" diff --git a/.gitignore b/.gitignore index f9c5eee..cff69bd 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,10 @@ build/ dist/ *.egg-info/ +## Rust build artifacts +rust/target/ +rust/Cargo.lock + ## Script scripts/ci/get-all-tests.sh diff --git a/README.md b/README.md index 716cc25..cb9c6bb 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,27 @@ uv run ruff check . uv run mypy agent_assembly ``` +## Native Core Extension (AAASM-55) + +Build and install the PyO3 extension locally: + +```bash +uv tool run maturin develop --manifest-path rust/aa-ffi-python/Cargo.toml --release +``` + +Validate native module import: + +```python +from agent_assembly._core import RuntimeClient, GovernanceEvent, PolicyResult +``` + +Run opt-in native integration tests: + +```bash +AAASM_RUN_NATIVE_CORE_TESTS=1 uv run pytest test/integration/test_native_core_runtime.py +AAASM_RUN_MATURIN_TESTS=1 uv run pytest test/integration/test_native_core_maturin.py +``` + ## Documentation - Project docs source: `docs/` diff --git a/agent_assembly/__init__.py b/agent_assembly/__init__.py index f5aaab8..fcdf256 100644 --- a/agent_assembly/__init__.py +++ b/agent_assembly/__init__.py @@ -12,6 +12,16 @@ ToolExecutionBlockedError, ) +try: + from agent_assembly._core import ( # type: ignore[attr-defined] + GovernanceEvent, + PolicyResult, + PolicyTimeoutError, + RuntimeClient, + ) +except ImportError: + pass + __version__ = "0.0.0" __all__ = [ @@ -28,3 +38,13 @@ "AdapterValidationError", "ToolExecutionBlockedError", ] + +if "RuntimeClient" in globals(): + __all__.extend( + [ + "RuntimeClient", + "GovernanceEvent", + "PolicyResult", + "PolicyTimeoutError", + ] + ) diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000..8839457 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +members = ["aa-ffi-python"] +resolver = "2" diff --git a/rust/aa-ffi-python/Cargo.toml b/rust/aa-ffi-python/Cargo.toml new file mode 100644 index 0000000..fe93a98 --- /dev/null +++ b/rust/aa-ffi-python/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "aa-ffi-python" +version = "0.0.0" +edition = "2021" +description = "PyO3 bridge crate for the Agent Assembly Python SDK" +license = "MIT" + +[lib] +name = "aa_ffi_python" +crate-type = ["cdylib"] + +[dependencies] +aa-core = { git = "https://github.com/AI-agent-assembly/agent-assembly.git", package = "aa-core", features = ["serde"] } +aa-proto = { git = "https://github.com/AI-agent-assembly/agent-assembly.git", package = "aa-proto" } +once_cell = "1.20" +prost = "0.13" +pyo3 = { version = "0.20", features = ["extension-module"] } +pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1.41", features = ["io-util", "net", "rt-multi-thread", "sync", "time"] } diff --git a/rust/aa-ffi-python/pyproject.toml b/rust/aa-ffi-python/pyproject.toml new file mode 100644 index 0000000..b277e05 --- /dev/null +++ b/rust/aa-ffi-python/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "agent-assembly-core" +version = "0.0.0" +description = "Native core extension for Agent Assembly Python SDK" +requires-python = ">=3.12,<4.0" + +[tool.maturin] +python-source = "../.." +module-name = "agent_assembly._core" +features = ["pyo3/extension-module"] diff --git a/rust/aa-ffi-python/src/lib.rs b/rust/aa-ffi-python/src/lib.rs new file mode 100644 index 0000000..5ae2616 --- /dev/null +++ b/rust/aa-ffi-python/src/lib.rs @@ -0,0 +1,550 @@ +//! aa-ffi-python crate bootstrap. + +use aa_core::AuditEntry; +use aa_proto::assembly::audit::v1::AuditEvent; +use aa_proto::assembly::common::v1::Decision; +use aa_proto::assembly::policy::v1::CheckActionRequest; +use aa_proto::assembly::policy::v1::CheckActionResponse; +use once_cell::sync::Lazy; +use prost::Message; +use pyo3::exceptions::PyValueError; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::UnixStream; +use tokio::runtime::Runtime; +use tokio::sync::{mpsc, oneshot}; +use tokio::time; + +pyo3::create_exception!(_core, PolicyTimeoutError, pyo3::exceptions::PyTimeoutError); + +static TOKIO_RUNTIME: Lazy = Lazy::new(|| { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("aa-ffi-python") + .build() + .expect("failed to build aa-ffi-python tokio runtime") +}); + +const TAG_POLICY_QUERY: u8 = 1; +const TAG_EVENT_REPORT: u8 = 2; +const TAG_HEARTBEAT: u8 = 4; + +const TAG_POLICY_RESPONSE: u8 = 1; +const TAG_ACK: u8 = 3; + +#[pyclass(module = "agent_assembly._core")] +#[derive(Clone)] +struct GovernanceEvent { + #[pyo3(get)] + payload_json: String, + audit_entry: AuditEntry, +} + +#[pymethods] +impl GovernanceEvent { + #[new] + fn new(payload_json: String) -> PyResult { + let audit_entry = serde_json::from_str::(&payload_json).map_err(|error| { + PyValueError::new_err(format!( + "GovernanceEvent payload must be serialized aa_core::AuditEntry JSON: {error}" + )) + })?; + Ok(Self { + payload_json, + audit_entry, + }) + } +} + +#[pyclass(module = "agent_assembly._core")] +#[derive(Clone)] +struct PolicyResult { + #[pyo3(get)] + allowed: bool, + #[pyo3(get)] + reason: String, +} + +#[pymethods] +impl PolicyResult { + #[new] + fn new(allowed: bool, reason: Option) -> Self { + Self { + allowed, + reason: reason.unwrap_or_default(), + } + } +} + +#[pyclass(module = "agent_assembly._core")] +struct RuntimeClient { + #[pyo3(get)] + socket_path: String, + sender: Option>, + closed: Arc, + last_error: Arc>>, +} + +enum WorkerMessage { + Event(GovernanceEvent), + PolicyQuery { + action_json: String, + timeout_ms: u64, + response_tx: oneshot::Sender>, + }, + Close, +} + +#[derive(Clone)] +struct PolicyResultPayload { + allowed: bool, + reason: String, +} + +#[derive(Debug)] +enum WorkerError { + Timeout, + Disconnected, + Transport(String), + Decode(String), +} + +enum WorkerWaitError { + Timeout, + Disconnected, +} + +#[pymethods] +impl RuntimeClient { + #[new] + fn new(socket_path: String) -> Self { + Self { + socket_path, + sender: None, + closed: Arc::new(AtomicBool::new(true)), + last_error: Arc::new(Mutex::new(None)), + } + } + + #[staticmethod] + fn connect(socket_path: String) -> Self { + let _ = &*TOKIO_RUNTIME; + + let (sender, receiver) = mpsc::unbounded_channel::(); + let closed = Arc::new(AtomicBool::new(false)); + let last_error = Arc::new(Mutex::new(None)); + + TOKIO_RUNTIME.spawn(worker_loop( + socket_path.clone(), + receiver, + Arc::clone(&closed), + Arc::clone(&last_error), + )); + + Self { + socket_path, + sender: Some(sender), + closed, + last_error, + } + } + + fn send_event(&self, event: GovernanceEvent) -> PyResult<()> { + ensure_client_open(self.closed.as_ref(), self.last_error.as_ref())?; + let sender = self + .sender + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("runtime event queue is unavailable"))?; + sender + .send(WorkerMessage::Event(event)) + .map_err(|_| PyRuntimeError::new_err("failed to enqueue governance event"))?; + Ok(()) + } + + fn query_policy(&self, py: Python<'_>, action: &PyAny) -> PyResult { + ensure_client_open(self.closed.as_ref(), self.last_error.as_ref())?; + let action_json = serialize_action_to_json(py, action)?; + let timeout_ms = extract_timeout_ms(action); + let sender = self + .sender + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("runtime event queue is unavailable"))?; + + let (response_tx, response_rx) = oneshot::channel::>(); + sender + .send(WorkerMessage::PolicyQuery { + action_json, + timeout_ms, + response_tx, + }) + .map_err(|_| PyRuntimeError::new_err("failed to enqueue policy query"))?; + + let worker_result = py.allow_threads(|| wait_for_worker_response(timeout_ms + 100, response_rx)); + let worker_result = worker_result.map_err(|error| match error { + WorkerWaitError::Timeout => PolicyTimeoutError::new_err("policy query timed out"), + WorkerWaitError::Disconnected => PyRuntimeError::new_err("policy worker disconnected"), + })?; + + let payload = worker_result.map_err(map_worker_error_to_py)?; + Ok(PolicyResult { + allowed: payload.allowed, + reason: payload.reason, + }) + } + + fn close(&mut self) { + if self.closed.swap(true, Ordering::SeqCst) { + return; + } + if let Some(sender) = self.sender.take() { + let _ = sender.send(WorkerMessage::Close); + } + } +} + +async fn worker_loop( + socket_path: String, + mut receiver: mpsc::UnboundedReceiver, + closed: Arc, + last_error: Arc>>, +) { + let stream = match UnixStream::connect(&socket_path).await { + Ok(stream) => stream, + Err(error) => { + set_worker_error(last_error.as_ref(), format!("failed to connect runtime socket: {error}")); + closed.store(true, Ordering::SeqCst); + return; + } + }; + + let (mut reader, mut writer) = stream.into_split(); + if let Err(error) = write_heartbeat(&mut writer).await { + set_worker_error(last_error.as_ref(), format!("failed to send heartbeat: {error:?}")); + closed.store(true, Ordering::SeqCst); + return; + } + + match read_runtime_response(&mut reader).await { + Ok(RuntimeResponse::Ack) => {} + Ok(_) => { + set_worker_error(last_error.as_ref(), "unexpected heartbeat response from runtime".to_string()); + closed.store(true, Ordering::SeqCst); + return; + } + Err(error) => { + set_worker_error(last_error.as_ref(), format!("failed to read heartbeat ack: {error}")); + closed.store(true, Ordering::SeqCst); + return; + } + } + + while let Some(message) = receiver.recv().await { + match message { + WorkerMessage::Event(event) => { + let send_result = send_event_frame(&mut writer, &event).await; + if let Err(error) = send_result { + set_worker_error(last_error.as_ref(), format!("failed to send event: {error:?}")); + break; + } + + match read_runtime_response(&mut reader).await { + Ok(RuntimeResponse::Ack) => {} + Ok(_) => { + set_worker_error(last_error.as_ref(), "unexpected event ack response from runtime".to_string()); + break; + } + Err(error) => { + set_worker_error(last_error.as_ref(), format!("failed to read event ack: {error}")); + break; + } + } + } + WorkerMessage::PolicyQuery { + action_json, + timeout_ms, + response_tx, + } => { + let response = process_policy_query(&mut reader, &mut writer, action_json, timeout_ms).await; + let _ = response_tx.send(response); + } + WorkerMessage::Close => break, + } + } + + closed.store(true, Ordering::SeqCst); +} + +async fn send_event_frame(writer: &mut W, event: &GovernanceEvent) -> Result<(), WorkerError> +where + W: AsyncWrite + Unpin, +{ + let entry = &event.audit_entry; + let event_type = format!("{:?}", entry.event_type()); + let agent_id_hex = bytes_to_hex(entry.agent_id().as_bytes()); + let session_id_hex = bytes_to_hex(entry.session_id().as_bytes()); + let audit_event = AuditEvent { + event_id: make_event_id(), + trace_id: "python-sdk".to_string(), + span_id: "ffi-send-event".to_string(), + decision: Decision::Allow as i32, + labels: std::collections::HashMap::from([ + (String::from("payload_json"), event.payload_json.clone()), + (String::from("event_type"), event_type), + (String::from("agent_id_hex"), agent_id_hex), + (String::from("session_id_hex"), session_id_hex), + (String::from("payload"), entry.payload().to_string()), + ]), + ..Default::default() + }; + let payload = audit_event.encode_to_vec(); + write_frame(writer, TAG_EVENT_REPORT, &payload).await +} + +async fn process_policy_query( + reader: &mut R, + writer: &mut W, + action_json: String, + timeout_ms: u64, +) -> Result +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let request = CheckActionRequest { + trace_id: action_json, + span_id: "ffi-query-policy".to_string(), + ..Default::default() + }; + let payload = request.encode_to_vec(); + write_frame(writer, TAG_POLICY_QUERY, &payload).await?; + + let response = time::timeout( + Duration::from_millis(timeout_ms), + read_runtime_response(reader), + ) + .await + .map_err(|_| WorkerError::Timeout)? + .map_err(|error| WorkerError::Transport(error))?; + + match response { + RuntimeResponse::PolicyResponse(bytes) => { + let policy = CheckActionResponse::decode(bytes.as_slice()) + .map_err(|error| WorkerError::Decode(error.to_string()))?; + let allowed = matches!(policy.decision, x if x == Decision::Allow as i32 || x == Decision::Redact as i32); + Ok(PolicyResultPayload { + allowed, + reason: policy.reason, + }) + } + RuntimeResponse::Ack => Err(WorkerError::Transport( + "runtime returned ACK instead of policy response".to_string(), + )), + RuntimeResponse::Unknown(tag, _) => Err(WorkerError::Transport(format!( + "runtime returned unexpected tag {tag} for policy query" + ))), + } +} + +fn map_worker_error_to_py(error: WorkerError) -> PyErr { + match error { + WorkerError::Timeout => PolicyTimeoutError::new_err("policy query timed out"), + WorkerError::Disconnected => PyRuntimeError::new_err("policy worker disconnected"), + WorkerError::Transport(message) | WorkerError::Decode(message) => PyRuntimeError::new_err(message), + } +} + +fn ensure_client_open(closed: &AtomicBool, last_error: &Mutex>) -> PyResult<()> { + if !closed.load(Ordering::SeqCst) { + return Ok(()); + } + + if let Ok(guard) = last_error.lock() { + if let Some(message) = guard.as_ref() { + return Err(PyRuntimeError::new_err(message.clone())); + } + } + + Err(PyRuntimeError::new_err("runtime client is closed")) +} + +fn extract_timeout_ms(action: &PyAny) -> u64 { + action + .downcast::() + .ok() + .and_then(|dict| dict.get_item("timeout_ms").ok().flatten()) + .and_then(|value| value.extract::().ok()) + .unwrap_or(50) +} + +fn serialize_action_to_json(py: Python<'_>, action: &PyAny) -> PyResult { + let json_module = PyModule::import(py, "json")?; + let dumped = json_module.call_method1("dumps", (action,))?; + dumped.extract::() +} + +fn set_worker_error(last_error: &Mutex>, message: String) { + if let Ok(mut guard) = last_error.lock() { + *guard = Some(message); + } +} + +fn make_event_id() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + format!("py-{}-{}", now.as_secs(), now.subsec_nanos()) +} + +fn bytes_to_hex(bytes: &[u8; 16]) -> String { + const HEX: &[u8; 16] = b"0123456789abcdef"; + let mut result = String::with_capacity(bytes.len() * 2); + for byte in bytes { + result.push(HEX[(byte >> 4) as usize] as char); + result.push(HEX[(byte & 0x0F) as usize] as char); + } + result +} + +enum RuntimeResponse { + Ack, + PolicyResponse(Vec), + Unknown(u8, Vec), +} + +async fn write_heartbeat(writer: &mut W) -> Result<(), WorkerError> +where + W: AsyncWrite + Unpin, +{ + writer + .write_u8(TAG_HEARTBEAT) + .await + .map_err(|error| WorkerError::Transport(error.to_string()))?; + writer + .flush() + .await + .map_err(|error| WorkerError::Transport(error.to_string())) +} + +async fn write_frame(writer: &mut W, tag: u8, payload: &[u8]) -> Result<(), WorkerError> +where + W: AsyncWrite + Unpin, +{ + writer + .write_u8(tag) + .await + .map_err(|error| WorkerError::Transport(error.to_string()))?; + write_varint(writer, payload.len() as u64).await?; + writer + .write_all(payload) + .await + .map_err(|error| WorkerError::Transport(error.to_string()))?; + writer + .flush() + .await + .map_err(|error| WorkerError::Transport(error.to_string())) +} + +async fn read_runtime_response(reader: &mut R) -> Result +where + R: AsyncRead + Unpin, +{ + let tag = reader.read_u8().await.map_err(|error| error.to_string())?; + match tag { + TAG_ACK => { + let _ = read_length_delimited(reader).await?; + Ok(RuntimeResponse::Ack) + } + TAG_POLICY_RESPONSE => { + let payload = read_length_delimited(reader).await?; + Ok(RuntimeResponse::PolicyResponse(payload)) + } + other => { + let payload = read_length_delimited(reader).await?; + Ok(RuntimeResponse::Unknown(other, payload)) + } + } +} + +async fn read_length_delimited(reader: &mut R) -> Result, String> +where + R: AsyncRead + Unpin, +{ + let len = read_varint(reader).await? as usize; + let mut payload = vec![0u8; len]; + reader + .read_exact(&mut payload) + .await + .map_err(|error| error.to_string())?; + Ok(payload) +} + +async fn read_varint(reader: &mut R) -> Result +where + R: AsyncRead + Unpin, +{ + let mut result: u64 = 0; + let mut shift = 0u32; + loop { + let byte = reader.read_u8().await.map_err(|error| error.to_string())?; + result |= ((byte & 0x7F) as u64) << shift; + if byte & 0x80 == 0 { + break; + } + shift += 7; + if shift >= 64 { + return Err("varint too long".to_string()); + } + } + Ok(result) +} + +async fn write_varint(writer: &mut W, mut value: u64) -> Result<(), WorkerError> +where + W: AsyncWrite + Unpin, +{ + loop { + let byte = (value & 0x7F) as u8; + value >>= 7; + if value == 0 { + writer + .write_u8(byte) + .await + .map_err(|error| WorkerError::Transport(error.to_string()))?; + break; + } + + writer + .write_u8(byte | 0x80) + .await + .map_err(|error| WorkerError::Transport(error.to_string()))?; + } + + Ok(()) +} + +fn wait_for_worker_response( + timeout_ms: u64, + response_rx: oneshot::Receiver>, +) -> Result, WorkerWaitError> { + TOKIO_RUNTIME + .block_on(async move { time::timeout(Duration::from_millis(timeout_ms), response_rx).await }) + .map_err(|_| WorkerWaitError::Timeout)? + .map_err(|_| WorkerWaitError::Disconnected) +} + +#[pymodule] +fn _core(py: Python<'_>, module: &PyModule) -> PyResult<()> { + module.add("PolicyTimeoutError", py.get_type::())?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + Ok(()) +} diff --git a/test/integration/test_native_core_maturin.py b/test/integration/test_native_core_maturin.py new file mode 100644 index 0000000..ee09135 --- /dev/null +++ b/test/integration/test_native_core_maturin.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os +import subprocess +import sys + +import pytest + + +@pytest.mark.integration +def test_maturin_develop_exposes_runtime_client() -> None: + if os.getenv("AAASM_RUN_MATURIN_TESTS") != "1": + pytest.skip("Set AAASM_RUN_MATURIN_TESTS=1 to run maturin integration smoke tests.") + + command = [ + "uv", + "tool", + "run", + "maturin", + "develop", + "--manifest-path", + "rust/aa-ffi-python/Cargo.toml", + "--release", + ] + env = os.environ.copy() + env.setdefault("PYO3_USE_ABI3_FORWARD_COMPATIBILITY", "1") + subprocess.run(command, check=True, env=env) + + from agent_assembly._core import RuntimeClient + + assert RuntimeClient is not None + assert hasattr(RuntimeClient, "connect") + assert sys.modules.get("agent_assembly._core") is not None diff --git a/test/integration/test_native_core_runtime.py b/test/integration/test_native_core_runtime.py new file mode 100644 index 0000000..114e683 --- /dev/null +++ b/test/integration/test_native_core_runtime.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import gc +import json +import os +import socket +import tempfile +import threading +import time +import tracemalloc +from pathlib import Path + +import pytest + + +class MockRuntimeServer: + def __init__(self, *, policy_delay_ms: int = 0) -> None: + self._policy_delay_ms = policy_delay_ms + self._stop = threading.Event() + self._ready = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + self._socket_dir = tempfile.TemporaryDirectory(prefix="aaasm55-") + self.socket_path = str(Path(self._socket_dir.name) / "runtime.sock") + + def start(self) -> None: + self._thread.start() + if not self._ready.wait(timeout=5): + raise RuntimeError("mock runtime server did not start") + + def close(self) -> None: + self._stop.set() + self._thread.join(timeout=5) + self._socket_dir.cleanup() + + def _run(self) -> None: + if os.path.exists(self.socket_path): + os.remove(self.socket_path) + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server: + server.bind(self.socket_path) + server.listen(1) + self._ready.set() + + conn, _ = server.accept() + with conn: + conn.settimeout(0.2) + + # Heartbeat frame from client is a single tag byte. + tag = self._read_u8(conn) + if tag != 4: + return + self._write_frame(conn, 3, b"") + + while not self._stop.is_set(): + try: + tag = self._read_u8(conn) + except (TimeoutError, OSError): + continue + + if tag in (1, 2, 3): + _ = self._read_length_delimited(conn) + + if tag == 1: + if self._policy_delay_ms > 0: + time.sleep(self._policy_delay_ms / 1000.0) + # CheckActionResponse { decision: ALLOW } protobuf payload + try: + self._write_frame(conn, 1, b"\x08\x01") + except OSError: + return + elif tag == 2: + try: + self._write_frame(conn, 3, b"") + except OSError: + return + elif tag == 3: + try: + self._write_frame(conn, 3, b"") + except OSError: + return + else: + return + + @staticmethod + def _read_u8(conn: socket.socket) -> int: + chunk = conn.recv(1) + if not chunk: + raise OSError("socket closed") + return chunk[0] + + @staticmethod + def _read_varint(conn: socket.socket) -> int: + result = 0 + shift = 0 + while True: + byte = MockRuntimeServer._read_u8(conn) + result |= (byte & 0x7F) << shift + if byte & 0x80 == 0: + return result + shift += 7 + if shift >= 64: + raise ValueError("varint too long") + + @staticmethod + def _read_exact(conn: socket.socket, size: int) -> bytes: + data = b"" + while len(data) < size: + chunk = conn.recv(size - len(data)) + if not chunk: + raise OSError("socket closed") + data += chunk + return data + + @classmethod + def _read_length_delimited(cls, conn: socket.socket) -> bytes: + size = cls._read_varint(conn) + return cls._read_exact(conn, size) + + @staticmethod + def _write_varint(conn: socket.socket, value: int) -> None: + while True: + byte = value & 0x7F + value >>= 7 + if value == 0: + conn.sendall(bytes([byte])) + return + conn.sendall(bytes([byte | 0x80])) + + @classmethod + def _write_frame(cls, conn: socket.socket, tag: int, payload: bytes) -> None: + conn.sendall(bytes([tag])) + cls._write_varint(conn, len(payload)) + if payload: + conn.sendall(payload) + + +def make_audit_entry_payload(index: int, *, worker_id: int = 0) -> str: + return json.dumps( + { + "seq": index, + "timestamp_ns": 1_700_000_000_000_000_000 + index, + "event_type": "ToolCallIntercepted", + "agent_id": [worker_id % 255] * 16, + "session_id": [index % 255] * 16, + "payload": json.dumps({"index": index, "worker": worker_id}), + "previous_hash": [0] * 32, + "entry_hash": [0] * 32, + } + ) + + +@pytest.fixture() +def native_core(): + if os.getenv("AAASM_RUN_NATIVE_CORE_TESTS") != "1": + pytest.skip("Set AAASM_RUN_NATIVE_CORE_TESTS=1 to run native core runtime tests.") + return pytest.importorskip("agent_assembly._core") + + +@pytest.mark.integration +def test_send_event_is_non_blocking(native_core) -> None: + server = MockRuntimeServer() + server.start() + + client = native_core.RuntimeClient.connect(server.socket_path) + try: + events = [native_core.GovernanceEvent(make_audit_entry_payload(index)) for index in range(500)] + start = time.perf_counter() + for event in events: + client.send_event(event) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + assert elapsed_ms < 50.0 + finally: + client.close() + server.close() + + +@pytest.mark.integration +def test_query_policy_returns_quickly_and_times_out(native_core) -> None: + fast_server = MockRuntimeServer(policy_delay_ms=0) + fast_server.start() + + fast_client = native_core.RuntimeClient.connect(fast_server.socket_path) + try: + start = time.perf_counter() + result = fast_client.query_policy({"action": "tool.call", "timeout_ms": 50}) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + assert elapsed_ms < 50.0 + assert result.allowed is True + finally: + fast_client.close() + fast_server.close() + + slow_server = MockRuntimeServer(policy_delay_ms=200) + slow_server.start() + + slow_client = native_core.RuntimeClient.connect(slow_server.socket_path) + try: + with pytest.raises(native_core.PolicyTimeoutError): + slow_client.query_policy({"action": "slow.call", "timeout_ms": 10}) + finally: + slow_client.close() + slow_server.close() + + +@pytest.mark.integration +def test_runtime_client_has_no_thread_deadlock(native_core) -> None: + server = MockRuntimeServer() + server.start() + + client = native_core.RuntimeClient.connect(server.socket_path) + errors: list[Exception] = [] + + def worker(worker_id: int) -> None: + try: + for index in range(100): + client.send_event( + native_core.GovernanceEvent(make_audit_entry_payload(index, worker_id=worker_id)) + ) + client.query_policy({"action": "tool.call", "timeout_ms": 50}) + except Exception as error: # pragma: no cover - runtime guard + errors.append(error) + + threads = [threading.Thread(target=worker, args=(worker_id,)) for worker_id in range(8)] + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=5) + + try: + assert all(not thread.is_alive() for thread in threads) + assert errors == [] + finally: + client.close() + server.close() + + +@pytest.mark.integration +def test_runtime_client_tracemalloc_leak_guard(native_core) -> None: + server = MockRuntimeServer() + server.start() + + client = native_core.RuntimeClient.connect(server.socket_path) + tracemalloc.start() + baseline_current, _ = tracemalloc.get_traced_memory() + try: + for index in range(10_000): + client.send_event(native_core.GovernanceEvent(make_audit_entry_payload(index))) + gc.collect() + current, _ = tracemalloc.get_traced_memory() + assert current - baseline_current < 1_000_000 + finally: + tracemalloc.stop() + client.close() + server.close() diff --git a/test/unit/test_init_exports.py b/test/unit/test_init_exports.py new file mode 100644 index 0000000..81aeb3a --- /dev/null +++ b/test/unit/test_init_exports.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import importlib +import sys +import types + + +def test_all_includes_native_core_symbols_when_extension_is_available() -> None: + fake_core = types.ModuleType("agent_assembly._core") + + class RuntimeClient: ... + + class GovernanceEvent: ... + + class PolicyResult: ... + + class PolicyTimeoutError(Exception): ... + + fake_core.RuntimeClient = RuntimeClient + fake_core.GovernanceEvent = GovernanceEvent + fake_core.PolicyResult = PolicyResult + fake_core.PolicyTimeoutError = PolicyTimeoutError + + original_package = sys.modules.pop("agent_assembly", None) + original_core = sys.modules.get("agent_assembly._core") + + try: + sys.modules["agent_assembly._core"] = fake_core + module = importlib.import_module("agent_assembly") + + assert "RuntimeClient" in module.__all__ + assert "GovernanceEvent" in module.__all__ + assert "PolicyResult" in module.__all__ + assert "PolicyTimeoutError" in module.__all__ + finally: + sys.modules.pop("agent_assembly", None) + if original_package is not None: + sys.modules["agent_assembly"] = original_package + + if original_core is None: + sys.modules.pop("agent_assembly._core", None) + else: + sys.modules["agent_assembly._core"] = original_core