diff --git a/.github/workflows/rw_run_all_test_and_record.yaml b/.github/workflows/rw_run_all_test_and_record.yaml index dfeb223..8809e34 100644 --- a/.github/workflows/rw_run_all_test_and_record.yaml +++ b/.github/workflows/rw_run_all_test_and_record.yaml @@ -164,6 +164,6 @@ jobs: - name: SonarCloud Scan uses: SonarSource/sonarqube-scan-action@v7.1.0 env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_TOKEN: ${{ github.token }} SONAR_TOKEN: ${{ secrets.sonar_token }} SONAR_HOST_URL: https://sonarcloud.io diff --git a/agent_assembly/adapters/langchain/__init__.py b/agent_assembly/adapters/langchain/__init__.py index b2a365f..fa6cf78 100644 --- a/agent_assembly/adapters/langchain/__init__.py +++ b/agent_assembly/adapters/langchain/__init__.py @@ -1,6 +1,7 @@ """LangChain adapter package.""" from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler +from agent_assembly.adapters.langgraph import LangGraphPatch from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile from agent_assembly.adapters.langchain.runtime import ( auto_inject_callback_handler, @@ -9,6 +10,7 @@ __all__ = [ "AssemblyCallbackHandler", + "LangGraphPatch", "patch_stategraph_compile", "auto_inject_callback_handler", "get_active_callback_handler", diff --git a/agent_assembly/adapters/langchain/callback_handler.py b/agent_assembly/adapters/langchain/callback_handler.py index 751feec..b015224 100644 --- a/agent_assembly/adapters/langchain/callback_handler.py +++ b/agent_assembly/adapters/langchain/callback_handler.py @@ -277,16 +277,46 @@ async def aon_llm_end( await result return None - def on_graph_node_start(self, node_name: str, state: Any) -> None: + def on_graph_node_start( + self, + node_name: str, + state: Any, + *, + agent_id: str | None = None, + state_keys: list[str] | None = None, + config: Any = None, + ) -> None: method = getattr(self._interceptor, "on_graph_node_start", None) if not callable(method): return None - method(node_name=node_name, state=state) + method( + node_name=node_name, + agent_id=agent_id, + state=state, + state_keys=state_keys, + config=config, + ) return None - def on_graph_node_end(self, node_name: str, state: Any, result: Any) -> None: + def on_graph_node_end( + self, + node_name: str, + state: Any, + result: Any, + *, + agent_id: str | None = None, + state_delta: dict[str, Any] | None = None, + config: Any = None, + ) -> None: method = getattr(self._interceptor, "on_graph_node_end", None) if not callable(method): return None - method(node_name=node_name, state=state, result=result) + method( + node_name=node_name, + agent_id=agent_id, + state=state, + result=result, + state_delta=state_delta, + config=config, + ) return None diff --git a/agent_assembly/adapters/langchain/langgraph_patch.py b/agent_assembly/adapters/langchain/langgraph_patch.py index 80da5d8..161995c 100644 --- a/agent_assembly/adapters/langchain/langgraph_patch.py +++ b/agent_assembly/adapters/langchain/langgraph_patch.py @@ -1,200 +1,33 @@ -"""LangGraph compile-time patching for governance interception.""" - -from __future__ import annotations +"""Backward-compatible shim for LangGraph patch utilities.""" import importlib -import inspect -from typing import Any - -_PATCHED_FLAG = "_agent_assembly_compile_patched" -_ORIGINAL_COMPILE = "_agent_assembly_original_compile" -_NODE_WRAPPED_FLAG = "_agent_assembly_node_wrapped" -_INVOKE_WRAPPED_FLAG = "_agent_assembly_invoke_wrapped" - - -def _extract_state(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: - if args: - return args[0] - return kwargs.get("state") - - -def _invoke_pre_node_hook(callback_handler: Any, node_name: str, state: Any) -> None: - method = getattr(callback_handler, "on_graph_node_start", None) - if not callable(method): - return None - - result = method(node_name=node_name, state=state) - if inspect.isawaitable(result): - return None - - return None - - -def _invoke_post_node_hook(callback_handler: Any, node_name: str, state: Any, result: Any) -> None: - method = getattr(callback_handler, "on_graph_node_end", None) - if not callable(method): - return None - - callback_result = method(node_name=node_name, state=state, result=result) - if inspect.isawaitable(callback_result): - return None - - return None - - -def _wrap_node_callable(node_name: str, node_callable: Any, callback_handler: Any) -> Any: - if getattr(node_callable, _NODE_WRAPPED_FLAG, False): - return node_callable - - def wrapped_node(*node_args: Any, **node_kwargs: Any) -> Any: - state = _extract_state(node_args, node_kwargs) - _invoke_pre_node_hook(callback_handler, node_name=node_name, state=state) - - node_result = node_callable(*node_args, **node_kwargs) - if inspect.isawaitable(node_result): - async def awaited_node_result() -> Any: - resolved_result = await node_result - _invoke_post_node_hook( - callback_handler, - node_name=node_name, - state=state, - result=resolved_result, - ) - return resolved_result - - return awaited_node_result() - - _invoke_post_node_hook( - callback_handler, - node_name=node_name, - state=state, - result=node_result, - ) - return node_result - - setattr(wrapped_node, _NODE_WRAPPED_FLAG, True) - return wrapped_node - - -def _wrap_node_map(node_map: Any, callback_handler: Any) -> bool: - items_method = getattr(node_map, "items", None) - if not callable(items_method): - return False - - wrapped_any = False - for node_name, node_executor in list(items_method()): - if callable(node_executor): - wrapped_executor = _wrap_node_callable(str(node_name), node_executor, callback_handler) - if wrapped_executor is node_executor: - continue - try: - node_map[node_name] = wrapped_executor - except Exception: - continue - wrapped_any = True - continue - - invoke = getattr(node_executor, "invoke", None) - if callable(invoke): - setattr( - node_executor, - "invoke", - _wrap_node_callable(str(node_name), invoke, callback_handler), - ) - wrapped_any = True - - ainvoke = getattr(node_executor, "ainvoke", None) - if callable(ainvoke): - setattr( - node_executor, - "ainvoke", - _wrap_node_callable(str(node_name), ainvoke, callback_handler), - ) - wrapped_any = True - - return wrapped_any - - -def _wrap_compiled_graph_nodes(compiled_graph: Any, callback_handler: Any) -> bool: - candidate_maps = [ - getattr(compiled_graph, "nodes", None), - getattr(compiled_graph, "_nodes", None), - ] - - pregel = getattr(compiled_graph, "pregel", None) - if pregel is None: - pregel = getattr(compiled_graph, "_pregel", None) - if pregel is not None: - candidate_maps.extend( - [ - getattr(pregel, "nodes", None), - getattr(pregel, "_nodes", None), - ] - ) - - wrapped_any = False - for node_map in candidate_maps: - if node_map is None: - continue - if _wrap_node_map(node_map, callback_handler): - wrapped_any = True - - return wrapped_any - - -def patch_stategraph_compile(callback_handler: Any) -> bool: - """Patch `StateGraph.compile()` to attach runtime governance hooks.""" - try: - module = importlib.import_module("langgraph.graph.state") - except ImportError: - return False - - state_graph_cls = getattr(module, "StateGraph", None) - if state_graph_cls is None: - return False - - if getattr(state_graph_cls, _PATCHED_FLAG, False): - return True - - original_compile = state_graph_cls.compile - - def patched_compile(self: Any, *args: Any, **kwargs: Any) -> Any: - compiled_graph = original_compile(self, *args, **kwargs) - nodes_wrapped = _wrap_compiled_graph_nodes(compiled_graph, callback_handler) - if not nodes_wrapped: - invoke = getattr(compiled_graph, "invoke", None) - if callable(invoke) and not getattr(invoke, _INVOKE_WRAPPED_FLAG, False): - def wrapped_invoke(*invoke_args: Any, **invoke_kwargs: Any) -> Any: - state = _extract_state(invoke_args, invoke_kwargs) - _invoke_pre_node_hook(callback_handler, node_name="graph.invoke", state=state) - - invoke_result = invoke(*invoke_args, **invoke_kwargs) - if inspect.isawaitable(invoke_result): - async def awaited_invoke_result() -> Any: - resolved_result = await invoke_result - _invoke_post_node_hook( - callback_handler, - node_name="graph.invoke", - state=state, - result=resolved_result, - ) - return resolved_result - - return awaited_invoke_result() - - _invoke_post_node_hook( - callback_handler, - node_name="graph.invoke", - state=state, - result=invoke_result, - ) - return invoke_result - - setattr(wrapped_invoke, _INVOKE_WRAPPED_FLAG, True) - setattr(compiled_graph, "invoke", wrapped_invoke) - return compiled_graph - setattr(state_graph_cls, _ORIGINAL_COMPILE, original_compile) - setattr(state_graph_cls, "compile", patched_compile) - setattr(state_graph_cls, _PATCHED_FLAG, True) - return True +from agent_assembly.adapters.langgraph import patch as _impl + +LangGraphPatch = _impl.LangGraphPatch +patch_stategraph_compile = _impl.patch_stategraph_compile + +_PATCHED_FLAG = _impl._PATCHED_FLAG +_ORIGINAL_COMPILE = _impl._ORIGINAL_COMPILE +_NODE_WRAPPED_FLAG = _impl._NODE_WRAPPED_FLAG +_INVOKE_WRAPPED_FLAG = _impl._INVOKE_WRAPPED_FLAG + +_extract_state = _impl._extract_state +_extract_config = _impl._extract_config +_extract_agent_id = _impl._extract_agent_id +_summarize_state_keys = _impl._summarize_state_keys +_compute_state_delta = _impl._compute_state_delta +_invoke_pre_node_hook = _impl._invoke_pre_node_hook +_invoke_post_node_hook = _impl._invoke_post_node_hook +_wrap_node_callable = _impl._wrap_node_callable +_make_sync_node_wrapper = _impl._make_sync_node_wrapper +_make_async_node_wrapper = _impl._make_async_node_wrapper +_make_assembly_node_wrapper = _impl._make_assembly_node_wrapper +_wrap_node_map = _impl._wrap_node_map +_wrap_compiled_graph_nodes = _impl._wrap_compiled_graph_nodes +_discover_compiled_graph_node_maps = _impl._discover_compiled_graph_node_maps +_apply_stategraph_compile_patch = _impl._apply_stategraph_compile_patch +_wrap_graph_invoke_fallback = _impl._wrap_graph_invoke_fallback +_record_node_enter = _impl._record_node_enter +_record_node_exit = _impl._record_node_exit +_load_stategraph_class = _impl._load_stategraph_class diff --git a/agent_assembly/adapters/langchain/runtime.py b/agent_assembly/adapters/langchain/runtime.py index 81c2e1a..50bb20f 100644 --- a/agent_assembly/adapters/langchain/runtime.py +++ b/agent_assembly/adapters/langchain/runtime.py @@ -6,7 +6,7 @@ from typing import Any from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler -from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile +from agent_assembly.adapters.langgraph import LangGraphPatch _ACTIVE_CALLBACK_HANDLER: AssemblyCallbackHandler | None = None _RUNTIME_LOCK = Lock() @@ -18,12 +18,12 @@ def auto_inject_callback_handler(interceptor: Any) -> AssemblyCallbackHandler: with _RUNTIME_LOCK: if _ACTIVE_CALLBACK_HANDLER is not None: - patch_stategraph_compile(_ACTIVE_CALLBACK_HANDLER) + LangGraphPatch(_ACTIVE_CALLBACK_HANDLER).apply() return _ACTIVE_CALLBACK_HANDLER handler = AssemblyCallbackHandler(interceptor) _ACTIVE_CALLBACK_HANDLER = handler - patch_stategraph_compile(handler) + LangGraphPatch(handler).apply() return handler diff --git a/agent_assembly/adapters/langgraph/__init__.py b/agent_assembly/adapters/langgraph/__init__.py new file mode 100644 index 0000000..6f7f2db --- /dev/null +++ b/agent_assembly/adapters/langgraph/__init__.py @@ -0,0 +1,5 @@ +"""LangGraph adapter package.""" + +from agent_assembly.adapters.langgraph.patch import LangGraphPatch + +__all__ = ["LangGraphPatch"] diff --git a/agent_assembly/adapters/langgraph/patch.py b/agent_assembly/adapters/langgraph/patch.py new file mode 100644 index 0000000..51aa7db --- /dev/null +++ b/agent_assembly/adapters/langgraph/patch.py @@ -0,0 +1,375 @@ +"""LangGraph patch module.""" + +from __future__ import annotations + +from dataclasses import dataclass +import importlib +import inspect +from typing import Any + +_PATCHED_FLAG = "_agent_assembly_compile_patched" +_ORIGINAL_COMPILE = "_agent_assembly_original_compile" +_NODE_WRAPPED_FLAG = "_agent_assembly_node_wrapped" +_INVOKE_WRAPPED_FLAG = "_agent_assembly_invoke_wrapped" + + +@dataclass(slots=True) +class LangGraphPatch: + """Applies LangGraph runtime monkey-patching for node-level governance hooks.""" + + callback_handler: Any + + def apply(self) -> bool: + """Apply patching once and return whether patch wiring is active.""" + state_graph_cls = _load_stategraph_class() + if state_graph_cls is None: + return False + if getattr(state_graph_cls, _PATCHED_FLAG, False): + return True + _apply_stategraph_compile_patch(state_graph_cls, self.callback_handler) + return True + + +def patch_stategraph_compile(callback_handler: Any) -> bool: + """Backward-compatible helper used by existing runtime wiring.""" + return LangGraphPatch(callback_handler=callback_handler).apply() + + +def _invoke_pre_node_hook(callback_handler: Any, node_name: str, state: object) -> None: + """Backward-compatible pre-node hook helper.""" + _record_node_enter( + callback_handler, + node_name=node_name, + state=state, + config=None, + ) + return None + + +def _invoke_post_node_hook( + callback_handler: Any, + node_name: str, + state: object, + result: object, +) -> None: + """Backward-compatible post-node hook helper.""" + _record_node_exit( + callback_handler, + node_name=node_name, + previous_state=state, + next_state=result, + config=None, + ) + return None + + +def _wrap_node_callable(node_name: str, node_func: Any, callback_handler: Any) -> Any: + """Backward-compatible node wrapper helper.""" + return _make_assembly_node_wrapper(node_name, node_func, callback_handler) + + +def _extract_state(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if args: + return args[0] + return kwargs.get("state") + + +def _extract_config(args: tuple[Any, ...], kwargs: dict[str, Any]) -> object: + if "config" in kwargs: + return kwargs["config"] + if len(args) >= 2: + return args[1] + return None + + +def _extract_agent_id(config: object) -> str | None: + if not isinstance(config, dict): + return None + + direct_agent_id = config.get("agent_id") + if isinstance(direct_agent_id, str) and direct_agent_id: + return direct_agent_id + + configurable = config.get("configurable") + if isinstance(configurable, dict): + nested_agent_id = configurable.get("agent_id") + if isinstance(nested_agent_id, str) and nested_agent_id: + return nested_agent_id + + metadata = config.get("metadata") + if isinstance(metadata, dict): + metadata_agent_id = metadata.get("agent_id") + if isinstance(metadata_agent_id, str) and metadata_agent_id: + return metadata_agent_id + + return None + + +def _summarize_state_keys(state: object) -> list[str]: + if not isinstance(state, dict): + return [] + return [str(key) for key in state.keys()] + + +def _compute_state_delta(previous_state: object, next_state: object) -> dict[str, object]: + if not isinstance(previous_state, dict) or not isinstance(next_state, dict): + return {"changed_keys": [], "new_values": {}, "removed_keys": []} + + changed_keys: list[str] = [] + new_values: dict[str, object] = {} + for key, value in next_state.items(): + if key not in previous_state or previous_state[key] != value: + key_str = str(key) + changed_keys.append(key_str) + new_values[key_str] = value + + removed_keys = [str(key) for key in previous_state.keys() if key not in next_state] + + return { + "changed_keys": changed_keys, + "new_values": new_values, + "removed_keys": removed_keys, + } + + +def _make_sync_node_wrapper(node_name: str, original_func: Any, callback_handler: Any) -> Any: + def wrapped_node(*node_args: Any, **node_kwargs: Any) -> Any: + state = _extract_state(node_args, node_kwargs) + config = _extract_config(node_args, node_kwargs) + _record_node_enter(callback_handler, node_name=node_name, state=state, config=config) + result = original_func(*node_args, **node_kwargs) + _record_node_exit( + callback_handler, + node_name=node_name, + previous_state=state, + next_state=result, + config=config, + ) + return result + + return wrapped_node + + +def _make_async_node_wrapper(node_name: str, original_func: Any, callback_handler: Any) -> Any: + async def wrapped_node(*node_args: Any, **node_kwargs: Any) -> Any: + state = _extract_state(node_args, node_kwargs) + config = _extract_config(node_args, node_kwargs) + _record_node_enter(callback_handler, node_name=node_name, state=state, config=config) + result = await original_func(*node_args, **node_kwargs) + _record_node_exit( + callback_handler, + node_name=node_name, + previous_state=state, + next_state=result, + config=config, + ) + return result + + return wrapped_node + + +def _make_assembly_node_wrapper(node_name: str, original_func: Any, callback_handler: Any) -> Any: + if getattr(original_func, _NODE_WRAPPED_FLAG, False): + return original_func + + if inspect.iscoroutinefunction(original_func): + wrapped = _make_async_node_wrapper(node_name, original_func, callback_handler) + else: + wrapped = _make_sync_node_wrapper(node_name, original_func, callback_handler) + + setattr(wrapped, _NODE_WRAPPED_FLAG, True) + return wrapped + + +def _wrap_node_map(node_map: Any, callback_handler: Any) -> bool: + items_method = getattr(node_map, "items", None) + if not callable(items_method): + return False + + wrapped_any = False + for node_name, node_executor in list(items_method()): + if callable(node_executor): + wrapped_executor = _make_assembly_node_wrapper( + str(node_name), node_executor, callback_handler + ) + if wrapped_executor is node_executor: + continue + try: + node_map[node_name] = wrapped_executor + except Exception: + continue + wrapped_any = True + continue + + invoke = getattr(node_executor, "invoke", None) + if callable(invoke): + wrapped_invoke = _make_assembly_node_wrapper(str(node_name), invoke, callback_handler) + setattr(node_executor, "invoke", wrapped_invoke) + wrapped_any = True + + ainvoke = getattr(node_executor, "ainvoke", None) + if callable(ainvoke): + wrapped_ainvoke = _make_assembly_node_wrapper(str(node_name), ainvoke, callback_handler) + setattr(node_executor, "ainvoke", wrapped_ainvoke) + wrapped_any = True + + return wrapped_any + + +def _wrap_compiled_graph_nodes(compiled_graph: Any, callback_handler: Any) -> bool: + wrapped_any = False + for node_map in _discover_compiled_graph_node_maps(compiled_graph): + if _wrap_node_map(node_map, callback_handler): + wrapped_any = True + return wrapped_any + + +def _apply_stategraph_compile_patch(state_graph_cls: type[Any], callback_handler: Any) -> None: + original_compile = state_graph_cls.compile + + def patched_compile(self: Any, *args: Any, **kwargs: Any) -> Any: + compiled_graph = original_compile(self, *args, **kwargs) + nodes_wrapped = _wrap_compiled_graph_nodes(compiled_graph, callback_handler) + if not nodes_wrapped: + _wrap_graph_invoke_fallback(compiled_graph, callback_handler) + return compiled_graph + + setattr(state_graph_cls, _ORIGINAL_COMPILE, original_compile) + setattr(state_graph_cls, "compile", patched_compile) + setattr(state_graph_cls, _PATCHED_FLAG, True) + + +def _wrap_graph_invoke_fallback(compiled_graph: Any, callback_handler: Any) -> None: + invoke = getattr(compiled_graph, "invoke", None) + if not callable(invoke) or getattr(invoke, _INVOKE_WRAPPED_FLAG, False): + return None + + if inspect.iscoroutinefunction(invoke): + async def wrapped_async_invoke(*invoke_args: Any, **invoke_kwargs: Any) -> Any: + state = _extract_state(invoke_args, invoke_kwargs) + config = _extract_config(invoke_args, invoke_kwargs) + _record_node_enter(callback_handler, node_name="graph.invoke", state=state, config=config) + result = await invoke(*invoke_args, **invoke_kwargs) + _record_node_exit( + callback_handler, + node_name="graph.invoke", + previous_state=state, + next_state=result, + config=config, + ) + return result + wrapped_invoke: Any = wrapped_async_invoke + else: + def wrapped_sync_invoke(*invoke_args: Any, **invoke_kwargs: Any) -> Any: + state = _extract_state(invoke_args, invoke_kwargs) + config = _extract_config(invoke_args, invoke_kwargs) + _record_node_enter(callback_handler, node_name="graph.invoke", state=state, config=config) + result = invoke(*invoke_args, **invoke_kwargs) + _record_node_exit( + callback_handler, + node_name="graph.invoke", + previous_state=state, + next_state=result, + config=config, + ) + return result + wrapped_invoke = wrapped_sync_invoke + + setattr(wrapped_invoke, _INVOKE_WRAPPED_FLAG, True) + setattr(compiled_graph, "invoke", wrapped_invoke) + + +def _record_node_enter(callback_handler: Any, *, node_name: str, state: object, config: object) -> None: + method = getattr(callback_handler, "on_graph_node_start", None) + if not callable(method): + return None + + hook_kwargs = { + "node_name": node_name, + "agent_id": _extract_agent_id(config), + "state": state, + "state_keys": _summarize_state_keys(state), + "config": config, + } + try: + method(**hook_kwargs) + except TypeError as exc: + if not _is_call_signature_type_error(exc): + raise + method(node_name=node_name, state=state) + return None + + +def _record_node_exit( + callback_handler: Any, + *, + node_name: str, + previous_state: object, + next_state: object, + config: object, +) -> None: + method = getattr(callback_handler, "on_graph_node_end", None) + if not callable(method): + return None + + hook_kwargs = { + "node_name": node_name, + "agent_id": _extract_agent_id(config), + "state": previous_state, + "result": next_state, + "state_delta": _compute_state_delta(previous_state, next_state), + "config": config, + } + try: + method(**hook_kwargs) + except TypeError as exc: + if not _is_call_signature_type_error(exc): + raise + method(node_name=node_name, state=previous_state, result=next_state) + return None + + +def _is_call_signature_type_error(error: TypeError) -> bool: + message = str(error) + signature_markers = ( + "unexpected keyword argument", + "missing 1 required positional argument", + "missing required positional argument", + "takes ", + "positional arguments but", + "keyword-only argument", + ) + return any(marker in message for marker in signature_markers) + + +def _discover_compiled_graph_node_maps(compiled_graph: Any) -> list[Any]: + candidate_maps = [ + getattr(compiled_graph, "nodes", None), + getattr(compiled_graph, "_nodes", None), + ] + + pregel = getattr(compiled_graph, "pregel", None) + if pregel is None: + pregel = getattr(compiled_graph, "_pregel", None) + if pregel is not None: + candidate_maps.extend( + [ + getattr(pregel, "nodes", None), + getattr(pregel, "_nodes", None), + ] + ) + + return [node_map for node_map in candidate_maps if node_map is not None] + + +def _load_stategraph_class() -> type[Any] | None: + try: + module = importlib.import_module("langgraph.graph.state") + except ImportError: + return None + + state_graph_cls = getattr(module, "StateGraph", None) + if isinstance(state_graph_cls, type): + return state_graph_cls + + return None diff --git a/test/integration/test_langgraph_interception_integration.py b/test/integration/test_langgraph_interception_integration.py index 606af97..c90f185 100644 --- a/test/integration/test_langgraph_interception_integration.py +++ b/test/integration/test_langgraph_interception_integration.py @@ -90,3 +90,73 @@ def fake_import_module(module_name: str) -> object: "end:node_a", "start:node_b", ] + + +@pytest.mark.integration +def test_langgraph_compile_patch_allows_downstream_node_after_blocked_tool_handled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + interceptor = GraphInterceptor() + handler = AssemblyCallbackHandler(interceptor) + + class FakeCompiledGraph: + def __init__(self) -> None: + self.nodes = { + "node_a": self._node_a, + "node_b": self._node_b, + "node_c": self._node_c, + } + + def _node_a(self, state: dict[str, object]) -> dict[str, object]: + return {**state, "node_a": "ok"} + + def _node_b(self, state: dict[str, object]) -> dict[str, object]: + try: + handler.on_tool_start( + serialized={"name": "blocked_tool"}, + input_str="{}", + run_id=uuid4(), + ) + except ToolExecutionBlockedError: + return {**state, "node_b": "blocked"} + return {**state, "node_b": "unexpected-allow"} + + def _node_c(self, state: dict[str, object]) -> dict[str, object]: + return {**state, "node_c": "ok"} + + def invoke(self, state: dict[str, object]) -> dict[str, object]: + current_state = state + for node_name in ("node_a", "node_b", "node_c"): + current_state = self.nodes[node_name](current_state) + return current_state + + class FakeStateGraph: + def compile(self) -> FakeCompiledGraph: + return FakeCompiledGraph() + + fake_module = SimpleNamespace(StateGraph=FakeStateGraph) + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda module_name: fake_module + if module_name == "langgraph.graph.state" + else (_ for _ in ()).throw(ImportError(module_name)), + ) + + assert patch_stategraph_compile(handler) is True + compiled = FakeStateGraph().compile() + result = compiled.invoke({"step": "run"}) + + assert result == { + "step": "run", + "node_a": "ok", + "node_b": "blocked", + "node_c": "ok", + } + assert interceptor.events == [ + "start:node_a", + "end:node_a", + "start:node_b", + "end:node_b", + "start:node_c", + "end:node_c", + ] diff --git a/test/unit/adapters/langchain/test_langgraph_patch.py b/test/unit/adapters/langchain/test_langgraph_patch.py index 2c735e6..f340fe3 100644 --- a/test/unit/adapters/langchain/test_langgraph_patch.py +++ b/test/unit/adapters/langchain/test_langgraph_patch.py @@ -57,6 +57,33 @@ def test_invoke_hooks_handle_missing_methods_and_awaitables() -> None: ] +def test_invoke_hooks_only_fallback_on_signature_mismatch() -> None: + class SignatureMismatchRecorder: + def __init__(self) -> None: + self.events: list[tuple[str, object]] = [] + + def on_graph_node_start(self, *, node_name: str, state: object) -> None: + self.events.append(("start", state)) + + def on_graph_node_end(self, *, node_name: str, state: object, result: object) -> None: + self.events.append(("end", result)) + + recorder = SignatureMismatchRecorder() + langgraph_patch._invoke_pre_node_hook(recorder, "n3", {"state": 3}) + langgraph_patch._invoke_post_node_hook(recorder, "n3", {"state": 3}, {"result": 3}) + assert recorder.events == [("start", {"state": 3}), ("end", {"result": 3})] + + +def test_invoke_hooks_reraise_internal_typeerror() -> None: + class InternalTypeErrorRecorder: + def on_graph_node_start(self, **kwargs: object) -> None: + del kwargs + raise TypeError("internal callback failure") + + with pytest.raises(TypeError, match="internal callback failure"): + langgraph_patch._invoke_pre_node_hook(InternalTypeErrorRecorder(), "n4", {"state": 4}) + + @pytest.mark.asyncio async def test_wrap_node_callable_handles_already_wrapped_and_async_results() -> None: recorder = GraphEventRecorder() @@ -206,6 +233,70 @@ def compile(self) -> FallbackCompiledGraph: ] +def test_patch_stategraph_compile_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeStateGraph: + def compile(self) -> object: + return object() + + monkeypatch.setattr( + "agent_assembly.adapters.langchain.langgraph_patch.importlib.import_module", + lambda name: SimpleNamespace(StateGraph=FakeStateGraph), + ) + + assert langgraph_patch.patch_stategraph_compile(GraphEventRecorder()) is True + patched_compile = FakeStateGraph.compile + assert langgraph_patch.patch_stategraph_compile(GraphEventRecorder()) is True + assert FakeStateGraph.compile is patched_compile + + +def test_wrap_node_callable_records_metadata_and_preserves_config_passthrough() -> None: + captured_events: list[tuple[str, dict[str, object]]] = [] + captured_configs: list[object] = [] + + class Recorder: + def on_graph_node_start(self, **kwargs: object) -> None: + captured_events.append(("start", dict(kwargs))) + + def on_graph_node_end(self, **kwargs: object) -> None: + captured_events.append(("end", dict(kwargs))) + + def node(state: dict[str, object], config: object) -> dict[str, object]: + captured_configs.append(config) + return {**state, "node_done": True} + + wrapped = langgraph_patch._wrap_node_callable("node_x", node, Recorder()) + config = {"configurable": {"agent_id": "agent-007"}} + result = wrapped({"step": "run"}, config) + + assert result == {"step": "run", "node_done": True} + assert captured_configs == [config] + assert captured_events[0] == ( + "start", + { + "node_name": "node_x", + "agent_id": "agent-007", + "state": {"step": "run"}, + "state_keys": ["step"], + "config": config, + }, + ) + assert captured_events[1] == ( + "end", + { + "node_name": "node_x", + "agent_id": "agent-007", + "state": {"step": "run"}, + "result": {"step": "run", "node_done": True}, + "state_delta": { + "changed_keys": ["node_done"], + "new_values": {"node_done": True}, + "removed_keys": [], + }, + "config": config, + }, + ) + + @pytest.mark.asyncio async def test_patch_stategraph_compile_fallback_wraps_async_invoke( monkeypatch: pytest.MonkeyPatch,