Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def _get_input_schema(agent: BaseAgent) -> Optional[type[BaseModel]]:
return None


def _ensure_load_artifacts_tool(agent: BaseAgent) -> None:
"""Recursively attaches `load_artifacts_tool` to `agent` and its sub-agents.

Only `LlmAgent` instances can hold tools, so non-LLM agents (e.g.
`ParallelAgent`, `SequentialAgent`) are skipped but their sub-agents are
still visited. This is a no-op if the tool is already present, so it is
safe to call on every `AgentTool.run_async()`.
"""
from ..agents.llm_agent import LlmAgent
from .load_artifacts_tool import load_artifacts_tool

if isinstance(agent, LlmAgent) and not any(
getattr(tool, 'name', None) == load_artifacts_tool.name
for tool in agent.tools
):
agent.tools.append(load_artifacts_tool)

for sub_agent in agent.sub_agents:
_ensure_load_artifacts_tool(sub_agent)


def _get_output_schema(agent: BaseAgent) -> Optional[SchemaType]:
"""Extracts the output_schema from an agent.

Expand Down Expand Up @@ -118,6 +139,14 @@ class AgentTool(BaseTool):
to the agent's runner. When True (default), the agent will inherit all
plugins from its parent. Set to False to run the agent with an isolated
plugin environment.
include_load_artifacts_tool: Whether to automatically attach the
`load_artifacts` tool to the wrapped agent and all of its sub-agents
(recursively). Artifacts saved by the parent agent are always
forwarded to sub-agents via `ForwardingArtifactService`, but the
bytes are only injected into an LLM agent's request when it calls
`load_artifacts`. Defaults to False to avoid sending large payloads
on every turn; set to True so sub-agents can see artifacts without
having to add the tool to each agent manually.
"""

def __init__(
Expand All @@ -127,11 +156,13 @@ def __init__(
*,
include_plugins: bool = True,
propagate_grounding_metadata: bool = False,
include_load_artifacts_tool: bool = False,
):
self.agent = agent
self.skip_summarization: bool = skip_summarization
self.include_plugins = include_plugins
self.propagate_grounding_metadata = propagate_grounding_metadata
self.include_load_artifacts_tool = include_load_artifacts_tool

super().__init__(name=agent.name, description=agent.description)

Expand Down Expand Up @@ -214,6 +245,9 @@ async def run_async(
if self.skip_summarization:
tool_context.actions.skip_summarization = True

if self.include_load_artifacts_tool:
_ensure_load_artifacts_tool(self.agent)

input_schema = _get_input_schema(self.agent)
if input_schema:
input_value = input_schema.model_validate(args)
Expand Down
93 changes: 93 additions & 0 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import Agent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.run_config import RunConfig
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
Expand All @@ -35,6 +36,7 @@
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.agent_tool import AgentTool
from google.adk.tools.load_artifacts_tool import load_artifacts_tool
from google.adk.tools.tool_context import ToolContext
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
Expand Down Expand Up @@ -1148,6 +1150,97 @@ async def test_run_async_skips_thought_parts():
assert result == '42'


def test_include_load_artifacts_tool_default_false():
"""By default, load_artifacts is not added to the wrapped agent."""
mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)
tool_agent = Agent(name='tool_agent', model=mock_model)
root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent)],
)

runner = testing_utils.InMemoryRunner(root_agent)
runner.run('test1')

assert all(tool.name != 'load_artifacts' for tool in tool_agent.tools)


def test_include_load_artifacts_tool_true_adds_to_wrapped_agent():
"""When enabled, load_artifacts is attached to the wrapped LlmAgent."""
mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)
tool_agent = Agent(name='tool_agent', model=mock_model)
root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent, include_load_artifacts_tool=True)],
)

runner = testing_utils.InMemoryRunner(root_agent)
runner.run('test1')

assert any(tool.name == 'load_artifacts' for tool in tool_agent.tools)


def test_include_load_artifacts_tool_true_adds_to_sub_agents_recursively():
"""When enabled, load_artifacts is attached to sub-agents of a composite
wrapped agent (e.g. ParallelAgent), not just the top-level agent.
"""
sub_agent_1 = Agent(
name='sub_agent_1',
model=testing_utils.MockModel.create(responses=['sub_response_1']),
)
sub_agent_2 = Agent(
name='sub_agent_2',
model=testing_utils.MockModel.create(responses=['sub_response_2']),
)
parallel_agent = ParallelAgent(
name='parallel_tool_agent', sub_agents=[sub_agent_1, sub_agent_2]
)

function_call_for_parallel = Part.from_function_call(
name='parallel_tool_agent', args={'request': 'test1'}
)
mock_model_root = testing_utils.MockModel.create(
responses=[function_call_for_parallel, 'response2']
)
root_agent = Agent(
name='root_agent',
model=mock_model_root,
tools=[AgentTool(agent=parallel_agent, include_load_artifacts_tool=True)],
)

runner = testing_utils.InMemoryRunner(root_agent)
runner.run('test1')

assert any(tool.name == 'load_artifacts' for tool in sub_agent_1.tools)
assert any(tool.name == 'load_artifacts' for tool in sub_agent_2.tools)


def test_include_load_artifacts_tool_does_not_duplicate_existing_tool():
"""If the wrapped agent already has load_artifacts, it is not duplicated."""
mock_model = testing_utils.MockModel.create(
responses=[function_call_no_schema, 'response1', 'response2']
)
tool_agent = Agent(
name='tool_agent', model=mock_model, tools=[load_artifacts_tool]
)
root_agent = Agent(
name='root_agent',
model=mock_model,
tools=[AgentTool(agent=tool_agent, include_load_artifacts_tool=True)],
)

runner = testing_utils.InMemoryRunner(root_agent)
runner.run('test1')

assert sum(tool.name == 'load_artifacts' for tool in tool_agent.tools) == 1


class TestAgentToolWithCompositeAgents:
"""Tests for AgentTool wrapping composite agents (SequentialAgent, etc.)."""

Expand Down