From 3741b74bcfd21f0d69ad69a1b4902b9bca5fada7 Mon Sep 17 00:00:00 2001 From: Varun Nuthalapati Date: Mon, 22 Jun 2026 23:14:22 -0700 Subject: [PATCH] feat(tools): add opt-in flag to auto-attach load_artifacts to AgentTool sub-agents Artifacts saved by a parent agent are forwarded to sub-agents via ForwardingArtifactService, but the bytes are only injected into an LLM agent's request when it explicitly calls the load_artifacts tool. This is easy to miss, especially when AgentTool wraps a composite agent like ParallelAgent with several sub-agents. Adds include_load_artifacts_tool (default False) to AgentTool that recursively attaches load_artifacts_tool to the wrapped agent and all of its sub-agents. Closes #3232 --- src/google/adk/tools/agent_tool.py | 34 +++++++++ tests/unittests/tools/test_agent_tool.py | 93 ++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 8054d0ff8c..436679784d 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -79,6 +79,27 @@ def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]: return None +def _ensure_load_artifacts_tool(agent: BaseAgent) -> None: + """Recursively attaches `load_artifacts_tool` to `agent` and its sub-agents. + + Only `LlmAgent` instances can hold tools, so non-LLM agents (e.g. + `ParallelAgent`, `SequentialAgent`) are skipped but their sub-agents are + still visited. This is a no-op if the tool is already present, so it is + safe to call on every `AgentTool.run_async()`. + """ + from ..agents.llm_agent import LlmAgent + from .load_artifacts_tool import load_artifacts_tool + + if isinstance(agent, LlmAgent) and not any( + getattr(tool, 'name', None) == load_artifacts_tool.name + for tool in agent.tools + ): + agent.tools.append(load_artifacts_tool) + + for sub_agent in agent.sub_agents: + _ensure_load_artifacts_tool(sub_agent) + + def _get_output_schema(agent: BaseAgent) -> Optional[SchemaType]: """Extracts the output_schema from an agent. @@ -118,6 +139,14 @@ class AgentTool(BaseTool): to the agent's runner. When True (default), the agent will inherit all plugins from its parent. Set to False to run the agent with an isolated plugin environment. + include_load_artifacts_tool: Whether to automatically attach the + `load_artifacts` tool to the wrapped agent and all of its sub-agents + (recursively). Artifacts saved by the parent agent are always + forwarded to sub-agents via `ForwardingArtifactService`, but the + bytes are only injected into an LLM agent's request when it calls + `load_artifacts`. Defaults to False to avoid sending large payloads + on every turn; set to True so sub-agents can see artifacts without + having to add the tool to each agent manually. """ def __init__( @@ -127,11 +156,13 @@ def __init__( *, include_plugins: bool = True, propagate_grounding_metadata: bool = False, + include_load_artifacts_tool: bool = False, ): self.agent = agent self.skip_summarization: bool = skip_summarization self.include_plugins = include_plugins self.propagate_grounding_metadata = propagate_grounding_metadata + self.include_load_artifacts_tool = include_load_artifacts_tool super().__init__(name=agent.name, description=agent.description) @@ -214,6 +245,9 @@ async def run_async( if self.skip_summarization: tool_context.actions.skip_summarization = True + if self.include_load_artifacts_tool: + _ensure_load_artifacts_tool(self.agent) + input_schema = _get_input_schema(self.agent) if input_schema: input_value = input_schema.model_validate(args) diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index b9c7d97daf..19e988e694 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -21,6 +21,7 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import Agent +from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -35,6 +36,7 @@ from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.tools.agent_tool import AgentTool +from google.adk.tools.load_artifacts_tool import load_artifacts_tool from google.adk.tools.tool_context import ToolContext from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types @@ -1148,6 +1150,97 @@ async def test_run_async_skips_thought_parts(): assert result == '42' +def test_include_load_artifacts_tool_default_false(): + """By default, load_artifacts is not added to the wrapped agent.""" + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + tool_agent = Agent(name='tool_agent', model=mock_model) + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + runner.run('test1') + + assert all(tool.name != 'load_artifacts' for tool in tool_agent.tools) + + +def test_include_load_artifacts_tool_true_adds_to_wrapped_agent(): + """When enabled, load_artifacts is attached to the wrapped LlmAgent.""" + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + tool_agent = Agent(name='tool_agent', model=mock_model) + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_load_artifacts_tool=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + runner.run('test1') + + assert any(tool.name == 'load_artifacts' for tool in tool_agent.tools) + + +def test_include_load_artifacts_tool_true_adds_to_sub_agents_recursively(): + """When enabled, load_artifacts is attached to sub-agents of a composite + wrapped agent (e.g. ParallelAgent), not just the top-level agent. + """ + sub_agent_1 = Agent( + name='sub_agent_1', + model=testing_utils.MockModel.create(responses=['sub_response_1']), + ) + sub_agent_2 = Agent( + name='sub_agent_2', + model=testing_utils.MockModel.create(responses=['sub_response_2']), + ) + parallel_agent = ParallelAgent( + name='parallel_tool_agent', sub_agents=[sub_agent_1, sub_agent_2] + ) + + function_call_for_parallel = Part.from_function_call( + name='parallel_tool_agent', args={'request': 'test1'} + ) + mock_model_root = testing_utils.MockModel.create( + responses=[function_call_for_parallel, 'response2'] + ) + root_agent = Agent( + name='root_agent', + model=mock_model_root, + tools=[AgentTool(agent=parallel_agent, include_load_artifacts_tool=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + runner.run('test1') + + assert any(tool.name == 'load_artifacts' for tool in sub_agent_1.tools) + assert any(tool.name == 'load_artifacts' for tool in sub_agent_2.tools) + + +def test_include_load_artifacts_tool_does_not_duplicate_existing_tool(): + """If the wrapped agent already has load_artifacts, it is not duplicated.""" + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + tool_agent = Agent( + name='tool_agent', model=mock_model, tools=[load_artifacts_tool] + ) + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_load_artifacts_tool=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + runner.run('test1') + + assert sum(tool.name == 'load_artifacts' for tool in tool_agent.tools) == 1 + + class TestAgentToolWithCompositeAgents: """Tests for AgentTool wrapping composite agents (SequentialAgent, etc.)."""