diff --git a/agent_assembly/__init__.py b/agent_assembly/__init__.py index eaf5edd..071f5cb 100644 --- a/agent_assembly/__init__.py +++ b/agent_assembly/__init__.py @@ -9,6 +9,7 @@ ConfigurationError, GatewayError, PolicyError, + ToolExecutionBlockedError, ) __version__ = "0.0.0" @@ -24,4 +25,5 @@ "GatewayError", "ConfigurationError", "AdapterValidationError", + "ToolExecutionBlockedError", ] diff --git a/agent_assembly/adapters/langchain/__init__.py b/agent_assembly/adapters/langchain/__init__.py new file mode 100644 index 0000000..b2a365f --- /dev/null +++ b/agent_assembly/adapters/langchain/__init__.py @@ -0,0 +1,15 @@ +"""LangChain adapter package.""" + +from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler +from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile +from agent_assembly.adapters.langchain.runtime import ( + auto_inject_callback_handler, + get_active_callback_handler, +) + +__all__ = [ + "AssemblyCallbackHandler", + "patch_stategraph_compile", + "auto_inject_callback_handler", + "get_active_callback_handler", +] diff --git a/agent_assembly/adapters/langchain/callback_handler.py b/agent_assembly/adapters/langchain/callback_handler.py new file mode 100644 index 0000000..751feec --- /dev/null +++ b/agent_assembly/adapters/langchain/callback_handler.py @@ -0,0 +1,292 @@ +"""LangChain callback handler module.""" + +from __future__ import annotations + +import importlib +import inspect +from typing import Any, Literal, Mapping, cast +from uuid import UUID + +from agent_assembly.exceptions import ToolExecutionBlockedError + + +class _FallbackBaseCallbackHandler: + """Fallback base type when langchain-core is not installed.""" + + pass + + +_CallbackHandlerBase: type[object] = _FallbackBaseCallbackHandler +try: # pragma: no cover - import availability depends on installed extras. + callbacks_module = importlib.import_module("langchain_core.callbacks") + maybe_base = getattr(callbacks_module, "BaseCallbackHandler", _FallbackBaseCallbackHandler) + if isinstance(maybe_base, type): + _CallbackHandlerBase = cast(type[object], maybe_base) +except ImportError: # pragma: no cover - fallback keeps runtime import-safe. + pass + + +class AssemblyCallbackHandler(_CallbackHandlerBase): # type: ignore[valid-type,misc] + """Callback handler that delegates runtime events to governance interception.""" + + def __init__(self, interceptor: Any) -> None: + self._interceptor = interceptor + + def _normalize_decision( + self, + decision: object, + ) -> tuple[Literal["allow", "deny", "pending"], str | None]: + if isinstance(decision, str): + normalized = decision.strip().lower() + if normalized == "allow": + return "allow", None + if normalized == "deny": + return "deny", None + if normalized == "pending": + return "pending", None + return "allow", None + + if isinstance(decision, Mapping): + raw_status = str(decision.get("status", "allow")).strip().lower() + if raw_status == "allow": + status: Literal["allow", "deny", "pending"] = "allow" + elif raw_status == "deny": + status = "deny" + elif raw_status == "pending": + status = "pending" + else: + status = "allow" + + reason_value = decision.get("reason") + reason = str(reason_value) if reason_value is not None else None + return status, reason + + return "allow", None + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "check_tool_start", None) + if not callable(method): + return None + + decision = method( + serialized=serialized, + input_str=input_str, + run_id=run_id, + **kwargs, + ) + status, reason = self._normalize_decision(decision) + if status == "deny": + raise ToolExecutionBlockedError(reason or "Tool execution blocked by governance.") + if status == "pending": + approval = self._resolve_pending_approval( + serialized=serialized, + input_str=input_str, + run_id=run_id, + **kwargs, + ) + approval_status, approval_reason = self._normalize_decision(approval) + if approval_status != "allow": + raise ToolExecutionBlockedError( + approval_reason or reason or "Tool execution was not approved by governance." + ) + + return None + + def _resolve_pending_approval( + self, + *, + serialized: dict[str, Any], + input_str: str, + run_id: UUID, + **kwargs: Any, + ) -> object: + wait_method = getattr(self._interceptor, "wait_for_tool_approval", None) + if not callable(wait_method): + return "deny" + + return wait_method( + serialized=serialized, + input_str=input_str, + run_id=run_id, + **kwargs, + ) + + async def aon_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "check_tool_start", None) + if not callable(method): + return None + + decision = method( + serialized=serialized, + input_str=input_str, + run_id=run_id, + **kwargs, + ) + if inspect.isawaitable(decision): + decision = await decision + + status, reason = self._normalize_decision(decision) + if status == "deny": + raise ToolExecutionBlockedError(reason or "Tool execution blocked by governance.") + if status == "pending": + approval = self._resolve_pending_approval( + serialized=serialized, + input_str=input_str, + run_id=run_id, + **kwargs, + ) + if inspect.isawaitable(approval): + approval = await approval + approval_status, approval_reason = self._normalize_decision(approval) + if approval_status != "allow": + raise ToolExecutionBlockedError( + approval_reason or reason or "Tool execution was not approved by governance." + ) + + return None + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_tool_end", None) + if not callable(method): + return None + + method( + output=output, + run_id=run_id, + **kwargs, + ) + return None + + async def aon_tool_end( + self, + output: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_tool_end", None) + if not callable(method): + return None + + result = method( + output=output, + run_id=run_id, + **kwargs, + ) + if inspect.isawaitable(result): + await result + return None + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_llm_start_scan", None) + if not callable(method): + return None + + method( + serialized=serialized, + prompts=prompts, + run_id=run_id, + **kwargs, + ) + return None + + async def aon_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_llm_start_scan", None) + if not callable(method): + return None + + result = method( + serialized=serialized, + prompts=prompts, + run_id=run_id, + **kwargs, + ) + if inspect.isawaitable(result): + await result + return None + + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_llm_end", None) + if not callable(method): + return None + + method( + response=response, + run_id=run_id, + **kwargs, + ) + return None + + async def aon_llm_end( + self, + response: Any, + *, + run_id: UUID, + **kwargs: Any, + ) -> None: + method = getattr(self._interceptor, "on_llm_end", None) + if not callable(method): + return None + + result = method( + response=response, + run_id=run_id, + **kwargs, + ) + if inspect.isawaitable(result): + await result + return None + + def on_graph_node_start(self, node_name: str, state: Any) -> None: + method = getattr(self._interceptor, "on_graph_node_start", None) + if not callable(method): + return None + method(node_name=node_name, state=state) + return None + + def on_graph_node_end(self, node_name: str, state: Any, result: Any) -> None: + method = getattr(self._interceptor, "on_graph_node_end", None) + if not callable(method): + return None + method(node_name=node_name, state=state, result=result) + return None diff --git a/agent_assembly/adapters/langchain/langgraph_patch.py b/agent_assembly/adapters/langchain/langgraph_patch.py new file mode 100644 index 0000000..80da5d8 --- /dev/null +++ b/agent_assembly/adapters/langchain/langgraph_patch.py @@ -0,0 +1,200 @@ +"""LangGraph compile-time patching for governance interception.""" + +from __future__ import annotations + +import importlib +import inspect +from typing import Any + +_PATCHED_FLAG = "_agent_assembly_compile_patched" +_ORIGINAL_COMPILE = "_agent_assembly_original_compile" +_NODE_WRAPPED_FLAG = "_agent_assembly_node_wrapped" +_INVOKE_WRAPPED_FLAG = "_agent_assembly_invoke_wrapped" + + +def _extract_state(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if args: + return args[0] + return kwargs.get("state") + + +def _invoke_pre_node_hook(callback_handler: Any, node_name: str, state: Any) -> None: + method = getattr(callback_handler, "on_graph_node_start", None) + if not callable(method): + return None + + result = method(node_name=node_name, state=state) + if inspect.isawaitable(result): + return None + + return None + + +def _invoke_post_node_hook(callback_handler: Any, node_name: str, state: Any, result: Any) -> None: + method = getattr(callback_handler, "on_graph_node_end", None) + if not callable(method): + return None + + callback_result = method(node_name=node_name, state=state, result=result) + if inspect.isawaitable(callback_result): + return None + + return None + + +def _wrap_node_callable(node_name: str, node_callable: Any, callback_handler: Any) -> Any: + if getattr(node_callable, _NODE_WRAPPED_FLAG, False): + return node_callable + + def wrapped_node(*node_args: Any, **node_kwargs: Any) -> Any: + state = _extract_state(node_args, node_kwargs) + _invoke_pre_node_hook(callback_handler, node_name=node_name, state=state) + + node_result = node_callable(*node_args, **node_kwargs) + if inspect.isawaitable(node_result): + async def awaited_node_result() -> Any: + resolved_result = await node_result + _invoke_post_node_hook( + callback_handler, + node_name=node_name, + state=state, + result=resolved_result, + ) + return resolved_result + + return awaited_node_result() + + _invoke_post_node_hook( + callback_handler, + node_name=node_name, + state=state, + result=node_result, + ) + return node_result + + setattr(wrapped_node, _NODE_WRAPPED_FLAG, True) + return wrapped_node + + +def _wrap_node_map(node_map: Any, callback_handler: Any) -> bool: + items_method = getattr(node_map, "items", None) + if not callable(items_method): + return False + + wrapped_any = False + for node_name, node_executor in list(items_method()): + if callable(node_executor): + wrapped_executor = _wrap_node_callable(str(node_name), node_executor, callback_handler) + if wrapped_executor is node_executor: + continue + try: + node_map[node_name] = wrapped_executor + except Exception: + continue + wrapped_any = True + continue + + invoke = getattr(node_executor, "invoke", None) + if callable(invoke): + setattr( + node_executor, + "invoke", + _wrap_node_callable(str(node_name), invoke, callback_handler), + ) + wrapped_any = True + + ainvoke = getattr(node_executor, "ainvoke", None) + if callable(ainvoke): + setattr( + node_executor, + "ainvoke", + _wrap_node_callable(str(node_name), ainvoke, callback_handler), + ) + wrapped_any = True + + return wrapped_any + + +def _wrap_compiled_graph_nodes(compiled_graph: Any, callback_handler: Any) -> bool: + candidate_maps = [ + getattr(compiled_graph, "nodes", None), + getattr(compiled_graph, "_nodes", None), + ] + + pregel = getattr(compiled_graph, "pregel", None) + if pregel is None: + pregel = getattr(compiled_graph, "_pregel", None) + if pregel is not None: + candidate_maps.extend( + [ + getattr(pregel, "nodes", None), + getattr(pregel, "_nodes", None), + ] + ) + + wrapped_any = False + for node_map in candidate_maps: + if node_map is None: + continue + if _wrap_node_map(node_map, callback_handler): + wrapped_any = True + + return wrapped_any + + +def patch_stategraph_compile(callback_handler: Any) -> bool: + """Patch `StateGraph.compile()` to attach runtime governance hooks.""" + try: + module = importlib.import_module("langgraph.graph.state") + except ImportError: + return False + + state_graph_cls = getattr(module, "StateGraph", None) + if state_graph_cls is None: + return False + + if getattr(state_graph_cls, _PATCHED_FLAG, False): + return True + + original_compile = state_graph_cls.compile + + def patched_compile(self: Any, *args: Any, **kwargs: Any) -> Any: + compiled_graph = original_compile(self, *args, **kwargs) + nodes_wrapped = _wrap_compiled_graph_nodes(compiled_graph, callback_handler) + if not nodes_wrapped: + invoke = getattr(compiled_graph, "invoke", None) + if callable(invoke) and not getattr(invoke, _INVOKE_WRAPPED_FLAG, False): + def wrapped_invoke(*invoke_args: Any, **invoke_kwargs: Any) -> Any: + state = _extract_state(invoke_args, invoke_kwargs) + _invoke_pre_node_hook(callback_handler, node_name="graph.invoke", state=state) + + invoke_result = invoke(*invoke_args, **invoke_kwargs) + if inspect.isawaitable(invoke_result): + async def awaited_invoke_result() -> Any: + resolved_result = await invoke_result + _invoke_post_node_hook( + callback_handler, + node_name="graph.invoke", + state=state, + result=resolved_result, + ) + return resolved_result + + return awaited_invoke_result() + + _invoke_post_node_hook( + callback_handler, + node_name="graph.invoke", + state=state, + result=invoke_result, + ) + return invoke_result + + setattr(wrapped_invoke, _INVOKE_WRAPPED_FLAG, True) + setattr(compiled_graph, "invoke", wrapped_invoke) + return compiled_graph + + setattr(state_graph_cls, _ORIGINAL_COMPILE, original_compile) + setattr(state_graph_cls, "compile", patched_compile) + setattr(state_graph_cls, _PATCHED_FLAG, True) + return True diff --git a/agent_assembly/adapters/langchain/runtime.py b/agent_assembly/adapters/langchain/runtime.py new file mode 100644 index 0000000..81c2e1a --- /dev/null +++ b/agent_assembly/adapters/langchain/runtime.py @@ -0,0 +1,39 @@ +"""LangChain runtime wiring helpers.""" + +from __future__ import annotations + +from threading import Lock +from typing import Any + +from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler +from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile + +_ACTIVE_CALLBACK_HANDLER: AssemblyCallbackHandler | None = None +_RUNTIME_LOCK = Lock() + + +def auto_inject_callback_handler(interceptor: Any) -> AssemblyCallbackHandler: + """Create and register the active callback handler instance.""" + global _ACTIVE_CALLBACK_HANDLER + + with _RUNTIME_LOCK: + if _ACTIVE_CALLBACK_HANDLER is not None: + patch_stategraph_compile(_ACTIVE_CALLBACK_HANDLER) + return _ACTIVE_CALLBACK_HANDLER + + handler = AssemblyCallbackHandler(interceptor) + _ACTIVE_CALLBACK_HANDLER = handler + patch_stategraph_compile(handler) + return handler + + +def get_active_callback_handler() -> AssemblyCallbackHandler | None: + """Return the current callback handler instance when one is registered.""" + return _ACTIVE_CALLBACK_HANDLER + + +def _reset_runtime_state_for_tests() -> None: + global _ACTIVE_CALLBACK_HANDLER + + with _RUNTIME_LOCK: + _ACTIVE_CALLBACK_HANDLER = None diff --git a/agent_assembly/core/assembly.py b/agent_assembly/core/assembly.py index 88afbe9..be49337 100644 --- a/agent_assembly/core/assembly.py +++ b/agent_assembly/core/assembly.py @@ -4,6 +4,7 @@ from typing import Optional +from agent_assembly.adapters.langchain.runtime import auto_inject_callback_handler from agent_assembly.client.gateway import GatewayClient from agent_assembly.exceptions import ConfigurationError @@ -32,8 +33,10 @@ def init_assembly( if not agent_id: raise ConfigurationError("agent_id is required") - return GatewayClient( + client = GatewayClient( gateway_url=gateway_url, agent_id=agent_id, api_key=api_key, ) + auto_inject_callback_handler(interceptor=object()) + return client diff --git a/agent_assembly/exceptions/__init__.py b/agent_assembly/exceptions/__init__.py index b973d5d..a3ab7f4 100644 --- a/agent_assembly/exceptions/__init__.py +++ b/agent_assembly/exceptions/__init__.py @@ -9,6 +9,7 @@ "GatewayError", "ConfigurationError", "AdapterValidationError", + "ToolExecutionBlockedError", ] @@ -40,3 +41,8 @@ class ConfigurationError(AssemblyError): class AdapterValidationError(AssemblyError): """Exception raised when an adapter contract is invalid.""" pass + + +class ToolExecutionBlockedError(AssemblyError): + """Exception raised when a tool run is blocked by governance.""" + pass diff --git a/docs/contents/document/api-references/index.mdx b/docs/contents/document/api-references/index.mdx index 23a4d8b..8ed37d7 100644 --- a/docs/contents/document/api-references/index.mdx +++ b/docs/contents/document/api-references/index.mdx @@ -60,5 +60,15 @@ from agent_assembly.exceptions import ( PolicyError, GatewayError, ConfigurationError, + ToolExecutionBlockedError, ) ``` + +## LangChain runtime interception + +`init_assembly(...)` auto-injects a LangChain callback handler when the SDK runtime starts. + +- Tool start checks support governance decisions: `allow`, `deny`, and `pending`. +- `deny` (or unresolved `pending`) raises `ToolExecutionBlockedError`. +- LLM start interception is scan-only and does not mutate prompt content. +- LangGraph `StateGraph.compile()` is patched to add pre/post invocation governance hooks. diff --git a/pytest.ini b/pytest.ini index dfef326..e293d55 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,5 @@ log_cli = 1 log_cli_level = INFO log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) log_cli_date_format=%Y-%m-%d %H:%M:%S +markers = + integration: marks tests as integration tests diff --git a/test/integration/test_langgraph_interception_integration.py b/test/integration/test_langgraph_interception_integration.py new file mode 100644 index 0000000..606af97 --- /dev/null +++ b/test/integration/test_langgraph_interception_integration.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from types import SimpleNamespace +from uuid import uuid4 + +import pytest + +from agent_assembly.adapters.langchain import AssemblyCallbackHandler, patch_stategraph_compile +from agent_assembly.exceptions import ToolExecutionBlockedError + + +class GraphInterceptor: + def __init__(self) -> None: + self.events: list[str] = [] + + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + serialized = kwargs.get("serialized") + if isinstance(serialized, dict) and serialized.get("name") == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + def on_graph_node_start(self, **kwargs: object) -> None: + self.events.append(f"start:{kwargs.get('node_name')}") + + def on_graph_node_end(self, **kwargs: object) -> None: + self.events.append(f"end:{kwargs.get('node_name')}") + + +@pytest.mark.integration +def test_langgraph_compile_patch_wraps_multi_node_graph_and_blocks_denied_tool( + monkeypatch: pytest.MonkeyPatch, +) -> None: + interceptor = GraphInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + class FakeCompiledGraph: + def __init__(self) -> None: + self.nodes = { + "node_a": self._node_a, + "node_b": self._node_b, + } + + def _node_a(self, state: dict[str, object]) -> dict[str, object]: + handler.on_tool_start( + serialized={"name": "safe_tool"}, + input_str="{}", + run_id=uuid4(), + ) + return {**state, "node_a": "ok"} + + def _node_b(self, state: dict[str, object]) -> dict[str, object]: + handler.on_tool_start( + serialized={"name": "blocked_tool"}, + input_str="{}", + run_id=uuid4(), + ) + return {**state, "node_b": "should-not-complete"} + + def invoke(self, state: dict[str, object]) -> dict[str, object]: + current_state = state + for node_name in ("node_a", "node_b"): + current_state = self.nodes[node_name](current_state) + return current_state + + class FakeStateGraph: + def compile(self) -> FakeCompiledGraph: + return FakeCompiledGraph() + + fake_module = SimpleNamespace(StateGraph=FakeStateGraph) + + def fake_import_module(module_name: str) -> object: + if module_name == "langgraph.graph.state": + return fake_module + raise ImportError(module_name) + + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + fake_import_module, + ) + + patched = patch_stategraph_compile(handler) + assert patched is True + + compiled = FakeStateGraph().compile() + with pytest.raises(ToolExecutionBlockedError): + compiled.invoke({"step": "run"}) + + assert interceptor.events == [ + "start:node_a", + "end:node_a", + "start:node_b", + ] diff --git a/test/unit/adapters/langchain/test_callback_handler_async.py b/test/unit/adapters/langchain/test_callback_handler_async.py new file mode 100644 index 0000000..89ffeef --- /dev/null +++ b/test/unit/adapters/langchain/test_callback_handler_async.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from agent_assembly.adapters.langchain import AssemblyCallbackHandler +from agent_assembly.exceptions import ToolExecutionBlockedError + + +class AsyncInterceptor: + def __init__(self) -> None: + self.tool_end_calls = 0 + self.llm_scan_calls = 0 + self.llm_end_calls = 0 + self.pending_wait_calls = 0 + + async def check_tool_start(self, **kwargs: object) -> object: + return kwargs.get("decision", {"status": "allow"}) + + async def wait_for_tool_approval(self, **kwargs: object) -> object: + self.pending_wait_calls += 1 + return kwargs.get("approval_decision", {"status": "allow"}) + + async def on_tool_end(self, **kwargs: object) -> None: + self.tool_end_calls += 1 + + async def on_llm_start_scan(self, **kwargs: object) -> None: + self.llm_scan_calls += 1 + + async def on_llm_end(self, **kwargs: object) -> None: + self.llm_end_calls += 1 + + +@pytest.mark.asyncio +async def test_aon_tool_start_raises_when_governance_denies() -> None: + handler = AssemblyCallbackHandler(AsyncInterceptor()) + + with pytest.raises(ToolExecutionBlockedError): + await handler.aon_tool_start( + serialized={"name": "web_search"}, + input_str="query", + run_id=uuid4(), + decision={"status": "deny", "reason": "blocked"}, + ) + + +@pytest.mark.asyncio +async def test_aon_tool_start_waits_for_pending_approval() -> None: + interceptor = AsyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + await handler.aon_tool_start( + serialized={"name": "calendar_write"}, + input_str="create event", + run_id=uuid4(), + decision={"status": "pending"}, + approval_decision={"status": "allow"}, + ) + + assert interceptor.pending_wait_calls == 1 + + +@pytest.mark.asyncio +async def test_aon_tool_end_delegates_to_interceptor() -> None: + interceptor = AsyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + await handler.aon_tool_end(output={"ok": True}, run_id=uuid4()) + + assert interceptor.tool_end_calls == 1 + + +@pytest.mark.asyncio +async def test_aon_llm_start_delegates_to_interceptor() -> None: + interceptor = AsyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + await handler.aon_llm_start( + serialized={"name": "gpt"}, + prompts=["hello", "world"], + run_id=uuid4(), + ) + + assert interceptor.llm_scan_calls == 1 + + +@pytest.mark.asyncio +async def test_aon_llm_end_delegates_to_interceptor() -> None: + interceptor = AsyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + await handler.aon_llm_end(response={"text": "done"}, run_id=uuid4()) + + assert interceptor.llm_end_calls == 1 diff --git a/test/unit/adapters/langchain/test_callback_handler_sync.py b/test/unit/adapters/langchain/test_callback_handler_sync.py new file mode 100644 index 0000000..d42a782 --- /dev/null +++ b/test/unit/adapters/langchain/test_callback_handler_sync.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from agent_assembly.adapters.langchain import AssemblyCallbackHandler +from agent_assembly.exceptions import ToolExecutionBlockedError + + +class SyncInterceptor: + def __init__(self) -> None: + self.tool_end_calls = 0 + self.llm_scan_calls = 0 + self.llm_end_calls = 0 + self.pending_wait_calls = 0 + self.last_prompts: list[str] | None = None + + def check_tool_start(self, **kwargs: object) -> object: + return kwargs.get("decision", {"status": "allow"}) + + def wait_for_tool_approval(self, **kwargs: object) -> object: + self.pending_wait_calls += 1 + return kwargs.get("approval_decision", {"status": "allow"}) + + def on_tool_end(self, **kwargs: object) -> None: + self.tool_end_calls += 1 + + def on_llm_start_scan(self, **kwargs: object) -> None: + self.llm_scan_calls += 1 + prompts = kwargs.get("prompts") + if isinstance(prompts, list): + self.last_prompts = prompts + + def on_llm_end(self, **kwargs: object) -> None: + self.llm_end_calls += 1 + + +def test_on_tool_start_raises_when_governance_denies() -> None: + handler = AssemblyCallbackHandler(SyncInterceptor()) + + with pytest.raises(ToolExecutionBlockedError): + handler.on_tool_start( + serialized={"name": "web_search"}, + input_str="query", + run_id=uuid4(), + decision={"status": "deny", "reason": "blocked"}, + ) + + +def test_on_tool_start_allows_when_governance_allows() -> None: + interceptor = SyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + handler.on_tool_start( + serialized={"name": "web_search"}, + input_str="query", + run_id=uuid4(), + decision={"status": "allow"}, + ) + + assert interceptor.pending_wait_calls == 0 + + +def test_on_tool_start_waits_for_pending_approval() -> None: + interceptor = SyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + handler.on_tool_start( + serialized={"name": "calendar_write"}, + input_str="create event", + run_id=uuid4(), + decision={"status": "pending"}, + approval_decision={"status": "allow"}, + ) + + assert interceptor.pending_wait_calls == 1 + + +def test_on_tool_start_blocks_when_pending_never_approved() -> None: + handler = AssemblyCallbackHandler(SyncInterceptor()) + + with pytest.raises(ToolExecutionBlockedError): + handler.on_tool_start( + serialized={"name": "calendar_write"}, + input_str="create event", + run_id=uuid4(), + decision={"status": "pending"}, + approval_decision={"status": "deny"}, + ) + + +def test_on_tool_end_delegates_to_interceptor() -> None: + interceptor = SyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + handler.on_tool_end(output={"ok": True}, run_id=uuid4()) + + assert interceptor.tool_end_calls == 1 + + +def test_on_llm_start_scans_without_mutating_prompts() -> None: + interceptor = SyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + prompts = ["hello", "world"] + + handler.on_llm_start( + serialized={"name": "gpt"}, + prompts=prompts, + run_id=uuid4(), + ) + + assert interceptor.llm_scan_calls == 1 + assert interceptor.last_prompts is prompts + assert prompts == ["hello", "world"] + + +def test_on_llm_end_delegates_to_interceptor() -> None: + interceptor = SyncInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + handler.on_llm_end(response={"text": "done"}, run_id=uuid4()) + + assert interceptor.llm_end_calls == 1 diff --git a/test/unit/adapters/langchain/test_langgraph_patch.py b/test/unit/adapters/langchain/test_langgraph_patch.py new file mode 100644 index 0000000..2c735e6 --- /dev/null +++ b/test/unit/adapters/langchain/test_langgraph_patch.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.langchain import langgraph_patch + + +class _TestAwaitable: + def __await__(self): # type: ignore[no-untyped-def] + if False: + yield None + return None + + +class GraphEventRecorder: + def __init__(self, await_callbacks: bool = False) -> None: + self.await_callbacks = await_callbacks + self.events: list[tuple[str, str, object]] = [] + + def on_graph_node_start(self, *, node_name: str, state: object) -> object: + self.events.append(("start", node_name, state)) + if self.await_callbacks: + return _TestAwaitable() + return None + + def on_graph_node_end(self, *, node_name: str, state: object, result: object) -> object: + self.events.append(("end", node_name, result)) + if self.await_callbacks: + return _TestAwaitable() + return None + + +def test_extract_state_prefers_args_and_falls_back_to_kwargs() -> None: + assert langgraph_patch._extract_state(({"from": "args"},), {"state": {"from": "kwargs"}}) == { + "from": "args" + } + assert langgraph_patch._extract_state((), {"state": {"from": "kwargs"}}) == { + "from": "kwargs" + } + + +def test_invoke_hooks_handle_missing_methods_and_awaitables() -> None: + # Missing hook methods should no-op. + langgraph_patch._invoke_pre_node_hook(object(), "n1", {"state": 1}) + langgraph_patch._invoke_post_node_hook(object(), "n1", {"state": 1}, {"result": 1}) + + recorder = GraphEventRecorder(await_callbacks=True) + langgraph_patch._invoke_pre_node_hook(recorder, "n2", {"state": 2}) + langgraph_patch._invoke_post_node_hook(recorder, "n2", {"state": 2}, {"result": 2}) + + assert recorder.events == [ + ("start", "n2", {"state": 2}), + ("end", "n2", {"result": 2}), + ] + + +@pytest.mark.asyncio +async def test_wrap_node_callable_handles_already_wrapped_and_async_results() -> None: + recorder = GraphEventRecorder() + + async def async_node(state: dict[str, object]) -> dict[str, object]: + return {"ok": state} + + wrapped_async = langgraph_patch._wrap_node_callable("async_node", async_node, recorder) + async_result = wrapped_async({"v": 1}) + assert async_result is not None + assert await async_result == {"ok": {"v": 1}} + + def already_wrapped(state: dict[str, object]) -> dict[str, object]: + return {"state": state} + + setattr(already_wrapped, langgraph_patch._NODE_WRAPPED_FLAG, True) + assert langgraph_patch._wrap_node_callable("wrapped", already_wrapped, recorder) is already_wrapped + + assert recorder.events == [ + ("start", "async_node", {"v": 1}), + ("end", "async_node", {"ok": {"v": 1}}), + ] + + +def test_wrap_node_map_covers_non_mapping_and_assignment_failure() -> None: + assert langgraph_patch._wrap_node_map(object(), GraphEventRecorder()) is False + + class FailingNodeMap: + def __init__(self) -> None: + self._items = {"node": lambda state: state} + + def items(self) -> Any: + return self._items.items() + + def __setitem__(self, key: object, value: object) -> None: + del key, value + raise RuntimeError("cannot assign") + + assert langgraph_patch._wrap_node_map(FailingNodeMap(), GraphEventRecorder()) is False + + +@pytest.mark.asyncio +async def test_wrap_node_map_wraps_invoke_and_ainvoke_members() -> None: + recorder = GraphEventRecorder() + + class InvokeNode: + def invoke(self, state: dict[str, object]) -> dict[str, object]: + return {"invoke": state} + + async def ainvoke(self, state: dict[str, object]) -> dict[str, object]: + return {"ainvoke": state} + + node = InvokeNode() + node_map: dict[str, object] = {"node": node} + assert langgraph_patch._wrap_node_map(node_map, recorder) is True + + invoke_result = node.invoke({"v": 1}) + ainvoke_result = await node.ainvoke({"v": 2}) + + assert invoke_result == {"invoke": {"v": 1}} + assert ainvoke_result == {"ainvoke": {"v": 2}} + assert recorder.events == [ + ("start", "node", {"v": 1}), + ("end", "node", {"invoke": {"v": 1}}), + ("start", "node", {"v": 2}), + ("end", "node", {"ainvoke": {"v": 2}}), + ] + + +def test_wrap_compiled_graph_nodes_supports_pregel_fallback() -> None: + recorder = GraphEventRecorder() + + class FakeCompiledGraph: + def __init__(self) -> None: + self.nodes = None + self._nodes = None + self.pregel = None + self._pregel = SimpleNamespace( + nodes={"pregel_node": lambda state: {"pregel": state}}, + _nodes=None, + ) + + compiled_graph = FakeCompiledGraph() + assert langgraph_patch._wrap_compiled_graph_nodes(compiled_graph, recorder) is True + + result = compiled_graph._pregel.nodes["pregel_node"]({"k": 1}) + assert result == {"pregel": {"k": 1}} + assert recorder.events == [ + ("start", "pregel_node", {"k": 1}), + ("end", "pregel_node", {"pregel": {"k": 1}}), + ] + + +def test_patch_stategraph_compile_handles_import_and_stategraph_edge_cases( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: (_ for _ in ()).throw(ImportError(name)), + ) + assert langgraph_patch.patch_stategraph_compile(GraphEventRecorder()) is False + + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: SimpleNamespace(), + ) + assert langgraph_patch.patch_stategraph_compile(GraphEventRecorder()) is False + + class AlreadyPatchedStateGraph: + compile = lambda self: object() + + setattr(AlreadyPatchedStateGraph, langgraph_patch._PATCHED_FLAG, True) + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: SimpleNamespace(StateGraph=AlreadyPatchedStateGraph), + ) + assert langgraph_patch.patch_stategraph_compile(GraphEventRecorder()) is True + + +def test_patch_stategraph_compile_fallback_wraps_sync_invoke( + monkeypatch: pytest.MonkeyPatch, +) -> None: + recorder = GraphEventRecorder() + + class FallbackCompiledGraph: + def invoke(self, *, state: dict[str, object]) -> dict[str, object]: + return {"invoke": state} + + class StateGraphWithFallbackInvoke: + def compile(self) -> FallbackCompiledGraph: + return FallbackCompiledGraph() + + fake_module = SimpleNamespace(StateGraph=StateGraphWithFallbackInvoke) + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: fake_module, + ) + + assert langgraph_patch.patch_stategraph_compile(recorder) is True + compiled_graph = StateGraphWithFallbackInvoke().compile() + result = compiled_graph.invoke(state={"n": 1}) + + assert result == {"invoke": {"n": 1}} + assert recorder.events == [ + ("start", "graph.invoke", {"n": 1}), + ("end", "graph.invoke", {"invoke": {"n": 1}}), + ] + + +@pytest.mark.asyncio +async def test_patch_stategraph_compile_fallback_wraps_async_invoke( + monkeypatch: pytest.MonkeyPatch, +) -> None: + recorder = GraphEventRecorder() + + class AsyncFallbackCompiledGraph: + async def invoke(self, state: dict[str, object]) -> dict[str, object]: + return {"invoke": state} + + class StateGraphWithAsyncFallbackInvoke: + def compile(self) -> AsyncFallbackCompiledGraph: + return AsyncFallbackCompiledGraph() + + fake_module = SimpleNamespace(StateGraph=StateGraphWithAsyncFallbackInvoke) + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: fake_module, + ) + + assert langgraph_patch.patch_stategraph_compile(recorder) is True + compiled_graph = StateGraphWithAsyncFallbackInvoke().compile() + result = await compiled_graph.invoke({"n": 2}) + + assert result == {"invoke": {"n": 2}} + assert recorder.events == [ + ("start", "graph.invoke", {"n": 2}), + ("end", "graph.invoke", {"invoke": {"n": 2}}), + ] diff --git a/test/unit/adapters/langchain/test_runtime.py b/test/unit/adapters/langchain/test_runtime.py new file mode 100644 index 0000000..51707a9 --- /dev/null +++ b/test/unit/adapters/langchain/test_runtime.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from agent_assembly import init_assembly +from agent_assembly.adapters.langchain.runtime import ( + _reset_runtime_state_for_tests, + auto_inject_callback_handler, + get_active_callback_handler, +) + + +def test_auto_inject_callback_handler_is_idempotent() -> None: + _reset_runtime_state_for_tests() + + first = auto_inject_callback_handler(interceptor=object()) + second = auto_inject_callback_handler(interceptor=object()) + + assert first is second + assert get_active_callback_handler() is first + + +def test_init_assembly_auto_injects_callback_handler() -> None: + _reset_runtime_state_for_tests() + + client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent") + try: + assert get_active_callback_handler() is not None + finally: + client.close() + + +def test_init_assembly_reuses_existing_callback_handler() -> None: + _reset_runtime_state_for_tests() + + first_client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent-a") + second_client = init_assembly(gateway_url="http://localhost:8080", agent_id="test-agent-b") + try: + first_handler = get_active_callback_handler() + assert first_handler is not None + assert get_active_callback_handler() is first_handler + finally: + first_client.close() + second_client.close()