diff --git a/README.md b/README.md index 65f4a44..716cc25 100644 --- a/README.md +++ b/README.md @@ -29,19 +29,20 @@ from agent_assembly import init_assembly async def main() -> None: - client = init_assembly( + context = init_assembly( gateway_url="http://localhost:8080", + api_key="required-api-key", agent_id="my-agent-001", - api_key="optional-api-key", + mode="auto", ) try: - registration = await client.register_agent() - decision = await client.check_policy_compliance("tool.call") + registration = await context.client.register_agent() + decision = await context.client.check_policy_compliance("tool.call") print(registration) print(decision) finally: - client.close() + context.shutdown() asyncio.run(main()) @@ -49,7 +50,7 @@ asyncio.run(main()) ## Public API -- `init_assembly(gateway_url, agent_id, api_key=None) -> GatewayClient` +- `init_assembly(gateway_url, api_key, agent_id=None, mode="auto") -> AssemblyContext` - `GatewayClient.register_agent() -> dict` - `GatewayClient.check_policy_compliance(action: str) -> dict` - Exceptions: `AssemblyError`, `AgentError`, `PolicyError`, `GatewayError`, `ConfigurationError` @@ -62,7 +63,7 @@ from agent_assembly import init_assembly from agent_assembly.exceptions import ConfigurationError try: - client = init_assembly(gateway_url="", agent_id="my-agent-001") + context = init_assembly(gateway_url="", api_key="my-api-key", agent_id="my-agent-001") except ConfigurationError as exc: print(f"Invalid configuration: {exc}") ``` diff --git a/agent_assembly/__init__.py b/agent_assembly/__init__.py index 071f5cb..f5aaab8 100644 --- a/agent_assembly/__init__.py +++ b/agent_assembly/__init__.py @@ -1,7 +1,7 @@ """Agent Assembly Python SDK.""" from agent_assembly.adapters import FrameworkAdapter, GovernanceInterceptor -from agent_assembly.core import init_assembly +from agent_assembly.core import AssemblyContext, init_assembly from agent_assembly.exceptions import ( AdapterValidationError, AgentError, @@ -17,6 +17,7 @@ __all__ = [ "__version__", "init_assembly", + "AssemblyContext", "GovernanceInterceptor", "FrameworkAdapter", "AssemblyError", diff --git a/agent_assembly/adapters/crewai/patch.py b/agent_assembly/adapters/crewai/patch.py index 261a7ec..d99f24a 100644 --- a/agent_assembly/adapters/crewai/patch.py +++ b/agent_assembly/adapters/crewai/patch.py @@ -34,6 +34,17 @@ def apply(self) -> bool: _apply_task_execute_sync_patch(task_cls, self.callback_handler) return True + def revert(self) -> None: + """Revert CrewAI runtime monkey patches when available.""" + base_tool_cls = _load_crewai_basetool_class() + if base_tool_cls is not None: + _revert_basetool_run_patch(base_tool_cls) + + task_cls = _load_crewai_task_class() + if task_cls is not None: + _revert_task_execute_sync_patch(task_cls) + return None + def _load_crewai_basetool_class() -> type[Any] | None: try: @@ -263,6 +274,21 @@ def patched_run(self: Any, *args: Any, **kwargs: Any) -> Any: setattr(base_tool_cls, _TOOLS_PATCHED_FLAG, True) +def _revert_basetool_run_patch(base_tool_cls: type[Any]) -> None: + if not getattr(base_tool_cls, _TOOLS_PATCHED_FLAG, False): + return None + + original_run = getattr(base_tool_cls, _ORIGINAL_TOOL_RUN, None) + if callable(original_run): + setattr(base_tool_cls, "run", original_run) + + if hasattr(base_tool_cls, _ORIGINAL_TOOL_RUN): + delattr(base_tool_cls, _ORIGINAL_TOOL_RUN) + if hasattr(base_tool_cls, _TOOLS_PATCHED_FLAG): + delattr(base_tool_cls, _TOOLS_PATCHED_FLAG) + return None + + def _record_task_start(callback_handler: Any, task: Any) -> None: method = getattr(callback_handler, "record", None) if callable(method): @@ -313,3 +339,18 @@ def patched_execute_sync(self: Any, *args: Any, **kwargs: Any) -> Any: setattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC, original_execute_sync) setattr(task_cls, "execute_sync", patched_execute_sync) setattr(task_cls, _TASK_PATCHED_FLAG, True) + + +def _revert_task_execute_sync_patch(task_cls: type[Any]) -> None: + if not getattr(task_cls, _TASK_PATCHED_FLAG, False): + return None + + original_execute_sync = getattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC, None) + if callable(original_execute_sync): + setattr(task_cls, "execute_sync", original_execute_sync) + + if hasattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC): + delattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC) + if hasattr(task_cls, _TASK_PATCHED_FLAG): + delattr(task_cls, _TASK_PATCHED_FLAG) + return None diff --git a/agent_assembly/adapters/langchain/__init__.py b/agent_assembly/adapters/langchain/__init__.py index fa6cf78..15a820a 100644 --- a/agent_assembly/adapters/langchain/__init__.py +++ b/agent_assembly/adapters/langchain/__init__.py @@ -3,6 +3,7 @@ from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler from agent_assembly.adapters.langgraph import LangGraphPatch from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile +from agent_assembly.adapters.langchain.patch import LangChainPatch from agent_assembly.adapters.langchain.runtime import ( auto_inject_callback_handler, get_active_callback_handler, @@ -10,6 +11,7 @@ __all__ = [ "AssemblyCallbackHandler", + "LangChainPatch", "LangGraphPatch", "patch_stategraph_compile", "auto_inject_callback_handler", diff --git a/agent_assembly/adapters/langchain/patch.py b/agent_assembly/adapters/langchain/patch.py new file mode 100644 index 0000000..3111bae --- /dev/null +++ b/agent_assembly/adapters/langchain/patch.py @@ -0,0 +1,36 @@ +"""LangChain patch orchestrator.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler +from agent_assembly.adapters.langchain.runtime import ( + _reset_runtime_state_for_tests, + auto_inject_callback_handler, +) +from agent_assembly.adapters.pydantic_ai.patch import set_process_agent_id + + +@dataclass(slots=True) +class LangChainPatch: + """Apply/remove LangChain callback handler injection.""" + + interceptor: Any + process_agent_id: str | None = None + + def apply(self) -> bool: + auto_inject_callback_handler( + self.interceptor, + process_agent_id=self.process_agent_id, + ) + return True + + def revert(self) -> None: + set_process_agent_id(None) + _reset_runtime_state_for_tests() + return None + + +__all__ = ["LangChainPatch", "AssemblyCallbackHandler"] diff --git a/agent_assembly/adapters/langchain/runtime.py b/agent_assembly/adapters/langchain/runtime.py index 70ef52d..6b0fa0c 100644 --- a/agent_assembly/adapters/langchain/runtime.py +++ b/agent_assembly/adapters/langchain/runtime.py @@ -5,10 +5,8 @@ from threading import Lock from typing import Any -from agent_assembly.adapters.crewai.patch import CrewAIPatch from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler -from agent_assembly.adapters.langgraph import LangGraphPatch -from agent_assembly.adapters.pydantic_ai.patch import PydanticAIPatch, set_process_agent_id +from agent_assembly.adapters.pydantic_ai.patch import set_process_agent_id _ACTIVE_CALLBACK_HANDLER: AssemblyCallbackHandler | None = None _RUNTIME_LOCK = Lock() @@ -27,16 +25,10 @@ def auto_inject_callback_handler( set_process_agent_id(process_agent_id) if _ACTIVE_CALLBACK_HANDLER is not None: - LangGraphPatch(_ACTIVE_CALLBACK_HANDLER).apply() - CrewAIPatch(interceptor).apply() - PydanticAIPatch(interceptor).apply() return _ACTIVE_CALLBACK_HANDLER handler = AssemblyCallbackHandler(interceptor) _ACTIVE_CALLBACK_HANDLER = handler - LangGraphPatch(handler).apply() - CrewAIPatch(interceptor).apply() - PydanticAIPatch(interceptor).apply() return handler diff --git a/agent_assembly/adapters/langgraph/patch.py b/agent_assembly/adapters/langgraph/patch.py index 51aa7db..cdca45d 100644 --- a/agent_assembly/adapters/langgraph/patch.py +++ b/agent_assembly/adapters/langgraph/patch.py @@ -29,6 +29,14 @@ def apply(self) -> bool: _apply_stategraph_compile_patch(state_graph_cls, self.callback_handler) return True + def revert(self) -> None: + """Revert state graph compile patch when it is currently active.""" + state_graph_cls = _load_stategraph_class() + if state_graph_cls is None: + return None + _revert_stategraph_compile_patch(state_graph_cls) + return None + def patch_stategraph_compile(callback_handler: Any) -> bool: """Backward-compatible helper used by existing runtime wiring.""" @@ -239,6 +247,21 @@ def patched_compile(self: Any, *args: Any, **kwargs: Any) -> Any: setattr(state_graph_cls, _PATCHED_FLAG, True) +def _revert_stategraph_compile_patch(state_graph_cls: type[Any]) -> None: + if not getattr(state_graph_cls, _PATCHED_FLAG, False): + return None + + original_compile = getattr(state_graph_cls, _ORIGINAL_COMPILE, None) + if callable(original_compile): + setattr(state_graph_cls, "compile", original_compile) + + if hasattr(state_graph_cls, _ORIGINAL_COMPILE): + delattr(state_graph_cls, _ORIGINAL_COMPILE) + if hasattr(state_graph_cls, _PATCHED_FLAG): + delattr(state_graph_cls, _PATCHED_FLAG) + return None + + def _wrap_graph_invoke_fallback(compiled_graph: Any, callback_handler: Any) -> None: invoke = getattr(compiled_graph, "invoke", None) if not callable(invoke) or getattr(invoke, _INVOKE_WRAPPED_FLAG, False): diff --git a/agent_assembly/adapters/mcp/__init__.py b/agent_assembly/adapters/mcp/__init__.py new file mode 100644 index 0000000..6dd19af --- /dev/null +++ b/agent_assembly/adapters/mcp/__init__.py @@ -0,0 +1,5 @@ +"""MCP adapter package.""" + +from agent_assembly.adapters.mcp.patch import MCPClientPatch + +__all__ = ["MCPClientPatch"] diff --git a/agent_assembly/adapters/mcp/patch.py b/agent_assembly/adapters/mcp/patch.py new file mode 100644 index 0000000..4923f0b --- /dev/null +++ b/agent_assembly/adapters/mcp/patch.py @@ -0,0 +1,25 @@ +"""MCP client patch module.""" + +from __future__ import annotations + +from dataclasses import dataclass +import importlib.util +from typing import Any + + +@dataclass(slots=True) +class MCPClientPatch: + """Patch placeholder for MCP client interception.""" + + callback_handler: Any + + def apply(self) -> bool: + _ = self.callback_handler + return _is_mcp_available() + + def revert(self) -> None: + return None + + +def _is_mcp_available() -> bool: + return importlib.util.find_spec("mcp") is not None diff --git a/agent_assembly/adapters/openai_agents/__init__.py b/agent_assembly/adapters/openai_agents/__init__.py new file mode 100644 index 0000000..bcfe1ac --- /dev/null +++ b/agent_assembly/adapters/openai_agents/__init__.py @@ -0,0 +1,5 @@ +"""OpenAI Agents adapter package.""" + +from agent_assembly.adapters.openai_agents.patch import OpenAIAgentsPatch + +__all__ = ["OpenAIAgentsPatch"] diff --git a/agent_assembly/adapters/openai_agents/patch.py b/agent_assembly/adapters/openai_agents/patch.py new file mode 100644 index 0000000..65b8372 --- /dev/null +++ b/agent_assembly/adapters/openai_agents/patch.py @@ -0,0 +1,25 @@ +"""OpenAI Agents patch module.""" + +from __future__ import annotations + +from dataclasses import dataclass +import importlib.util +from typing import Any + + +@dataclass(slots=True) +class OpenAIAgentsPatch: + """Patch placeholder for OpenAI Agents SDK interception.""" + + callback_handler: Any + + def apply(self) -> bool: + _ = self.callback_handler + return _is_openai_agents_available() + + def revert(self) -> None: + return None + + +def _is_openai_agents_available() -> bool: + return importlib.util.find_spec("openai.agents") is not None diff --git a/agent_assembly/adapters/pydantic_ai/patch.py b/agent_assembly/adapters/pydantic_ai/patch.py index 79139d4..96140ae 100644 --- a/agent_assembly/adapters/pydantic_ai/patch.py +++ b/agent_assembly/adapters/pydantic_ai/patch.py @@ -33,6 +33,14 @@ def apply(self) -> bool: _apply_tool_run_patch(tool_cls, self.callback_handler) return True + def revert(self) -> None: + """Revert Pydantic AI tool patch when available.""" + tool_cls = _load_pydantic_ai_tool_class() + if tool_cls is not None: + _revert_tool_run_patch(tool_cls) + set_process_agent_id(None) + return None + class AssemblyModelWrapper: """Optional model wrapper for LLM input scan-forward interception.""" @@ -131,6 +139,21 @@ async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any: setattr(tool_cls, _TOOLS_PATCHED_FLAG, True) +def _revert_tool_run_patch(tool_cls: type[Any]) -> None: + if not getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): + return None + + original_run = getattr(tool_cls, _ORIGINAL_TOOL_RUN, None) + if callable(original_run): + setattr(tool_cls, "_run", original_run) + + if hasattr(tool_cls, _ORIGINAL_TOOL_RUN): + delattr(tool_cls, _ORIGINAL_TOOL_RUN) + if hasattr(tool_cls, _TOOLS_PATCHED_FLAG): + delattr(tool_cls, _TOOLS_PATCHED_FLAG) + return None + + def set_process_agent_id(agent_id: str | None) -> None: global _PROCESS_AGENT_ID _PROCESS_AGENT_ID = agent_id diff --git a/agent_assembly/core/__init__.py b/agent_assembly/core/__init__.py index a7cf528..ab3a226 100644 --- a/agent_assembly/core/__init__.py +++ b/agent_assembly/core/__init__.py @@ -1,5 +1,5 @@ """Core module for Agent Assembly initialization.""" -from agent_assembly.core.assembly import init_assembly +from agent_assembly.core.assembly import AssemblyContext, init_assembly -__all__ = ["init_assembly"] +__all__ = ["init_assembly", "AssemblyContext"] diff --git a/agent_assembly/core/assembly.py b/agent_assembly/core/assembly.py index c92e5f6..2347681 100644 --- a/agent_assembly/core/assembly.py +++ b/agent_assembly/core/assembly.py @@ -2,41 +2,276 @@ from __future__ import annotations -from typing import Optional +from dataclasses import dataclass, field +import importlib.util +import sys +from threading import Lock +from typing import Any, Callable, Literal, Protocol -from agent_assembly.adapters.langchain.runtime import auto_inject_callback_handler +from agent_assembly.adapters.crewai.patch import CrewAIPatch +from agent_assembly.adapters.langchain.patch import LangChainPatch +from agent_assembly.adapters.langchain.runtime import get_active_callback_handler +from agent_assembly.adapters.langgraph import LangGraphPatch +from agent_assembly.adapters.mcp import MCPClientPatch +from agent_assembly.adapters.openai_agents import OpenAIAgentsPatch +from agent_assembly.adapters.pydantic_ai.patch import PydanticAIPatch from agent_assembly.client.gateway import GatewayClient -from agent_assembly.exceptions import ConfigurationError +from agent_assembly.exceptions import AssemblyError, ConfigurationError + +RuntimeMode = Literal["auto", "ebpf", "proxy", "sdk-only"] +NetworkMode = Literal["ebpf", "proxy", "sdk-only"] + +_DEFAULT_AGENT_ID = "agent-assembly-default" +_VALID_RUNTIME_MODES = {"auto", "ebpf", "proxy", "sdk-only"} +_INIT_LOCK = Lock() +_ACTIVE_CONTEXT: AssemblyContext | None = None + + +class RuntimePatch(Protocol): + def apply(self) -> bool: ... + + def revert(self) -> None: ... + + +@dataclass(slots=True) +class AssemblyContext: + """Represents an active assembly runtime session.""" + + client: GatewayClient + patches: list[RuntimePatch] + network_mode: NetworkMode + _network_shutdown: Callable[[], None] + _lock: Lock = field(default_factory=Lock, init=False, repr=False) + _is_shutdown: bool = field(default=False, init=False, repr=False) + + def __enter__(self) -> AssemblyContext: + return self + + def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: + del exc_type, exc, tb + self.shutdown() + return False + + @property + def is_shutdown(self) -> bool: + with self._lock: + return self._is_shutdown + + def shutdown(self) -> None: + shutdown_errors: list[str] = [] + should_clear_active_context = False + with self._lock: + if self._is_shutdown: + return None + + try: + self._network_shutdown() + except Exception as error: # pragma: no cover - defensive guard + shutdown_errors.append(f"network shutdown failed: {error}") + + for patch in reversed(self.patches): + try: + patch.revert() + except Exception as error: # pragma: no cover - defensive guard + shutdown_errors.append(f"patch revert failed: {error}") + + try: + self.client.close() + except Exception as error: # pragma: no cover - defensive guard + shutdown_errors.append(f"client close failed: {error}") + + self._is_shutdown = True + should_clear_active_context = True + + if should_clear_active_context: + _clear_active_context(self) + if shutdown_errors: + raise AssemblyError("; ".join(shutdown_errors)) def init_assembly( gateway_url: str, - agent_id: str, - api_key: Optional[str] = None, -) -> GatewayClient: - """ - Initialize the Agent Assembly SDK. - - Args: - gateway_url: URL of the governance gateway - agent_id: Unique identifier for the agent - api_key: Optional API key for authentication - - Returns: - Configured GatewayClient instance - - Raises: - ConfigurationError: If initialization fails - """ + api_key: str, + agent_id: str | None = None, + mode: RuntimeMode = "auto", + **kwargs: Any, +) -> AssemblyContext: + """Initialize the Agent Assembly SDK runtime for this process.""" + del kwargs + _validate_inputs(gateway_url=gateway_url, api_key=api_key, mode=mode) + resolved_agent_id = agent_id or _DEFAULT_AGENT_ID + + global _ACTIVE_CONTEXT + with _INIT_LOCK: + if _ACTIVE_CONTEXT is not None and not _ACTIVE_CONTEXT.is_shutdown: + _validate_active_context_compatibility( + _ACTIVE_CONTEXT, + gateway_url=gateway_url, + api_key=api_key, + agent_id=resolved_agent_id, + ) + return _ACTIVE_CONTEXT + + client = GatewayClient( + gateway_url=gateway_url, + agent_id=resolved_agent_id, + api_key=api_key, + ) + + patches: list[RuntimePatch] = [] + network_mode: NetworkMode = "sdk-only" + network_shutdown: Callable[[], None] = _noop_shutdown + try: + patches = _apply_runtime_patches( + client=client, + process_agent_id=resolved_agent_id, + ) + network_mode, network_shutdown = _start_network_layer(client=client, mode=mode) + except Exception as error: + _revert_patches(patches) + client.close() + raise ConfigurationError(f"Failed to initialize assembly runtime: {error}") from error + + context = AssemblyContext( + client=client, + patches=patches, + network_mode=network_mode, + _network_shutdown=network_shutdown, + ) + _ACTIVE_CONTEXT = context + return context + + +def _validate_inputs(*, gateway_url: str, api_key: str, mode: RuntimeMode) -> None: if not gateway_url: raise ConfigurationError("gateway_url is required") - if not agent_id: - raise ConfigurationError("agent_id is required") - - client = GatewayClient( - gateway_url=gateway_url, - agent_id=agent_id, - api_key=api_key, - ) - auto_inject_callback_handler(interceptor=object(), process_agent_id=agent_id) - return client + if not api_key: + raise ConfigurationError("api_key is required") + if mode not in _VALID_RUNTIME_MODES: + raise ConfigurationError( + "mode must be one of: auto, ebpf, proxy, sdk-only" + ) + + +def _is_installed(package: str) -> bool: + """Check if a package is importable without importing it.""" + try: + return importlib.util.find_spec(package) is not None + except (ImportError, AttributeError, ValueError): + return False + + +def _has_agents_sdk() -> bool: + """Check specifically for openai.agents module (not just openai base).""" + return _is_installed("openai.agents") + + +def _build_patch_plan(client: GatewayClient, process_agent_id: str) -> list[RuntimePatch]: + patch_plan: list[RuntimePatch] = [] + langchain_installed = _is_installed("langchain") + langgraph_installed = _is_installed("langgraph") + callback_target: Any = client + + if langchain_installed or langgraph_installed: + patch_plan.append(LangChainPatch(client, process_agent_id=process_agent_id)) + callback_handler = get_active_callback_handler() + if callback_handler is not None: + callback_target = callback_handler + + if langgraph_installed: + patch_plan.append(LangGraphPatch(callback_target)) + + if _is_installed("crewai"): + patch_plan.append(CrewAIPatch(callback_target)) + if _is_installed("pydantic_ai"): + patch_plan.append(PydanticAIPatch(callback_target)) + if _is_installed("openai") and _has_agents_sdk(): + patch_plan.append(OpenAIAgentsPatch(callback_target)) + if _is_installed("mcp"): + # Keep MCP patch last as fallback for remaining tool dispatch paths. + patch_plan.append(MCPClientPatch(callback_target)) + + return patch_plan + + +def _apply_runtime_patches(client: GatewayClient, process_agent_id: str) -> list[RuntimePatch]: + applied: list[RuntimePatch] = [] + patch_plan = _build_patch_plan(client=client, process_agent_id=process_agent_id) + for index, patch in enumerate(patch_plan): + if patch.apply(): + applied.append(patch) + callback_handler = get_active_callback_handler() + if callback_handler is not None: + _replace_callback_targets(patch_plan[index + 1 :], callback_handler) + return applied + + +def _revert_patches(patches: list[RuntimePatch]) -> None: + for patch in reversed(patches): + try: + patch.revert() + except Exception: + continue + + +def _start_network_layer(*, client: GatewayClient, mode: RuntimeMode) -> tuple[NetworkMode, Callable[[], None]]: + if mode == "sdk-only": + return "sdk-only", _noop_shutdown + + if mode == "ebpf": + if not _platform_supports_ebpf(): + raise ConfigurationError("eBPF mode is not supported on this platform.") + return "ebpf", _start_ebpf_probes(client) + + if mode == "proxy": + return "proxy", _start_mitm_proxy(client) + + if _platform_supports_ebpf(): + return "ebpf", _start_ebpf_probes(client) + return "proxy", _start_mitm_proxy(client) + + +def _platform_supports_ebpf() -> bool: + return sys.platform.startswith("linux") + + +def _start_ebpf_probes(client: GatewayClient) -> Callable[[], None]: + del client + return _noop_shutdown + + +def _start_mitm_proxy(client: GatewayClient) -> Callable[[], None]: + del client + return _noop_shutdown + + +def _noop_shutdown() -> None: + return None + + +def _clear_active_context(context: AssemblyContext) -> None: + global _ACTIVE_CONTEXT + with _INIT_LOCK: + if _ACTIVE_CONTEXT is context: + _ACTIVE_CONTEXT = None + + +def _replace_callback_targets(patches: list[RuntimePatch], callback_handler: Any) -> None: + for patch in patches: + if hasattr(patch, "callback_handler"): + setattr(patch, "callback_handler", callback_handler) + + +def _validate_active_context_compatibility( + context: AssemblyContext, + *, + gateway_url: str, + api_key: str, + agent_id: str, +) -> None: + if context.client.gateway_url != gateway_url.rstrip("/"): + raise ConfigurationError("init_assembly already initialized with a different gateway_url.") + if context.client.api_key != api_key: + raise ConfigurationError("init_assembly already initialized with a different api_key.") + if context.client.agent_id != agent_id: + raise ConfigurationError("init_assembly already initialized with a different agent_id.") diff --git a/docs/contents/document/api-references/index.mdx b/docs/contents/document/api-references/index.mdx index a996844..8325e67 100644 --- a/docs/contents/document/api-references/index.mdx +++ b/docs/contents/document/api-references/index.mdx @@ -21,9 +21,15 @@ from agent_assembly import ( ## Initialization -### `init_assembly(gateway_url: str, agent_id: str, api_key: str | None = None) -> GatewayClient` +### `init_assembly(gateway_url: str, api_key: str, agent_id: str | None = None, mode: str = "auto", **kwargs) -> AssemblyContext` -Creates a configured `GatewayClient` instance. +Initializes the runtime patch manager, network interception mode, and returns an `AssemblyContext`. + +`AssemblyContext` exposes: + +- `.client`: configured `GatewayClient` +- `.patches`: list of applied runtime patch instances +- `.shutdown()`: idempotent teardown of active patches and network layer ## Gateway client @@ -91,3 +97,12 @@ from agent_assembly.exceptions import ( - `pending` decisions await approval and raise on rejection. - Successful results are recorded with audit payload truncation at 2000 chars. - Agent identity resolves from `ctx.deps.assembly_agent_id` with process-level fallback. + +## Runtime modes + +`mode` accepts: + +- `"auto"`: prefer eBPF on Linux and fallback to proxy +- `"ebpf"`: force eBPF mode (raises on unsupported platforms) +- `"proxy"`: force proxy mode +- `"sdk-only"`: skip network layer startup diff --git a/docs/contents/document/quick-start/how-to-run.mdx b/docs/contents/document/quick-start/how-to-run.mdx index 9c8ef67..96fc464 100644 --- a/docs/contents/document/quick-start/how-to-run.mdx +++ b/docs/contents/document/quick-start/how-to-run.mdx @@ -15,19 +15,20 @@ from agent_assembly import init_assembly async def main() -> None: - client = init_assembly( + ctx = init_assembly( gateway_url="http://localhost:8080", + api_key="required-api-key", agent_id="my-agent-001", - api_key="optional-api-key", + mode="auto", ) try: - registration = await client.register_agent() - policy_result = await client.check_policy_compliance("tool.call") + registration = await ctx.client.register_agent() + policy_result = await ctx.client.check_policy_compliance("tool.call") print(registration) print(policy_result) finally: - client.close() + ctx.shutdown() asyncio.run(main()) diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 87b8479..30fd063 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -10,8 +10,9 @@ # Initialize the assembly with your agent configuration assembly = init_assembly( gateway_url="https://gateway.agent-assembly.dev", + api_key="your-api-key-here", agent_id="my-agent-001", - api_key="your-api-key-here", # Optional, depending on your gateway configuration + mode="auto", ) # Create an agent configuration @@ -32,5 +33,5 @@ # - Check policy compliance before executing actions # - Log audit events -# Don't forget to close the client when done -assembly.close() +# Don't forget to shutdown the runtime when done +assembly.shutdown() diff --git a/test/integration/test_assembly_integration.py b/test/integration/test_assembly_integration.py index 5020ce5..c65fa73 100644 --- a/test/integration/test_assembly_integration.py +++ b/test/integration/test_assembly_integration.py @@ -15,13 +15,14 @@ def test_init_assembly_with_valid_config(): """Test that assembly initialization works with valid configuration.""" # This test requires a running gateway - assembly = init_assembly( + context = init_assembly( gateway_url="http://localhost:8080", + api_key="test-api-key", agent_id="test-agent-001", ) - assert assembly is not None - assert assembly.agent_id == "test-agent-001" - assembly.close() + assert context is not None + assert context.client.agent_id == "test-agent-001" + context.shutdown() @pytest.mark.integration @@ -30,26 +31,28 @@ def test_init_assembly_with_invalid_config(): with pytest.raises(ConfigurationError): init_assembly( gateway_url="", # Invalid: empty URL + api_key="test-api-key", agent_id="test-agent-001", ) with pytest.raises(ConfigurationError): init_assembly( gateway_url="http://localhost:8080", - agent_id="", # Invalid: empty agent ID + api_key="", # Invalid: empty API key ) @pytest.mark.integration def test_gateway_client_context_manager(): """Test that the gateway client works as a context manager.""" - assembly = init_assembly( + context = init_assembly( gateway_url="http://localhost:8080", + api_key="test-api-key", agent_id="test-agent-001", ) - with assembly: - assert assembly.client is not None + with context: + assert context.client.client is not None # Client should be closed after exiting context - assert assembly._client is None + assert context.client._client is None diff --git a/test/unit/adapters/crewai/test_patch.py b/test/unit/adapters/crewai/test_patch.py index 92e048e..399f85e 100644 --- a/test/unit/adapters/crewai/test_patch.py +++ b/test/unit/adapters/crewai/test_patch.py @@ -61,6 +61,23 @@ def test_apply_patches_crewai_run_and_is_idempotent(monkeypatch: pytest.MonkeyPa assert FakeTask.execute_sync is first_task_ref +def test_revert_restores_crewai_runtime_hooks(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, FakeTask = _install_fake_crewai_modules(monkeypatch) + original_run = FakeBaseTool.run + original_execute_sync = FakeTask.execute_sync + + patcher = crewai_patch.CrewAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + assert FakeBaseTool.run is not original_run + assert FakeTask.execute_sync is not original_execute_sync + + patcher.revert() + assert FakeBaseTool.run is original_run + assert FakeTask.execute_sync is original_execute_sync + assert getattr(FakeBaseTool, crewai_patch._TOOLS_PATCHED_FLAG, False) is False + assert getattr(FakeTask, crewai_patch._TASK_PATCHED_FLAG, False) is False + + def test_loader_edge_cases_and_apply_false_without_basetool( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/test/unit/adapters/langchain/test_langgraph_patch.py b/test/unit/adapters/langchain/test_langgraph_patch.py index f340fe3..f2c9fce 100644 --- a/test/unit/adapters/langchain/test_langgraph_patch.py +++ b/test/unit/adapters/langchain/test_langgraph_patch.py @@ -249,6 +249,28 @@ def compile(self) -> object: assert FakeStateGraph.compile is patched_compile +def test_langgraph_patch_revert_restores_original_compile( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeStateGraph: + def compile(self) -> str: + return "original" + + original_compile = FakeStateGraph.compile + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: SimpleNamespace(StateGraph=FakeStateGraph), + ) + + patcher = langgraph_patch.LangGraphPatch(GraphEventRecorder()) + assert patcher.apply() is True + assert FakeStateGraph.compile is not original_compile + + patcher.revert() + assert FakeStateGraph.compile is original_compile + assert getattr(FakeStateGraph, langgraph_patch._PATCHED_FLAG, False) is False + + def test_wrap_node_callable_records_metadata_and_preserves_config_passthrough() -> None: captured_events: list[tuple[str, dict[str, object]]] = [] captured_configs: list[object] = [] diff --git a/test/unit/adapters/langchain/test_runtime.py b/test/unit/adapters/langchain/test_runtime.py index 51707a9..21364e4 100644 --- a/test/unit/adapters/langchain/test_runtime.py +++ b/test/unit/adapters/langchain/test_runtime.py @@ -1,15 +1,27 @@ from __future__ import annotations +import pytest + from agent_assembly import init_assembly +from agent_assembly.core import assembly as core_assembly from agent_assembly.adapters.langchain.runtime import ( _reset_runtime_state_for_tests, auto_inject_callback_handler, get_active_callback_handler, ) +from agent_assembly.exceptions import ConfigurationError + + +def _reset_assembly_state() -> None: + active_context = core_assembly._ACTIVE_CONTEXT + if active_context is not None and not active_context.is_shutdown: + active_context.shutdown() + core_assembly._ACTIVE_CONTEXT = None def test_auto_inject_callback_handler_is_idempotent() -> None: _reset_runtime_state_for_tests() + _reset_assembly_state() first = auto_inject_callback_handler(interceptor=object()) second = auto_inject_callback_handler(interceptor=object()) @@ -18,25 +30,56 @@ def test_auto_inject_callback_handler_is_idempotent() -> None: assert get_active_callback_handler() is first -def test_init_assembly_auto_injects_callback_handler() -> None: +def test_init_assembly_auto_injects_callback_handler(monkeypatch) -> None: _reset_runtime_state_for_tests() + _reset_assembly_state() + monkeypatch.setattr( + core_assembly, + "_is_installed", + lambda package: package == "langchain", + ) - client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent") + context = init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="test-agent", + ) try: assert get_active_callback_handler() is not None finally: - client.close() + context.shutdown() -def test_init_assembly_reuses_existing_callback_handler() -> None: +def test_init_assembly_reuses_existing_callback_handler(monkeypatch) -> None: _reset_runtime_state_for_tests() + _reset_assembly_state() + monkeypatch.setattr( + core_assembly, + "_is_installed", + lambda package: package == "langchain", + ) - first_client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent-a") - second_client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent-b") + first_context = init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="test-agent-a", + ) + second_context = init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="test-agent-a", + ) try: first_handler = get_active_callback_handler() assert first_handler is not None assert get_active_callback_handler() is first_handler + assert first_context is second_context + with pytest.raises(ConfigurationError): + init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="test-agent-b", + ) finally: - first_client.close() - second_client.close() + first_context.shutdown() + second_context.shutdown() diff --git a/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py index 58a4edd..d18d96e 100644 --- a/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py +++ b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py @@ -61,6 +61,23 @@ async def test_apply_patches_tool_run_and_is_idempotent(monkeypatch: pytest.Monk assert FakeTool._run is first_run_ref +def test_revert_restores_tool_run_and_clears_process_agent_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeTool = _install_fake_pydantic_ai_modules(monkeypatch) + original_run = FakeTool._run + pydantic_ai_patch.set_process_agent_id("agent-before-revert") + + patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + assert FakeTool._run is not original_run + + patcher.revert() + assert FakeTool._run is original_run + assert getattr(FakeTool, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is False + assert pydantic_ai_patch._get_process_agent_id() is None + + def test_loader_edge_cases_and_apply_false_without_tool(monkeypatch: pytest.MonkeyPatch) -> None: def raise_import_error(module_name: str) -> object: raise ImportError(module_name) diff --git a/test/unit/adapters/test_optional_patches.py b/test/unit/adapters/test_optional_patches.py new file mode 100644 index 0000000..623a207 --- /dev/null +++ b/test/unit/adapters/test_optional_patches.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import pytest + +from agent_assembly.adapters.mcp.patch import MCPClientPatch +from agent_assembly.adapters.mcp import patch as mcp_patch +from agent_assembly.adapters.openai_agents.patch import OpenAIAgentsPatch +from agent_assembly.adapters.openai_agents import patch as openai_patch + + +def test_mcp_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: object()) + patcher = MCPClientPatch(callback_handler=object()) + assert patcher.apply() is True + patcher.revert() + + +def test_mcp_patch_apply_returns_false_when_module_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(mcp_patch.importlib.util, "find_spec", lambda package: None) + patcher = MCPClientPatch(callback_handler=object()) + assert patcher.apply() is False + + +def test_openai_agents_patch_apply_and_revert(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(openai_patch.importlib.util, "find_spec", lambda package: object()) + patcher = OpenAIAgentsPatch(callback_handler=object()) + assert patcher.apply() is True + patcher.revert() + + +def test_openai_agents_patch_apply_returns_false_when_module_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(openai_patch.importlib.util, "find_spec", lambda package: None) + patcher = OpenAIAgentsPatch(callback_handler=object()) + assert patcher.apply() is False diff --git a/test/unit/test_assembly.py b/test/unit/test_assembly.py index da5667e..77afe9c 100644 --- a/test/unit/test_assembly.py +++ b/test/unit/test_assembly.py @@ -1,62 +1,486 @@ -""" -Unit tests for the Agent Assembly SDK. +from __future__ import annotations -These tests verify the basic SDK functionality without requiring external dependencies. -""" +from concurrent.futures import ThreadPoolExecutor +from threading import Event +from typing import Any import pytest from agent_assembly import init_assembly -from agent_assembly.exceptions import ConfigurationError +from agent_assembly.core import assembly as core_assembly +from agent_assembly.exceptions import AssemblyError, ConfigurationError -def test_init_assembly_with_valid_config(): - """Test that assembly initialization works with valid configuration.""" - assembly = init_assembly( +@pytest.fixture(autouse=True) +def cleanup_active_context() -> None: + active_context = core_assembly._ACTIVE_CONTEXT + if active_context is not None and not active_context.is_shutdown: + active_context.shutdown() + core_assembly._ACTIVE_CONTEXT = None + + +def test_init_assembly_with_valid_config_returns_context(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(core_assembly, "_apply_runtime_patches", lambda **kwargs: []) + monkeypatch.setattr( + core_assembly, + "_start_network_layer", + lambda **kwargs: ("sdk-only", core_assembly._noop_shutdown), + ) + + context = init_assembly( gateway_url="http://localhost:8080", + api_key="test-api-key", agent_id="test-agent-001", ) - assert assembly is not None - assert assembly.agent_id == "test-agent-001" - assert assembly.gateway_url == "http://localhost:8080" - assembly.close() + + try: + assert context.client.agent_id == "test-agent-001" + assert context.client.gateway_url == "http://localhost:8080" + assert context.client.api_key == "test-api-key" + assert context.network_mode == "sdk-only" + assert context.patches == [] + finally: + context.shutdown() -def test_init_assembly_with_invalid_config(): - """Test that assembly initialization fails with invalid configuration.""" +def test_init_assembly_with_invalid_config() -> None: with pytest.raises(ConfigurationError): init_assembly( - gateway_url="", # Invalid: empty URL + gateway_url="", + api_key="test-api-key", agent_id="test-agent-001", ) with pytest.raises(ConfigurationError): init_assembly( gateway_url="http://localhost:8080", - agent_id="", # Invalid: empty agent ID + api_key="", + agent_id="test-agent-001", ) + with pytest.raises(ConfigurationError): + init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + mode="invalid-mode", # type: ignore[arg-type] + ) + + +def test_is_installed_uses_find_spec(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + + def fake_find_spec(package: str) -> object | None: + calls.append(package) + if package == "installed_pkg": + return object() + return None + + monkeypatch.setattr(core_assembly.importlib.util, "find_spec", fake_find_spec) + + assert core_assembly._is_installed("installed_pkg") is True + assert core_assembly._is_installed("missing_pkg") is False + assert calls == ["installed_pkg", "missing_pkg"] + + +def test_is_installed_handles_find_spec_errors(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + core_assembly.importlib.util, + "find_spec", + lambda package: (_ for _ in ()).throw(ValueError(package)), + ) + assert core_assembly._is_installed("bad_pkg") is False + + +def test_has_agents_sdk_checks_openai_agents_module(monkeypatch: pytest.MonkeyPatch) -> None: + checked: list[str] = [] + monkeypatch.setattr( + core_assembly, + "_is_installed", + lambda package: checked.append(package) or True, + ) + assert core_assembly._has_agents_sdk() is True + assert checked == ["openai.agents"] + + +def test_build_patch_plan_langgraph_order_and_mcp_last( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[str] = [] + + class _FakePatch: + def __init__(self, name: str) -> None: + self.name = name + + def apply(self) -> bool: + return True + + def revert(self) -> None: + return None + + monkeypatch.setattr( + core_assembly, + "_is_installed", + lambda package: package + in {"langchain", "langgraph", "crewai", "pydantic_ai", "openai", "mcp"}, + ) + monkeypatch.setattr(core_assembly, "_has_agents_sdk", lambda: True) + monkeypatch.setattr(core_assembly, "get_active_callback_handler", lambda: object()) + + monkeypatch.setattr( + core_assembly, + "LangChainPatch", + lambda *args, **kwargs: created.append("langchain") or _FakePatch("langchain"), + ) + monkeypatch.setattr( + core_assembly, + "LangGraphPatch", + lambda *args, **kwargs: created.append("langgraph") or _FakePatch("langgraph"), + ) + monkeypatch.setattr( + core_assembly, + "CrewAIPatch", + lambda *args, **kwargs: created.append("crewai") or _FakePatch("crewai"), + ) + monkeypatch.setattr( + core_assembly, + "PydanticAIPatch", + lambda *args, **kwargs: created.append("pydantic_ai") or _FakePatch("pydantic_ai"), + ) + monkeypatch.setattr( + core_assembly, + "OpenAIAgentsPatch", + lambda *args, **kwargs: created.append("openai_agents") or _FakePatch("openai_agents"), + ) + monkeypatch.setattr( + core_assembly, + "MCPClientPatch", + lambda *args, **kwargs: created.append("mcp") or _FakePatch("mcp"), + ) + + patch_plan = core_assembly._build_patch_plan(client=object(), process_agent_id="agent-1") + + assert [patch.name for patch in patch_plan] == [ + "langchain", + "langgraph", + "crewai", + "pydantic_ai", + "openai_agents", + "mcp", + ] + assert created[-1] == "mcp" + + +def test_build_patch_plan_uses_langchain_bridge_for_langgraph_only( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[str] = [] + + class _FakePatch: + def __init__(self, name: str) -> None: + self.name = name + + def apply(self) -> bool: + return True + + def revert(self) -> None: + return None + + monkeypatch.setattr( + core_assembly, + "_is_installed", + lambda package: package in {"langgraph", "crewai", "pydantic_ai", "mcp"}, + ) + monkeypatch.setattr(core_assembly, "_has_agents_sdk", lambda: False) + monkeypatch.setattr(core_assembly, "get_active_callback_handler", lambda: None) + monkeypatch.setattr( + core_assembly, + "LangChainPatch", + lambda *args, **kwargs: created.append("langchain") or _FakePatch("langchain"), + ) + monkeypatch.setattr( + core_assembly, + "LangGraphPatch", + lambda *args, **kwargs: created.append("langgraph") or _FakePatch("langgraph"), + ) + monkeypatch.setattr( + core_assembly, + "CrewAIPatch", + lambda *args, **kwargs: created.append("crewai") or _FakePatch("crewai"), + ) + monkeypatch.setattr( + core_assembly, + "PydanticAIPatch", + lambda *args, **kwargs: created.append("pydantic_ai") or _FakePatch("pydantic_ai"), + ) + monkeypatch.setattr( + core_assembly, + "MCPClientPatch", + lambda *args, **kwargs: created.append("mcp") or _FakePatch("mcp"), + ) + + patch_plan = core_assembly._build_patch_plan(client=object(), process_agent_id="agent-1") + assert [patch.name for patch in patch_plan] == [ + "langchain", + "langgraph", + "crewai", + "pydantic_ai", + "mcp", + ] + + +def test_apply_runtime_patches_replaces_callback_targets( + monkeypatch: pytest.MonkeyPatch, +) -> None: + callback_targets: list[object] = [] + + class _FakePatch: + def __init__(self, name: str, *, callback_handler: object | None = None) -> None: + self.name = name + self.callback_handler = callback_handler + + def apply(self) -> bool: + callback_targets.append(self.callback_handler) + return True + + def revert(self) -> None: + return None + + patch_plan = [ + _FakePatch("langchain"), + _FakePatch("crewai", callback_handler="initial"), + _FakePatch("mcp", callback_handler="initial"), + ] + + monkeypatch.setattr(core_assembly, "_build_patch_plan", lambda **kwargs: patch_plan) + monkeypatch.setattr(core_assembly, "get_active_callback_handler", lambda: "runtime-callback") + + applied = core_assembly._apply_runtime_patches(client=object(), process_agent_id="agent-1") + assert applied == patch_plan + assert callback_targets == [None, "runtime-callback", "runtime-callback"] -def test_gateway_client_context_manager(): - """Test that the gateway client works as a context manager.""" - assembly = init_assembly( + +def test_mode_sdk_only_skips_network_layer() -> None: + network_mode, shutdown = core_assembly._start_network_layer(client=object(), mode="sdk-only") + assert network_mode == "sdk-only" + assert callable(shutdown) + + +def test_mode_auto_uses_proxy_when_ebpf_is_not_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_platform_supports_ebpf", lambda: False) + monkeypatch.setattr(core_assembly, "_start_mitm_proxy", lambda client: lambda: None) + + network_mode, shutdown = core_assembly._start_network_layer(client=object(), mode="auto") + + assert network_mode == "proxy" + assert callable(shutdown) + + +def test_mode_auto_uses_ebpf_when_supported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_platform_supports_ebpf", lambda: True) + monkeypatch.setattr(core_assembly, "_start_ebpf_probes", lambda client: lambda: None) + + network_mode, shutdown = core_assembly._start_network_layer(client=object(), mode="auto") + assert network_mode == "ebpf" + assert callable(shutdown) + + +def test_mode_proxy_forces_proxy_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_start_mitm_proxy", lambda client: lambda: None) + network_mode, shutdown = core_assembly._start_network_layer(client=object(), mode="proxy") + assert network_mode == "proxy" + assert callable(shutdown) + + +def test_mode_ebpf_raises_on_unsupported_platform( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_platform_supports_ebpf", lambda: False) + + with pytest.raises(ConfigurationError): + core_assembly._start_network_layer(client=object(), mode="ebpf") + + +def test_context_manager_shutdown_reverts_applied_patches( + monkeypatch: pytest.MonkeyPatch, +) -> None: + events: list[str] = [] + + class _Patch: + def __init__(self, name: str) -> None: + self.name = name + + def apply(self) -> bool: + events.append(f"apply:{self.name}") + return True + + def revert(self) -> None: + events.append(f"revert:{self.name}") + + monkeypatch.setattr( + core_assembly, + "_apply_runtime_patches", + lambda **kwargs: [_Patch("a"), _Patch("b")], + ) + monkeypatch.setattr( + core_assembly, + "_start_network_layer", + lambda **kwargs: ("sdk-only", core_assembly._noop_shutdown), + ) + + with init_assembly( gateway_url="http://localhost:8080", - agent_id="test-agent-001", + api_key="test-api-key", + ) as context: + assert context.is_shutdown is False + + assert events == ["revert:b", "revert:a"] + assert context.is_shutdown is True + + +def test_init_assembly_rejects_conflicting_reinit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_apply_runtime_patches", lambda **kwargs: []) + monkeypatch.setattr( + core_assembly, + "_start_network_layer", + lambda **kwargs: ("sdk-only", core_assembly._noop_shutdown), + ) + + context = init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="agent-a", ) - - with assembly: - assert assembly.client is not None - - # Client should be closed after exiting context - assert assembly._client is None + try: + with pytest.raises(ConfigurationError, match="different agent_id"): + init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key", + agent_id="agent-b", + ) + finally: + context.shutdown() + +def test_init_assembly_rejects_conflicting_gateway_and_api_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(core_assembly, "_apply_runtime_patches", lambda **kwargs: []) + monkeypatch.setattr( + core_assembly, + "_start_network_layer", + lambda **kwargs: ("sdk-only", core_assembly._noop_shutdown), + ) -def test_gateway_client_with_api_key(): - """Test that the gateway client can be initialized with an API key.""" - assembly = init_assembly( + context = init_assembly( gateway_url="http://localhost:8080", - agent_id="test-agent-001", api_key="test-api-key", + agent_id="agent-a", ) - assert assembly.api_key == "test-api-key" - assembly.close() + try: + with pytest.raises(ConfigurationError, match="different gateway_url"): + init_assembly( + gateway_url="http://localhost:9090", + api_key="test-api-key", + agent_id="agent-a", + ) + with pytest.raises(ConfigurationError, match="different api_key"): + init_assembly( + gateway_url="http://localhost:8080", + api_key="test-api-key-2", + agent_id="agent-a", + ) + finally: + context.shutdown() + + +def test_context_shutdown_aggregates_errors() -> None: + class _FailingPatch: + def apply(self) -> bool: + return True + + def revert(self) -> None: + raise RuntimeError("patch failure") + + class _FailingClient: + gateway_url = "http://localhost:8080" + api_key = "test-api-key" + agent_id = "test-agent" + + def close(self) -> None: + raise RuntimeError("close failure") + + context = core_assembly.AssemblyContext( + client=_FailingClient(), # type: ignore[arg-type] + patches=[_FailingPatch()], + network_mode="sdk-only", + _network_shutdown=lambda: (_ for _ in ()).throw(RuntimeError("network failure")), + ) + + with pytest.raises(AssemblyError, match="network shutdown failed"): + context.shutdown() + + +def test_revert_patches_ignores_revert_failures() -> None: + class _PatchOk: + def apply(self) -> bool: + return True + + def revert(self) -> None: + return None + + class _PatchFails: + def apply(self) -> bool: + return True + + def revert(self) -> None: + raise RuntimeError("boom") + + core_assembly._revert_patches([_PatchOk(), _PatchFails(), _PatchOk()]) # no raise + + +def test_init_assembly_is_thread_safe_and_idempotent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + started = Event() + release = Event() + apply_call_count = 0 + + def fake_apply_runtime_patches(**kwargs: Any) -> list[Any]: + nonlocal apply_call_count + apply_call_count += 1 + started.set() + release.wait(timeout=2) + return [] + + monkeypatch.setattr(core_assembly, "_apply_runtime_patches", fake_apply_runtime_patches) + monkeypatch.setattr( + core_assembly, + "_start_network_layer", + lambda **kwargs: ("sdk-only", core_assembly._noop_shutdown), + ) + + def initialize() -> core_assembly.AssemblyContext: + return init_assembly(gateway_url="http://localhost:8080", api_key="test-api-key") + + with ThreadPoolExecutor(max_workers=2) as executor: + future_a = executor.submit(initialize) + started.wait(timeout=2) + future_b = executor.submit(initialize) + release.set() + + context_a = future_a.result() + context_b = future_b.result() + + try: + assert context_a is context_b + assert apply_call_count == 1 + finally: + context_a.shutdown()