Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion agent_assembly/adapters/google_adk/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

_ORIGINAL_TOOL_RUN_ASYNC = "_agent_assembly_original_google_adk_tool_run_async"
_TOOLS_PATCHED_FLAG = "_agent_assembly_google_adk_tools_patched"
_ORIGINAL_AGENT_RUN_ASYNC = "_agent_assembly_original_google_adk_agent_run_async"
_AGENT_PATCHED_FLAG = "_agent_assembly_google_adk_agent_patched"
_PROCESS_AGENT_ID: str | None = None
_MAX_AUDIT_RESULT_CHARS = 2000

Expand All @@ -37,10 +39,16 @@ def apply(self) -> bool:
if tool_cls is None:
return False
_apply_tool_run_async_patch(tool_cls, self.callback_handler)
agent_cls = _load_google_adk_base_agent_class()
if agent_cls is not None:
_apply_agent_run_async_patch(agent_cls, self.process_agent_id)
return True

def revert(self) -> None:
"""Revert Google ADK tool patch when available."""
"""Revert Google ADK tool and agent patches when available."""
agent_cls = _load_google_adk_base_agent_class()
if agent_cls is not None:
_revert_agent_run_async_patch(agent_cls)
tool_cls = _load_google_adk_base_tool_class()
if tool_cls is not None:
_revert_tool_run_async_patch(tool_cls)
Expand All @@ -60,11 +68,61 @@ def _load_google_adk_base_tool_class() -> type[Any] | None:
return None


def _load_google_adk_base_agent_class() -> type[Any] | None:
try:
module = importlib.import_module("google.adk.agents")
except ImportError:
return None

agent_cls = getattr(module, "BaseAgent", None)
if isinstance(agent_cls, type):
return agent_cls
return None


def _current_spawn_depth() -> int:
current = _SPAWN_CTX.get()
return (current.depth + 1) if current is not None else 1


def _apply_agent_run_async_patch(agent_cls: type[Any], process_agent_id: str | None) -> None:
if getattr(agent_cls, _AGENT_PATCHED_FLAG, False):
return None

original_run_async = agent_cls.run_async

@wraps(original_run_async)
async def patched_run_async(self: Any, *args: Any, **kwargs: Any) -> Any:
spawn_ctx = SpawnContext(
parent_agent_id=process_agent_id or "",
depth=_current_spawn_depth(),
spawned_by_tool="google_adk_agent",
)
with spawn_context_scope(spawn_ctx):
async for event in original_run_async(self, *args, **kwargs):
yield event

setattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC, original_run_async)
agent_cls.run_async = patched_run_async
setattr(agent_cls, _AGENT_PATCHED_FLAG, True)
return None


def _revert_agent_run_async_patch(agent_cls: type[Any]) -> None:
if not getattr(agent_cls, _AGENT_PATCHED_FLAG, False):
return None

original_run_async = getattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC, None)
if callable(original_run_async):
agent_cls.run_async = original_run_async

if hasattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC):
delattr(agent_cls, _ORIGINAL_AGENT_RUN_ASYNC)
if hasattr(agent_cls, _AGENT_PATCHED_FLAG):
delattr(agent_cls, _AGENT_PATCHED_FLAG)
return None


def _apply_tool_run_async_patch(tool_cls: type[Any], callback_handler: Any) -> None:
if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False):
return None
Expand Down
188 changes: 188 additions & 0 deletions test/unit/adapters/google_adk/test_google_adk_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from agent_assembly.adapters.google_adk import patch as google_adk_patch
from agent_assembly.core.spawn import _SPAWN_CTX, SpawnContext
from agent_assembly.exceptions import PolicyViolationError


Expand Down Expand Up @@ -174,3 +175,190 @@ def model_dump(self) -> dict[str, int]:
assert google_adk_patch._serialize_tool_args(_Pydanticish()) == {"a": 1}
assert google_adk_patch._serialize_tool_args({"b": 2}) == {"b": 2}
assert google_adk_patch._serialize_tool_args("scalar") == {"value": "scalar"}


def _install_fake_google_adk_agent_module(
monkeypatch: pytest.MonkeyPatch,
captured: list[SpawnContext | None] | None = None,
) -> type[Any]:
"""Install a fake `google.adk.agents` module whose BaseAgent.run_async is
an async generator yielding two events, optionally capturing the current
SpawnContext snapshot at each yield point.
"""

class FakeBaseAgent:
async def run_async(self, *args: Any, **kwargs: Any) -> Any:
del args, kwargs
if captured is not None:
captured.append(_SPAWN_CTX.get())
yield {"event": "e1"}
if captured is not None:
captured.append(_SPAWN_CTX.get())
yield {"event": "e2"}

fake_module = SimpleNamespace(BaseAgent=FakeBaseAgent)

def fake_import_module(module_name: str) -> object:
if module_name == "google.adk.agents":
return fake_module
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module)
return FakeBaseAgent


@pytest.mark.asyncio
async def test_apply_agent_patches_run_async_and_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None:
FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch)

google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1")
first_ref = FakeBaseAgent.run_async
assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is True

# Re-applying is a no-op.
google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1")
assert FakeBaseAgent.run_async is first_ref


def test_revert_agent_patch_restores_run_async_and_clears_flag(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch)
original_run_async = FakeBaseAgent.run_async

google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1")
assert FakeBaseAgent.run_async is not original_run_async

google_adk_patch._revert_agent_run_async_patch(FakeBaseAgent)
assert FakeBaseAgent.run_async is original_run_async
assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is False


def test_revert_agent_patch_is_noop_when_not_patched(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch)
original_run_async = FakeBaseAgent.run_async

# Never applied — revert should not raise or rebind.
google_adk_patch._revert_agent_run_async_patch(FakeBaseAgent)
assert FakeBaseAgent.run_async is original_run_async


@pytest.mark.asyncio
async def test_patched_run_async_sets_spawn_context_during_iteration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: list[SpawnContext | None] = []
FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch, captured=captured)

google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id="parent-1")

agent = FakeBaseAgent()
events = [event async for event in agent.run_async("ctx")]

assert events == [{"event": "e1"}, {"event": "e2"}]
assert len(captured) == 2
for snapshot in captured:
assert snapshot is not None
assert snapshot.spawned_by_tool == "google_adk_agent"
assert snapshot.parent_agent_id == "parent-1"
assert snapshot.depth == 1


@pytest.mark.asyncio
async def test_patched_run_async_yields_all_events_from_original(
monkeypatch: pytest.MonkeyPatch,
) -> None:
FakeBaseAgent = _install_fake_google_adk_agent_module(monkeypatch)

google_adk_patch._apply_agent_run_async_patch(FakeBaseAgent, process_agent_id=None)

agent = FakeBaseAgent()
events = [event async for event in agent.run_async("ctx")]
assert events == [{"event": "e1"}, {"event": "e2"}]


def test_load_base_agent_returns_none_when_module_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def raise_import_error(module_name: str) -> object:
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", raise_import_error)
assert google_adk_patch._load_google_adk_base_agent_class() is None


def test_load_base_agent_returns_none_when_attribute_not_type(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_module = SimpleNamespace(BaseAgent=object())

def return_non_type(module_name: str) -> object:
if module_name == "google.adk.agents":
return fake_module
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", return_non_type)
assert google_adk_patch._load_google_adk_base_agent_class() is None


@pytest.mark.asyncio
async def test_apply_patches_both_tool_and_agent_when_both_available(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class FakeBaseTool:
name = "fake_tool"

async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]:
del args, tool_context, kwargs
return {"ran": True}

class FakeBaseAgent:
async def run_async(self, *args: Any, **kwargs: Any) -> Any:
del args, kwargs
yield {"event": "done"}

def fake_import_module(module_name: str) -> object:
if module_name == "google.adk.tools":
return SimpleNamespace(BaseTool=FakeBaseTool)
if module_name == "google.adk.agents":
return SimpleNamespace(BaseAgent=FakeBaseAgent)
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module)

patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor(), process_agent_id="parent-1")
assert patcher.apply() is True
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True
assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is True

patcher.revert()
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False
assert getattr(FakeBaseAgent, google_adk_patch._AGENT_PATCHED_FLAG, False) is False


def test_apply_proceeds_with_only_tool_when_agent_module_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class FakeBaseTool:
name = "fake_tool"

async def run_async(self, *, args: Any, tool_context: Any, **kwargs: Any) -> dict[str, object]:
del args, tool_context, kwargs
return {"ran": True}

def fake_import_module(module_name: str) -> object:
if module_name == "google.adk.tools":
return SimpleNamespace(BaseTool=FakeBaseTool)
raise ImportError(module_name)

monkeypatch.setattr(google_adk_patch.importlib, "import_module", fake_import_module)

patcher = google_adk_patch.GoogleADKPatch(_AllowInterceptor())
# Tool present, agent missing — apply still succeeds via the tool branch.
assert patcher.apply() is True
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is True

patcher.revert()
assert getattr(FakeBaseTool, google_adk_patch._TOOLS_PATCHED_FLAG, False) is False