diff --git a/agent_assembly/adapters/google_adk/patch.py b/agent_assembly/adapters/google_adk/patch.py index a0a7296..01514fe 100644 --- a/agent_assembly/adapters/google_adk/patch.py +++ b/agent_assembly/adapters/google_adk/patch.py @@ -19,6 +19,8 @@ _ORIGINAL_TOOL_RUN_ASYNC = "_agent_assembly_original_google_adk_tool_run_async" _TOOLS_PATCHED_FLAG = "_agent_assembly_google_adk_tools_patched" +_ORIGINAL_AGENT_RUN_ASYNC = "_agent_assembly_original_google_adk_agent_run_async" +_AGENT_PATCHED_FLAG = "_agent_assembly_google_adk_agent_patched" _PROCESS_AGENT_ID: str | None = None _MAX_AUDIT_RESULT_CHARS = 2000 @@ -37,10 +39,16 @@ def apply(self) -> bool: if tool_cls is None: return False _apply_tool_run_async_patch(tool_cls, self.callback_handler) + agent_cls = _load_google_adk_base_agent_class() + if agent_cls is not None: + _apply_agent_run_async_patch(agent_cls, self.process_agent_id) return True def revert(self) -> None: - """Revert Google ADK tool patch when available.""" + """Revert Google ADK tool and agent patches when available.""" + agent_cls = _load_google_adk_base_agent_class() + if agent_cls is not None: + _revert_agent_run_async_patch(agent_cls) tool_cls = _load_google_adk_base_tool_class() if tool_cls is not None: _revert_tool_run_async_patch(tool_cls) @@ -60,11 +68,61 @@ def _load_google_adk_base_tool_class() -> type[Any] | None: return None +def _load_google_adk_base_agent_class() -> type[Any] | None: + try: + module = importlib.import_module("google.adk.agents") + except ImportError: + return None + + agent_cls = getattr(module, "BaseAgent", None) + if isinstance(agent_cls, type): + return agent_cls + return None + + def _current_spawn_depth() -> int: current = _SPAWN_CTX.get() return (current.depth + 1) if current is not None else 1 +def _apply_agent_run_async_patch(agent_cls: type[Any], process_agent_id: str | None) -> None: + if getattr(agent_cls, _AGENT_PATCHED_FLAG, False): + return None + + original_run_async = agent_cls.run_async + + @wraps(original_run_async) + async def patched_run_async(self: Any, *args: Any, **kwargs: Any) -> Any: + spawn_ctx = SpawnContext( + parent_agent_id=process_agent_id or "", + depth=_current_spawn_depth(), + spawned_by_tool="google_adk_agent", + ) + with spawn_context_scope(spawn_ctx): + async for event in original_run_async(self, *args, **kwargs): + yield event + + setattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC, original_run_async) + agent_cls.run_async = patched_run_async + setattr(agent_cls, _AGENT_PATCHED_FLAG, True) + return None + + +def _revert_agent_run_async_patch(agent_cls: type[Any]) -> None: + if not getattr(agent_cls, _AGENT_PATCHED_FLAG, False): + return None + + original_run_async = getattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC, None) + if callable(original_run_async): + agent_cls.run_async = original_run_async + + if hasattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC): + delattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC) + if hasattr(agent_cls, _AGENT_PATCHED_FLAG): + delattr(agent_cls, _AGENT_PATCHED_FLAG) + return None + + def _apply_tool_run_async_patch(tool_cls: type[Any], callback_handler: Any) -> None: if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): return None diff --git a/test/unit/adapters/google_adk/test_google_adk_patch.py b/test/unit/adapters/google_adk/test_google_adk_patch.py index f20fc10..eb77d3c 100644 --- a/test/unit/adapters/google_adk/test_google_adk_patch.py +++ b/test/unit/adapters/google_adk/test_google_adk_patch.py @@ -6,6 +6,7 @@ import pytest from agent_assembly.adapters.google_adk import patch as google_adk_patch +from agent_assembly.core.spawn import _SPAWN_CTX, SpawnContext from agent_assembly.exceptions import PolicyViolationError @@ -174,3 +175,190 @@ def model_dump(self) -> dict[str, int]: assert google_adk_patch._serialize_tool_args(_Pydanticish()) == {"a": 1} assert google_adk_patch._serialize_tool_args({"b": 2}) == {"b": 2} assert google_adk_patch._serialize_tool_args("scalar") == {"value": "scalar"} + + +def _install_fake_google_adk_agent_module( + monkeypatch: pytest.MonkeyPatch, + captured: list[SpawnContext | None] | None = None, +) -> type[Any]: + """Install a fake `google.adk.agents` module whose BaseAgent.run_async is + an async generator yielding two events, optionally capturing the current + SpawnContext snapshot at each yield point. + """ + + class FakeBaseAgent: + async def run_async(self, *args: Any, **kwargs: Any) -> Any: + del args, kwargs + if captured is not None: + captured.append(_SPAWN_CTX.get()) + yield {"event": "e1"} + if captured is not None: + captured.append(_SPAWN_CTX.get()) + yield {"event": "e2"} + + fake_module = SimpleNamespace(BaseAgent=FakeBaseAgent) + + def fake_import_module(module_name: str) -> object: + if module_name == "google.adk.agents": + return fake_module + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module) + return FakeBaseAgent + + +@pytest.mark.asyncio +async def test_apply_agent_patches_run_async_and_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch) + + google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1") + first_ref = FakeBaseAgent.run_async + assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is True + + # Re-applying is a no-op. + google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1") + assert FakeBaseAgent.run_async is first_ref + + +def test_revert_agent_patch_restores_run_async_and_clears_flag( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch) + original_run_async = FakeBaseAgent.run_async + + google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1") + assert FakeBaseAgent.run_async is not original_run_async + + google_adk_patch._revert_agent_run_async_patch(FakeBaseAgent) + assert FakeBaseAgent.run_async is original_run_async + assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is False + + +def test_revert_agent_patch_is_noop_when_not_patched( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch) + original_run_async = FakeBaseAgent.run_async + + # Never applied — revert should not raise or rebind. + google_adk_patch._revert_agent_run_async_patch(FakeBaseAgent) + assert FakeBaseAgent.run_async is original_run_async + + +@pytest.mark.asyncio +async def test_patched_run_async_sets_spawn_context_during_iteration( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: list[SpawnContext | None] = [] + FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch, captured=captured) + + google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1") + + agent = FakeBaseAgent() + events = [event async for event in agent.run_async("ctx")] + + assert events == [{"event": "e1"}, {"event": "e2"}] + assert len(captured) == 2 + for snapshot in captured: + assert snapshot is not None + assert snapshot.spawned_by_tool == "google_adk_agent" + assert snapshot.parent_agent_id == "parent-1" + assert snapshot.depth == 1 + + +@pytest.mark.asyncio +async def test_patched_run_async_yields_all_events_from_original( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch) + + google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id=None) + + agent = FakeBaseAgent() + events = [event async for event in agent.run_async("ctx")] + assert events == [{"event": "e1"}, {"event": "e2"}] + + +def test_load_base_agent_returns_none_when_module_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def raise_import_error(module_name: str) -> object: + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", raise_import_error) + assert google_adk_patch._load_google_adk_base_agent_class() is None + + +def test_load_base_agent_returns_none_when_attribute_not_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_module = SimpleNamespace(BaseAgent=object()) + + def return_non_type(module_name: str) -> object: + if module_name == "google.adk.agents": + return fake_module + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", return_non_type) + assert google_adk_patch._load_google_adk_base_agent_class() is None + + +@pytest.mark.asyncio +async def test_apply_patches_both_tool_and_agent_when_both_available( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeBaseTool: + name = "fake_tool" + + async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]: + del args, tool_context, kwargs + return {"ran": True} + + class FakeBaseAgent: + async def run_async(self, *args: Any, **kwargs: Any) -> Any: + del args, kwargs + yield {"event": "done"} + + def fake_import_module(module_name: str) -> object: + if module_name == "google.adk.tools": + return SimpleNamespace(BaseTool=FakeBaseTool) + if module_name == "google.adk.agents": + return SimpleNamespace(BaseAgent=FakeBaseAgent) + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module) + + patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor(), process_agent_id="parent-1") + assert patcher.apply() is True + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True + assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is True + + patcher.revert() + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False + assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is False + + +def test_apply_proceeds_with_only_tool_when_agent_module_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeBaseTool: + name = "fake_tool" + + async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]: + del args, tool_context, kwargs + return {"ran": True} + + def fake_import_module(module_name: str) -> object: + if module_name == "google.adk.tools": + return SimpleNamespace(BaseTool=FakeBaseTool) + raise ImportError(module_name) + + monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module) + + patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor()) + # Tool present, agent missing — apply still succeeds via the tool branch. + assert patcher.apply() is True + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True + + patcher.revert() + assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False