diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 7db7bc89ee..2eef9af9dc 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -32,6 +32,7 @@ from a2a.client.client_factory import ClientFactory as A2AClientFactory from a2a.client.errors import A2AClientHTTPError from a2a.types import AgentCard +from a2a.types import Artifact as A2AArtifact from a2a.types import Message as A2AMessage from a2a.types import Part as A2APart from a2a.types import Role @@ -485,6 +486,97 @@ def _construct_message_parts_from_session( return message_parts, context_id + async def _save_a2a_artifacts_to_session( + self, + artifacts: Optional[list[A2AArtifact]], + event: Event, + ctx: InvocationContext, + part_converter: Optional[A2APartToGenAIPartConverter] = None, + ) -> None: + """Persists A2A artifacts into the orchestrator session's artifact service. + + When a remote A2A agent returns artifacts (e.g. files) in its response, + they are saved into the (parent/orchestrator) session's artifact service so + that downstream agents can load them via ``context.load_artifact(...)``. The + saved versions are recorded on the event's ``artifact_delta`` so the rest of + the ADK runtime is aware of them. + + This is best-effort: if no artifact service is configured, or an individual + artifact cannot be converted/saved, it is skipped without failing the + overall A2A response handling. + + Args: + artifacts: The A2A artifacts to persist. May be None or empty. + event: The ADK event produced for this A2A response. Its + ``actions.artifact_delta`` is updated with the saved filenames/versions. + ctx: The invocation context, providing the session artifact service. + part_converter: Optional A2A-to-GenAI part converter. Defaults to the + agent's configured ``a2a_part_converter``. + """ + if not artifacts or ctx.artifact_service is None: + return + + part_converter = part_converter or self._a2a_part_converter + + for artifact in artifacts: + if not artifact or not artifact.parts: + continue + + # Prefer the human-readable artifact name (this is the original filename + # when the remote is an ADK A2A server), falling back to the artifact id. + filename = artifact.name or artifact.artifact_id + if not filename: + logger.warning( + "Skipping A2A artifact without a name or id for agent %s", + self.name, + ) + continue + + for a2a_part in artifact.parts: + converted = part_converter(a2a_part) + if not isinstance(converted, list): + converted = [converted] if converted else [] + + # Only blob-like parts (files / inline data) are saved as artifacts. + genai_part = next( + ( + part + for part in converted + if part is not None + and (part.inline_data is not None or part.file_data is not None) + ), + None, + ) + if genai_part is None: + continue + + try: + version = await ctx.artifact_service.save_artifact( + app_name=ctx.app_name, + user_id=ctx.user_id, + session_id=ctx.session.id, + filename=filename, + artifact=genai_part, + ) + except Exception as e: # pylint: disable=broad-except + logger.warning( + "Failed to save A2A artifact %s for agent %s: %s", + filename, + self.name, + e, + ) + break + + event.actions.artifact_delta[filename] = version + logger.debug( + "Saved A2A artifact %s (version %s) to session for agent %s", + filename, + version, + self.name, + ) + # One artifact maps to a single saved file; ignore any extra parts. + break + async def _handle_a2a_response( self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext ) -> Optional[Event]: @@ -570,6 +662,17 @@ async def _handle_a2a_response( task.context_id ) + # Persist any artifacts returned by the remote agent into the + # orchestrator session so downstream agents can load them. Full task + # responses carry them on ``task.artifacts``; streaming artifact updates + # carry a single artifact on ``update.artifact``. + if isinstance(update, A2ATaskArtifactUpdateEvent): + update_artifact = getattr(update, "artifact", None) + artifacts = [update_artifact] if update_artifact else None + else: + artifacts = getattr(task, "artifacts", None) if task else None + await self._save_a2a_artifacts_to_session(artifacts, event, ctx) + # Otherwise, it's a regular A2AMessage for non-streaming responses. elif isinstance(a2a_response, A2AMessage): event = convert_a2a_message_to_event( @@ -642,6 +745,17 @@ async def _handle_a2a_response_v2( task.context_id ) + # Persist any artifacts returned by the remote agent into the + # orchestrator session so downstream agents can load them. + if isinstance(update, A2ATaskArtifactUpdateEvent): + update_artifact = getattr(update, "artifact", None) + artifacts = [update_artifact] if update_artifact else None + else: + artifacts = getattr(task, "artifacts", None) if task else None + await self._save_a2a_artifacts_to_session( + artifacts, event, ctx, self._config.a2a_part_converter + ) + # Otherwise, it's a regular A2AMessage. elif isinstance(a2a_response, A2AMessage): event = self._config.a2a_message_converter( diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 0a5c2ef75f..4c98a4d5e7 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import copy import json from pathlib import Path @@ -28,7 +29,11 @@ from a2a.types import AgentCard from a2a.types import AgentSkill from a2a.types import Artifact +from a2a.types import FilePart +from a2a.types import FileWithBytes +from a2a.types import FileWithUri from a2a.types import Message as A2AMessage +from a2a.types import Part as A2APart from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState @@ -46,6 +51,7 @@ from google.adk.agents.remote_a2a_agent import AgentCardResolutionError from google.adk.agents.remote_a2a_agent import RemoteA2aAgent import google.adk.agents.remote_a2a_agent as remote_a2a_agent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.events.event import Event from google.adk.sessions.session import Session from google.genai import types as genai_types @@ -552,6 +558,7 @@ def setup_method(self): self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" self.mock_context.branch = "main" + self.mock_context.artifact_service = None def test_create_a2a_request_for_user_function_response_no_function_call(self): """Test function response request creation when no function call exists.""" @@ -1377,6 +1384,7 @@ def setup_method(self): self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" self.mock_context.branch = "main" + self.mock_context.artifact_service = None def test_create_a2a_request_for_user_function_response_no_function_call(self): """Test function response request creation when no function call exists.""" @@ -1790,6 +1798,7 @@ def setup_method(self): self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" self.mock_context.branch = "main" + self.mock_context.artifact_service = None @pytest.mark.asyncio async def test_handle_a2a_response_impl_with_message(self): @@ -2036,6 +2045,7 @@ def setup_method(self): self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" self.mock_context.branch = "main" + self.mock_context.artifact_service = None @pytest.mark.asyncio async def test_run_async_impl_initialization_failure(self): @@ -2312,6 +2322,7 @@ def setup_method(self): self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" self.mock_context.branch = "main" + self.mock_context.artifact_service = None @pytest.mark.asyncio async def test_run_async_impl_initialization_failure(self): @@ -2989,3 +3000,228 @@ def test_deepcopy_config(self): copied_config.request_interceptors[0] is not config.request_interceptors[0] ) + + +def _make_file_artifact( + filename: str, + data: bytes = b"hello world", + mime_type: str = "text/plain", + artifact_id: str | None = None, +) -> Artifact: + """Builds an A2A Artifact carrying a single inline file part.""" + return Artifact( + artifact_id=artifact_id or f"{filename}_0", + name=filename, + parts=[ + A2APart( + root=FilePart( + file=FileWithBytes( + bytes=base64.b64encode(data).decode("utf-8"), + mime_type=mime_type, + name=filename, + ) + ) + ) + ], + ) + + +class TestRemoteA2aAgentArtifactPersistence: + """Tests that artifacts returned by a remote A2A agent are saved. + + These exercise the orchestrator-side persistence: when a RemoteA2aAgent + receives artifacts in an A2A response, they should be written into the parent + session's artifact service so downstream agents can load them. + """ + + def setup_method(self): + self.agent = RemoteA2aAgent( + name="remote_agent", + agent_card=create_test_agent_card(), + ) + self.artifact_service = InMemoryArtifactService() + + self.session = Mock(spec=Session) + self.session.id = "session-123" + self.session.events = [] + + self.ctx = Mock(spec=InvocationContext) + self.ctx.session = self.session + self.ctx.invocation_id = "invocation-123" + self.ctx.branch = "main" + self.ctx.app_name = "test_app" + self.ctx.user_id = "test_user" + self.ctx.artifact_service = self.artifact_service + + async def _load(self, filename: str, version: int = 0): + return await self.artifact_service.load_artifact( + app_name=self.ctx.app_name, + user_id=self.ctx.user_id, + session_id=self.session.id, + filename=filename, + version=version, + ) + + @pytest.mark.asyncio + async def test_non_streaming_task_artifact_saved_to_session(self): + """A full task response with artifacts saves them to the artifact service.""" + artifact = _make_file_artifact("report.txt", data=b"final report") + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[artifact], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + # The artifact is persisted to the orchestrator session. + saved = await self._load("report.txt") + assert saved is not None + assert saved.inline_data is not None + assert saved.inline_data.data == b"final report" + # And recorded as an artifact delta on the emitted event. + assert event.actions.artifact_delta == {"report.txt": 0} + + @pytest.mark.asyncio + async def test_streaming_artifact_update_saved_to_session(self): + """A streaming TaskArtifactUpdateEvent saves its artifact to the session.""" + artifact = _make_file_artifact( + "chart.png", data=b"\x89PNG", mime_type="image/png" + ) + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.working), + ) + update = TaskArtifactUpdateEvent( + task_id="task-123", + context_id="context-123", + artifact=artifact, + append=False, + last_chunk=True, + ) + + event = await self.agent._handle_a2a_response((task, update), self.ctx) + + saved = await self._load("chart.png") + assert saved is not None + assert saved.inline_data.data == b"\x89PNG" + assert saved.inline_data.mime_type == "image/png" + assert event.actions.artifact_delta == {"chart.png": 0} + + @pytest.mark.asyncio + async def test_multiple_artifacts_saved_to_session(self): + """All artifacts on a task are persisted with their respective filenames.""" + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[ + _make_file_artifact("a.txt", data=b"AAA"), + _make_file_artifact("b.txt", data=b"BBB"), + ], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + assert (await self._load("a.txt")).inline_data.data == b"AAA" + assert (await self._load("b.txt")).inline_data.data == b"BBB" + assert event.actions.artifact_delta == {"a.txt": 0, "b.txt": 0} + + @pytest.mark.asyncio + async def test_artifact_id_used_when_name_missing(self): + """When an artifact has no name, its artifact_id is used as the filename.""" + artifact = Artifact( + artifact_id="generated-id.bin", + parts=[ + A2APart( + root=FilePart( + file=FileWithBytes( + bytes=base64.b64encode(b"data").decode("utf-8"), + mime_type="application/octet-stream", + ) + ) + ) + ], + ) + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[artifact], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + assert (await self._load("generated-id.bin")) is not None + assert event.actions.artifact_delta == {"generated-id.bin": 0} + + @pytest.mark.asyncio + async def test_no_artifact_service_is_noop(self): + """Without an artifact service, artifacts are not saved and no error occurs.""" + self.ctx.artifact_service = None + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[_make_file_artifact("report.txt")], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + assert event is not None + assert event.actions.artifact_delta == {} + + @pytest.mark.asyncio + async def test_text_only_artifact_not_saved(self): + """Artifacts without blob-like parts (text only) are not saved as files.""" + artifact = Artifact( + artifact_id="note_0", + name="note.txt", + parts=[A2APart(root=TextPart(text="just text"))], + ) + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[artifact], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + assert await self._load("note.txt") is None + assert event.actions.artifact_delta == {} + + @pytest.mark.asyncio + async def test_file_with_uri_artifact_saved_as_file_data(self): + """File-by-URI artifacts are persisted as file_data parts.""" + artifact = Artifact( + artifact_id="remote_0", + name="remote.pdf", + parts=[ + A2APart( + root=FilePart( + file=FileWithUri( + uri="gs://bucket/remote.pdf", + mime_type="application/pdf", + name="remote.pdf", + ) + ) + ) + ], + ) + task = A2ATask( + id="task-123", + context_id="context-123", + status=A2ATaskStatus(state=TaskState.completed), + artifacts=[artifact], + ) + + event = await self.agent._handle_a2a_response((task, None), self.ctx) + + saved = await self._load("remote.pdf") + assert saved is not None + assert saved.file_data is not None + assert saved.file_data.file_uri == "gs://bucket/remote.pdf" + assert event.actions.artifact_delta == {"remote.pdf": 0}