diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index a016962c..27b5cb4c 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -47,9 +47,14 @@ initdb inmemory INR isready +jku JPY JSONRPCt +jwk +jwks JWS +jws +kid kwarg langgraph lifecycles diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index bdd4c5b8..97bba6b6 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -12,7 +12,7 @@ jobs: if: github.repository == 'a2aproject/a2a-python' steps: - name: Checkout Code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index decb3b1d..c6e6da0f 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install uv uses: astral-sh/setup-uv@v7 @@ -26,7 +26,7 @@ jobs: run: uv build - name: Upload distributions - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: release-dists path: dist/ @@ -40,7 +40,7 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: name: release-dists path: dist/ diff --git a/.github/workflows/stale.yaml b/.github/workflows/stale.yaml index 3f9c6fe9..7c8cb0dc 100644 --- a/.github/workflows/stale.yaml +++ b/.github/workflows/stale.yaml @@ -7,7 +7,7 @@ name: Mark stale issues and pull requests on: schedule: - # Scheduled to run at 10.30PM UTC everyday (1530PDT/1430PST) + # Scheduled to run at 10.30PM UTC every day (1530PDT/1430PST) - cron: "30 22 * * *" workflow_dispatch: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 16052ba1..eb5b3d1f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -39,7 +39,7 @@ jobs: python-version: ['3.10', '3.13'] steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up test environment variables run: | echo "POSTGRES_TEST_DSN=postgresql+asyncpg://a2a:a2a_password@localhost:5432/a2a_test" >> $GITHUB_ENV diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml index c019afeb..e1adbd34 100644 --- a/.github/workflows/update-a2a-types.yml +++ b/.github/workflows/update-a2a-types.yml @@ -12,7 +12,7 @@ jobs: pull-requests: write steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 with: @@ -42,7 +42,7 @@ jobs: uv run scripts/grpc_gen_post_processor.py echo "Buf generate finished." - name: Create Pull Request with Updates - uses: peter-evans/create-pull-request@v7 + uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.A2A_BOT_PAT }} committer: a2a-bot diff --git a/CHANGELOG.md b/CHANGELOG.md index f5e6048d..cfbedf4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,58 @@ - # Changelog +# Changelog + +## [0.3.22](https://github.com/a2aproject/a2a-python/compare/v0.3.21...v0.3.22) (2025-12-16) + + +### Features + +* Add custom ID generators to SimpleRequestContextBuilder ([#594](https://github.com/a2aproject/a2a-python/issues/594)) ([04bcafc](https://github.com/a2aproject/a2a-python/commit/04bcafc737cf426d9975c76e346335ff992363e2)) + + +### Code Refactoring + +* Move agent card signature verification into `A2ACardResolver` ([6fa6a6c](https://github.com/a2aproject/a2a-python/commit/6fa6a6cf3875bdf7bfc51fb1a541a3f3e8381dc0)) + +## [0.3.21](https://github.com/a2aproject/a2a-python/compare/v0.3.20...v0.3.21) (2025-12-12) + + +### Documentation + +* Fixing typos ([#586](https://github.com/a2aproject/a2a-python/issues/586)) ([5fea21f](https://github.com/a2aproject/a2a-python/commit/5fea21fb34ecea55e588eb10139b5d47020a76cb)) + +## [0.3.20](https://github.com/a2aproject/a2a-python/compare/v0.3.19...v0.3.20) (2025-12-03) + + +### Bug Fixes + +* Improve streaming errors handling ([#576](https://github.com/a2aproject/a2a-python/issues/576)) ([7ea7475](https://github.com/a2aproject/a2a-python/commit/7ea7475091df2ee40d3035ef1bc34ee2f86524ee)) + +## [0.3.19](https://github.com/a2aproject/a2a-python/compare/v0.3.18...v0.3.19) (2025-11-25) + + +### Bug Fixes + +* **jsonrpc, rest:** `extensions` support in `get_card` methods in `json-rpc` and `rest` transports ([#564](https://github.com/a2aproject/a2a-python/issues/564)) ([847f18e](https://github.com/a2aproject/a2a-python/commit/847f18eff59985f447c39a8e5efde87818b68d15)) + +## [0.3.18](https://github.com/a2aproject/a2a-python/compare/v0.3.17...v0.3.18) (2025-11-24) + + +### Bug Fixes + +* return updated `agent_card` in `JsonRpcTransport.get_card()` ([#552](https://github.com/a2aproject/a2a-python/issues/552)) ([0ce239e](https://github.com/a2aproject/a2a-python/commit/0ce239e98f67ccbf154f2edcdbcee43f3b080ead)) + +## [0.3.17](https://github.com/a2aproject/a2a-python/compare/v0.3.16...v0.3.17) (2025-11-24) + + +### Features + +* **client:** allow specifying `history_length` via call-site `MessageSendConfiguration` in `BaseClient.send_message` ([53bbf7a](https://github.com/a2aproject/a2a-python/commit/53bbf7ae3ad58fb0c10b14da05cf07c0a7bd9651)) + +## [0.3.16](https://github.com/a2aproject/a2a-python/compare/v0.3.15...v0.3.16) (2025-11-21) + + +### Bug Fixes + +* Ensure metadata propagation for `Task` `ToProto` and `FromProto` conversion ([#557](https://github.com/a2aproject/a2a-python/issues/557)) ([fc31d03](https://github.com/a2aproject/a2a-python/commit/fc31d03e8c6acb68660f6d1924262e16933c5d50)) ## [0.3.15](https://github.com/a2aproject/a2a-python/compare/v0.3.14...v0.3.15) (2025-11-19) @@ -66,7 +120,7 @@ ### Bug Fixes * apply `history_length` for `message/send` requests ([#498](https://github.com/a2aproject/a2a-python/issues/498)) ([a49f94e](https://github.com/a2aproject/a2a-python/commit/a49f94ef23d81b8375e409b1c1e51afaf1da1956)) -* **client:** `A2ACardResolver.get_agent_card` will auto-populate with `agent_card_path` when `relative_card_path` is empty ([#508](https://github.com/a2aproject/a2a-python/issues/508)) ([ba24ead](https://github.com/a2aproject/a2a-python/commit/ba24eadb5b6fcd056a008e4cbcef03b3f72a37c3)) +* **client:** `A2ACardResolver.get_agent_card` will autopopulate with `agent_card_path` when `relative_card_path` is empty ([#508](https://github.com/a2aproject/a2a-python/issues/508)) ([ba24ead](https://github.com/a2aproject/a2a-python/commit/ba24eadb5b6fcd056a008e4cbcef03b3f72a37c3)) ### Documentation @@ -403,8 +457,8 @@ * Event consumer should stop on input_required ([#167](https://github.com/a2aproject/a2a-python/issues/167)) ([51c2d8a](https://github.com/a2aproject/a2a-python/commit/51c2d8addf9e89a86a6834e16deb9f4ac0e05cc3)) * Fix Release Version ([#161](https://github.com/a2aproject/a2a-python/issues/161)) ([011d632](https://github.com/a2aproject/a2a-python/commit/011d632b27b201193813ce24cf25e28d1335d18e)) * generate StrEnum types for enums ([#134](https://github.com/a2aproject/a2a-python/issues/134)) ([0c49dab](https://github.com/a2aproject/a2a-python/commit/0c49dabcdb9d62de49fda53d7ce5c691b8c1591c)) -* library should released as 0.2.6 ([d8187e8](https://github.com/a2aproject/a2a-python/commit/d8187e812d6ac01caedf61d4edaca522e583d7da)) -* remove error types from enqueable events ([#138](https://github.com/a2aproject/a2a-python/issues/138)) ([511992f](https://github.com/a2aproject/a2a-python/commit/511992fe585bd15e956921daeab4046dc4a50a0a)) +* library should be released as 0.2.6 ([d8187e8](https://github.com/a2aproject/a2a-python/commit/d8187e812d6ac01caedf61d4edaca522e583d7da)) +* remove error types from enqueueable events ([#138](https://github.com/a2aproject/a2a-python/issues/138)) ([511992f](https://github.com/a2aproject/a2a-python/commit/511992fe585bd15e956921daeab4046dc4a50a0a)) * **stream:** don't block event loop in EventQueue ([#151](https://github.com/a2aproject/a2a-python/issues/151)) ([efd9080](https://github.com/a2aproject/a2a-python/commit/efd9080b917c51d6e945572fd123b07f20974a64)) * **task_updater:** fix potential duplicate artifact_id from default v… ([#156](https://github.com/a2aproject/a2a-python/issues/156)) ([1f0a769](https://github.com/a2aproject/a2a-python/commit/1f0a769c1027797b2f252e4c894352f9f78257ca)) diff --git a/Gemini.md b/Gemini.md index d4367c37..7f52d33f 100644 --- a/Gemini.md +++ b/Gemini.md @@ -4,7 +4,7 @@ - uv as package manager ## How to run all tests -1. If dependencies are not installed install them using following command +1. If dependencies are not installed, install them using the following command ``` uv sync --all-extras ``` diff --git a/pyproject.toml b/pyproject.toml index 46f7400a..561a5a45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] +signing = ["PyJWT>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] sql = ["a2a-sdk[postgresql,mysql,sqlite]"] @@ -45,6 +46,7 @@ all = [ "a2a-sdk[encryption]", "a2a-sdk[grpc]", "a2a-sdk[telemetry]", + "a2a-sdk[signing]", ] [project.urls] @@ -86,6 +88,7 @@ style = "pep440" dev = [ "datamodel-code-generator>=0.30.0", "mypy>=1.15.0", + "PyJWT>=2.0.0", "pytest>=8.3.5", "pytest-asyncio>=0.26.0", "pytest-cov>=6.1.1", diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 5719bc1b..c870f329 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from typing import Any from a2a.client.client import ( @@ -47,6 +47,7 @@ async def send_message( self, request: Message, *, + configuration: MessageSendConfiguration | None = None, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, extensions: list[str] | None = None, @@ -59,6 +60,7 @@ async def send_message( Args: request: The message to send to the agent. + configuration: Optional per-call overrides for message sending behavior. context: The client call context. request_metadata: Extensions Metadata attached to the request. extensions: List of extensions to be activated. @@ -66,7 +68,7 @@ async def send_message( Yields: An async iterator of `ClientEvent` or a final `Message` response. """ - config = MessageSendConfiguration( + base_config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, push_notification_config=( @@ -75,6 +77,15 @@ async def send_message( else None ), ) + if configuration is not None: + update_data = configuration.model_dump( + exclude_unset=True, + by_alias=False, + ) + config = base_config.model_copy(update=update_data) + else: + config = base_config + params = MessageSendParams( message=request, configuration=config, metadata=request_metadata ) @@ -250,6 +261,7 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the agent's card. @@ -259,12 +271,15 @@ async def get_card( Args: context: The client call context. extensions: List of extensions to be activated. + signature_verifier: A callable used to verify the agent card's signatures. Returns: The `AgentCard` for the agent. """ card = await self._transport.get_card( - context=context, extensions=extensions + context=context, + extensions=extensions, + signature_verifier=signature_verifier, ) self._card = card return card diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index f13fe3ab..adb3c5ae 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -1,6 +1,7 @@ import json import logging +from collections.abc import Callable from typing import Any import httpx @@ -44,6 +45,7 @@ async def get_agent_card( self, relative_card_path: str | None = None, http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Fetches an agent card from a specified path relative to the base_url. @@ -56,6 +58,7 @@ async def get_agent_card( agent card path. Use `'/'` for an empty path. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.get request. + signature_verifier: A callable used to verify the agent card's signatures. Returns: An `AgentCard` object representing the agent's capabilities. @@ -86,6 +89,8 @@ async def get_agent_card( agent_card_data, ) agent_card = AgentCard.model_validate(agent_card_data) + if signature_verifier: + signature_verifier(agent_card) except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index fd97b4d1..286641a7 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -185,6 +185,7 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index fabd7270..c3d5762e 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -116,6 +116,7 @@ async def connect( # noqa: PLR0913 resolver_http_kwargs: dict[str, Any] | None = None, extra_transports: dict[str, TransportProducer] | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> Client: """Convenience method for constructing a client. @@ -146,6 +147,7 @@ async def connect( # noqa: PLR0913 extra_transports: Additional transport protocols to enable when constructing the client. extensions: List of extensions to be activated. + signature_verifier: A callable used to verify the agent card's signatures. Returns: A `Client` object. @@ -158,12 +160,14 @@ async def connect( # noqa: PLR0913 card = await resolver.get_agent_card( relative_card_path=relative_card_path, http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, ) else: resolver = A2ACardResolver(client_config.httpx_client, agent) card = await resolver.get_agent_card( relative_card_path=relative_card_path, http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, ) else: card = agent @@ -256,7 +260,7 @@ def minimal_agent_card( """Generates a minimal card to simplify bootstrapping client creation. This minimal card is not viable itself to interact with the remote agent. - Instead this is a short hand way to take a known url and transport option + Instead this is a shorthand way to take a known url and transport option and interact with the get card endpoint of the agent server to get the correct agent card. This pattern is necessary for gRPC based card access as typically these servers won't expose a well known path card. diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 8f114d95..0c54a28d 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from a2a.client.middleware import ClientCallContext from a2a.types import ( @@ -103,6 +103,7 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the AgentCard.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 4e27953a..6a8b16f9 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable try: @@ -223,6 +223,7 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -236,6 +237,9 @@ async def get_card( metadata=self._get_grpc_metadata(extensions), ) card = proto_utils.FromProto.agent_card(card_pb) + if signature_verifier: + signature_verifier(card) + self.agent_card = card self._needs_extended_card = False return card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index d8011cf4..a565e640 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -1,7 +1,7 @@ import json import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import Any from uuid import uuid4 @@ -174,6 +174,7 @@ async def send_message_streaming( **modified_kwargs, ) as event_source: try: + event_source.response.raise_for_status() async for sse in event_source.aiter_sse(): response = SendStreamingMessageResponse.model_validate( json.loads(sse.data) @@ -181,6 +182,8 @@ async def send_message_streaming( if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) yield response.root.result + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' @@ -376,13 +379,20 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) card = self.agent_card + if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) + http_kwargs=modified_kwargs, + signature_verifier=signature_verifier, ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -393,10 +403,6 @@ async def get_card( return card request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) - modified_kwargs = update_extension_header( - self._get_http_args(context), - extensions if extensions is not None else self.extensions, - ) payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), @@ -412,7 +418,11 @@ async def get_card( ) if isinstance(response.root, JSONRPCErrorResponse): raise A2AClientJSONRPCError(response.root) - self.agent_card = response.root.result + card = response.root.result + if signature_verifier: + signature_verifier(card) + + self.agent_card = card self._needs_extended_card = False return card diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 83c26787..afc9dd08 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -1,7 +1,7 @@ import json import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import Any import httpx @@ -152,10 +152,13 @@ async def send_message_streaming( **modified_kwargs, ) as event_source: try: + event_source.response.raise_for_status() async for sse in event_source.aiter_sse(): event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e except SSEError as e: raise A2AClientHTTPError( 400, f'Invalid SSE response or protocol error: {e}' @@ -368,13 +371,20 @@ async def get_card( *, context: ClientCallContext | None = None, extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) card = self.agent_card + if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) + http_kwargs=modified_kwargs, + signature_verifier=signature_verifier, ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -384,10 +394,6 @@ async def get_card( if not self._needs_extended_card: return card - modified_kwargs = update_extension_header( - self._get_http_args(context), - extensions if extensions is not None else self.extensions, - ) _, modified_kwargs = await self._apply_interceptors( {}, modified_kwargs, @@ -397,6 +403,9 @@ async def get_card( '/v1/card', {}, modified_kwargs ) card = AgentCard.model_validate(response_data) + if signature_verifier: + signature_verifier(card) + self.agent_card = card self._needs_extended_card = False return card diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py index 3eca4435..876b6561 100644 --- a/src/a2a/server/agent_execution/simple_request_context_builder.py +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -2,6 +2,7 @@ from a2a.server.agent_execution import RequestContext, RequestContextBuilder from a2a.server.context import ServerCallContext +from a2a.server.id_generator import IDGenerator from a2a.server.tasks import TaskStore from a2a.types import MessageSendParams, Task @@ -13,6 +14,8 @@ def __init__( self, should_populate_referred_tasks: bool = False, task_store: TaskStore | None = None, + task_id_generator: IDGenerator | None = None, + context_id_generator: IDGenerator | None = None, ) -> None: """Initializes the SimpleRequestContextBuilder. @@ -22,9 +25,13 @@ def __init__( `related_tasks` field in the RequestContext. Defaults to False. task_store: The TaskStore instance to use for fetching referred tasks. Required if `should_populate_referred_tasks` is True. + task_id_generator: ID generator for new task IDs. Defaults to None. + context_id_generator: ID generator for new context IDs. Defaults to None. """ self._task_store = task_store self._should_populate_referred_tasks = should_populate_referred_tasks + self._task_id_generator = task_id_generator + self._context_id_generator = context_id_generator async def build( self, @@ -74,4 +81,6 @@ async def build( task=task, related_tasks=related_tasks, call_context=context, + task_id_generator=self._task_id_generator, + context_id_generator=self._context_id_generator, ) diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index f6599cca..357fcb02 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -73,7 +73,7 @@ async def dequeue_event(self, no_wait: bool = False) -> Event: closed but when there are no events on the queue. Two ways to avoid this are to call this with no_wait = True which won't block, but is the callers responsibility to retry as appropriate. Alternatively, one can - use a async Task management solution to cancel the get task if the queue + use an async Task management solution to cancel the get task if the queue has closed or some other condition is met. The implementation of the EventConsumer uses an async.wait with a timeout to abort the dequeue_event call and retry, when it will return with a closed error. diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index d13c5e50..53cdb9f5 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -117,12 +117,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: ', Data=' + str(error.data) if error.data else '', ) # Since the stream has started, we can't return a JSONResponse. - # Instead, we runt the error handling logic (provides logging) + # Instead, we run the error handling logic (provides logging) # and reraise the error and let server framework manage raise e except Exception as e: # Since the stream has started, we can't return a JSONResponse. - # Instead, we runt the error handling logic (provides logging) + # Instead, we run the error handling logic (provides logging) # and reraise the error and let server framework manage raise e diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 96c1646a..96acdc1e 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -2,6 +2,7 @@ import functools import inspect +import json import logging from collections.abc import Callable @@ -9,6 +10,7 @@ from uuid import uuid4 from a2a.types import ( + AgentCard, Artifact, MessageSendParams, Part, @@ -340,3 +342,29 @@ def are_modalities_compatible( return True return any(x in server_output_modes for x in client_output_modes) + + +def _clean_empty(d: Any) -> Any: + """Recursively remove empty strings, lists and dicts from a dictionary.""" + if isinstance(d, dict): + cleaned_dict: dict[Any, Any] = { + k: _clean_empty(v) for k, v in d.items() + } + return {k: v for k, v in cleaned_dict.items() if v} + if isinstance(d, list): + cleaned_list: list[Any] = [_clean_empty(v) for v in d] + return [v for v in cleaned_list if v] + return d if d not in ['', [], {}] else None + + +def canonicalize_agent_card(agent_card: AgentCard) -> str: + """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" + card_dict = agent_card.model_dump( + exclude={'signatures'}, + exclude_defaults=True, + exclude_none=True, + by_alias=True, + ) + # Recursively remove empty values + cleaned_dict = _clean_empty(card_dict) + return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index d077d62b..14ac098d 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -203,6 +203,7 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: if task.history else None ), + metadata=cls.metadata(task.metadata), ) @classmethod @@ -396,6 +397,21 @@ def agent_card( ] if card.additional_interfaces else None, + signatures=[cls.agent_card_signature(x) for x in card.signatures] + if card.signatures + else None, + ) + + @classmethod + def agent_card_signature( + cls, signature: types.AgentCardSignature + ) -> a2a_pb2.AgentCardSignature: + return a2a_pb2.AgentCardSignature( + protected=signature.protected, + signature=signature.signature, + header=dict_to_struct(signature.header) + if signature.header is not None + else None, ) @classmethod @@ -660,6 +676,7 @@ def task(cls, task: a2a_pb2.Task) -> types.Task: status=cls.task_status(task.status), artifacts=[cls.artifact(a) for a in task.artifacts], history=[cls.message(h) for h in task.history], + metadata=cls.metadata(task.metadata), ) @classmethod @@ -863,6 +880,19 @@ def agent_card( ] if card.additional_interfaces else None, + signatures=[cls.agent_card_signature(x) for x in card.signatures] + if card.signatures + else None, + ) + + @classmethod + def agent_card_signature( + cls, signature: a2a_pb2.AgentCardSignature + ) -> types.AgentCardSignature: + return types.AgentCardSignature( + protected=signature.protected, + signature=signature.signature, + header=json_format.MessageToDict(signature.header), ) @classmethod diff --git a/src/a2a/utils/signing.py b/src/a2a/utils/signing.py new file mode 100644 index 00000000..6ea8c21b --- /dev/null +++ b/src/a2a/utils/signing.py @@ -0,0 +1,152 @@ +import json + +from collections.abc import Callable +from typing import Any, TypedDict + +from a2a.utils.helpers import canonicalize_agent_card + + +try: + import jwt + + from jwt.api_jwk import PyJWK + from jwt.exceptions import PyJWTError + from jwt.utils import base64url_decode, base64url_encode +except ImportError as e: + raise ImportError( + 'A2A Signing requires PyJWT to be installed. ' + 'Install with: ' + "'pip install a2a-sdk[signing]'" + ) from e + +from a2a.types import AgentCard, AgentCardSignature + + +class SignatureVerificationError(Exception): + """Base exception for signature verification errors.""" + + +class NoSignatureError(SignatureVerificationError): + """Exception raised when no signature is found on an AgentCard.""" + + +class InvalidSignaturesError(SignatureVerificationError): + """Exception raised when all signatures are invalid.""" + + +class ProtectedHeader(TypedDict): + """Protected header parameters for JWS (JSON Web Signature).""" + + kid: str + """ Key identifier. """ + alg: str | None + """ Algorithm used for signing. """ + jku: str | None + """ JSON Web Key Set URL. """ + typ: str | None + """ Token type. + + Best practice: SHOULD be "JOSE" for JWS tokens. + """ + + +def create_agent_card_signer( + signing_key: PyJWK | str | bytes, + protected_header: ProtectedHeader, + header: dict[str, Any] | None = None, +) -> Callable[[AgentCard], AgentCard]: + """Creates a function that signs an AgentCard and adds the signature. + + Args: + signing_key: The private key for signing. + protected_header: The protected header parameters. + header: Unprotected header parameters. + + Returns: + A callable that takes an AgentCard and returns the modified AgentCard with a signature. + """ + + def agent_card_signer(agent_card: AgentCard) -> AgentCard: + """Signs agent card.""" + canonical_payload = canonicalize_agent_card(agent_card) + payload_dict = json.loads(canonical_payload) + + jws_string = jwt.encode( + payload=payload_dict, + key=signing_key, + algorithm=protected_header.get('alg', 'HS256'), + headers=dict(protected_header), + ) + + # The result of jwt.encode is a compact serialization: HEADER.PAYLOAD.SIGNATURE + protected, _, signature = jws_string.split('.') + + agent_card_signature = AgentCardSignature( + header=header, + protected=protected, + signature=signature, + ) + + agent_card.signatures = (agent_card.signatures or []) + [ + agent_card_signature + ] + return agent_card + + return agent_card_signer + + +def create_signature_verifier( + key_provider: Callable[[str | None, str | None], PyJWK | str | bytes], + algorithms: list[str], +) -> Callable[[AgentCard], None]: + """Creates a function that verifies the signatures on an AgentCard. + + The verifier succeeds if at least one signature is valid. Otherwise, it raises an error. + + Args: + key_provider: A callable that accepts a key ID (kid) and a JWK Set URL (jku) and returns the verification key. + This function is responsible for fetching the correct key for a given signature. + algorithms: A list of acceptable algorithms (e.g., ['ES256', 'RS256']) for verification used to prevent algorithm confusion attacks. + + Returns: + A function that takes an AgentCard as input, and raises an error if none of the signatures are valid. + """ + + def signature_verifier( + agent_card: AgentCard, + ) -> None: + """Verifies agent card signatures.""" + if not agent_card.signatures: + raise NoSignatureError('AgentCard has no signatures to verify.') + + for agent_card_signature in agent_card.signatures: + try: + # get verification key + protected_header_json = base64url_decode( + agent_card_signature.protected.encode('utf-8') + ).decode('utf-8') + protected_header = json.loads(protected_header_json) + kid = protected_header.get('kid') + jku = protected_header.get('jku') + verification_key = key_provider(kid, jku) + + canonical_payload = canonicalize_agent_card(agent_card) + encoded_payload = base64url_encode( + canonical_payload.encode('utf-8') + ).decode('utf-8') + + token = f'{agent_card_signature.protected}.{encoded_payload}.{agent_card_signature.signature}' + jwt.decode( + jwt=token, + key=verification_key, + algorithms=algorithms, + ) + # Found a valid signature, exit the loop and function + break + except PyJWTError: + continue + else: + # This block runs only if the loop completes without a break + raise InvalidSignaturesError('No valid signature found') + + return signature_verifier diff --git a/tests/README.md b/tests/README.md index d89f3bec..872ac723 100644 --- a/tests/README.md +++ b/tests/README.md @@ -5,7 +5,7 @@ uv run pytest -v -s client/test_client_factory.py ``` -In case of failures, you can cleanup the cache: +In case of failures, you can clean up the cache: 1. `uv clean` 2. `rm -fR .pytest_cache .venv __pycache__` diff --git a/tests/auth/test_user.py b/tests/auth/test_user.py index 5cc479ce..e3bbe2e6 100644 --- a/tests/auth/test_user.py +++ b/tests/auth/test_user.py @@ -1,9 +1,19 @@ import unittest -from a2a.auth.user import UnauthenticatedUser +from inspect import isabstract + +from a2a.auth.user import UnauthenticatedUser, User + + +class TestUser(unittest.TestCase): + def test_is_abstract(self): + self.assertTrue(isabstract(User)) class TestUnauthenticatedUser(unittest.TestCase): + def test_is_user_subclass(self): + self.assertTrue(issubclass(UnauthenticatedUser, User)) + def test_is_authenticated_returns_false(self): user = UnauthenticatedUser() self.assertFalse(user.is_authenticated) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index f5ab2543..7aa47902 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -9,6 +9,7 @@ AgentCapabilities, AgentCard, Message, + MessageSendConfiguration, Part, Role, Task, @@ -125,3 +126,78 @@ async def test_send_message_non_streaming_agent_capability_false( assert not mock_transport.send_message_streaming.called assert len(events) == 1 assert events[0][0].id == 'task-789' + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_non_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = False + mock_transport.send_message.return_value = Task( + id='task-cfg-ns-1', + context_id='ctx-cfg-ns-1', + status=TaskStatus(state=TaskState.completed), + ) + + cfg = MessageSendConfiguration( + history_length=2, + blocking=False, + accepted_output_modes=['application/json'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-ns-1' + + params = mock_transport.send_message.call_args[0][0] + assert params.configuration.history_length == 2 + assert params.configuration.blocking is False + assert params.configuration.accepted_output_modes == ['application/json'] + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = True + base_client._card.capabilities.streaming = True + + async def create_stream(*args, **kwargs): + yield Task( + id='task-cfg-s-1', + context_id='ctx-cfg-s-1', + status=TaskStatus(state=TaskState.completed), + ) + + mock_transport.send_message_streaming.return_value = create_stream() + + cfg = MessageSendConfiguration( + history_length=0, + blocking=True, + accepted_output_modes=['text/plain'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-s-1' + + params = mock_transport.send_message_streaming.call_args[0][0] + assert params.configuration.history_length == 0 + assert params.configuration.blocking is True + assert params.configuration.accepted_output_modes == ['text/plain'] diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py new file mode 100644 index 00000000..26f3f106 --- /dev/null +++ b/tests/client/test_card_resolver.py @@ -0,0 +1,400 @@ +import json +import logging + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest + +from a2a.client import A2ACardResolver, A2AClientHTTPError, A2AClientJSONError +from a2a.types import AgentCard +from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH + + +@pytest.fixture +def mock_httpx_client(): + """Fixture providing a mocked async httpx client.""" + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def base_url(): + """Fixture providing a test base URL.""" + return 'https://example.com' + + +@pytest.fixture +def resolver(mock_httpx_client, base_url): + """Fixture providing an A2ACardResolver instance.""" + return A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + ) + + +@pytest.fixture +def mock_response(): + """Fixture providing a mock httpx Response.""" + response = Mock(spec=httpx.Response) + response.raise_for_status = Mock() + return response + + +@pytest.fixture +def valid_agent_card_data(): + """Fixture providing valid agent card data.""" + return { + 'name': 'TestAgent', + 'description': 'A test agent', + 'version': '1.0.0', + 'url': 'https://example.com/a2a', + 'capabilities': {}, + 'default_input_modes': ['text/plain'], + 'default_output_modes': ['text/plain'], + 'skills': [ + { + 'id': 'test-skill', + 'name': 'Test Skill', + 'description': 'A skill for testing', + 'tags': ['test'], + } + ], + } + + +class TestA2ACardResolverInit: + """Tests for A2ACardResolver initialization.""" + + def test_init_with_defaults(self, mock_httpx_client, base_url): + """Test initialization with default agent_card_path.""" + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + ) + assert resolver.base_url == base_url + assert resolver.agent_card_path == AGENT_CARD_WELL_KNOWN_PATH[1:] + assert resolver.httpx_client == mock_httpx_client + + def test_init_with_custom_path(self, mock_httpx_client, base_url): + """Test initialization with custom agent_card_path.""" + custom_path = '/custom/agent/card' + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + agent_card_path=custom_path, + ) + assert resolver.base_url == base_url + assert resolver.agent_card_path == custom_path[1:] + + def test_init_strips_leading_slash_from_agent_card_path( + self, mock_httpx_client, base_url + ): + """Test that leading slash is stripped from agent_card_path.""" + agent_card_path = '/well-known/agent' + resolver = A2ACardResolver( + httpx_client=mock_httpx_client, + base_url=base_url, + agent_card_path=agent_card_path, + ) + assert resolver.agent_card_path == agent_card_path[1:] + + +class TestGetAgentCard: + """Tests for get_agent_card methods.""" + + @pytest.mark.asyncio + async def test_get_agent_card_success_default_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test successful agent card fetch using default path.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ) as mock_validate: + result = await resolver.get_agent_card() + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}', + ) + mock_response.raise_for_status.assert_called_once() + mock_response.json.assert_called_once() + mock_validate.assert_called_once_with(valid_agent_card_data) + assert result is not None + + @pytest.mark.asyncio + async def test_get_agent_card_success_custom_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test successful agent card fetch using custom relative path.""" + custom_path = 'custom/path/card' + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(relative_card_path=custom_path) + + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{custom_path}', + ) + + @pytest.mark.asyncio + async def test_get_agent_card_strips_leading_slash_from_relative_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test successful agent card fetch using custom path with leading slash.""" + custom_path = '/custom/path/card' + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(relative_card_path=custom_path) + + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{custom_path[1:]}', + ) + + @pytest.mark.asyncio + async def test_get_agent_card_with_http_kwargs( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test that http_kwargs are passed to httpx.get.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + http_kwargs = { + 'timeout': 30, + 'headers': {'Authorization': 'Bearer token'}, + } + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(http_kwargs=http_kwargs) + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}', + timeout=30, + headers={'Authorization': 'Bearer token'}, + ) + + @pytest.mark.asyncio + async def test_get_agent_card_root_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test fetching agent card from root path.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(relative_card_path='/') + mock_httpx_client.get.assert_called_once_with(f'{base_url}/') + + @pytest.mark.asyncio + async def test_get_agent_card_http_status_error( + self, resolver, mock_httpx_client + ): + """Test A2AClientHTTPError raised on HTTP status error.""" + status_code = 404 + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = status_code + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + 'Not Found', request=Mock(), response=mock_response + ) + mock_httpx_client.get.return_value = mock_response + + with pytest.raises(A2AClientHTTPError) as exc_info: + await resolver.get_agent_card() + + assert exc_info.value.status_code == status_code + assert 'Failed to fetch agent card' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_agent_card_json_decode_error( + self, resolver, mock_httpx_client, mock_response + ): + """Test A2AClientJSONError raised on JSON decode error.""" + mock_response.json.side_effect = json.JSONDecodeError( + 'Invalid JSON', '', 0 + ) + mock_httpx_client.get.return_value = mock_response + with pytest.raises(A2AClientJSONError) as exc_info: + await resolver.get_agent_card() + assert 'Failed to parse JSON' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_agent_card_request_error( + self, resolver, mock_httpx_client + ): + """Test A2AClientHTTPError raised on network request error.""" + mock_httpx_client.get.side_effect = httpx.RequestError( + 'Connection timeout', request=Mock() + ) + with pytest.raises(A2AClientHTTPError) as exc_info: + await resolver.get_agent_card() + assert exc_info.value.status_code == 503 + assert 'Network communication error' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_agent_card_validation_error( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test A2AClientJSONError is raised on agent card validation error.""" + return_json = {'invalid': 'data'} + mock_response.json.return_value = return_json + mock_httpx_client.get.return_value = mock_response + with pytest.raises(A2AClientJSONError) as exc_info: + await resolver.get_agent_card() + assert ( + f'Failed to validate agent card structure from {base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}' + in exc_info.value.message + ) + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}', + ) + + @pytest.mark.asyncio + async def test_get_agent_card_logs_success( # noqa: PLR0913 + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + caplog, + ): + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + with ( + patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ), + caplog.at_level(logging.INFO), + ): + await resolver.get_agent_card() + assert ( + f'Successfully fetched agent card data from {base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}' + in caplog.text + ) + + @pytest.mark.asyncio + async def test_get_agent_card_none_relative_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test that None relative_card_path uses default path.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(relative_card_path=None) + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}', + ) + + @pytest.mark.asyncio + async def test_get_agent_card_empty_string_relative_path( + self, + base_url, + resolver, + mock_httpx_client, + mock_response, + valid_agent_card_data, + ): + """Test that empty string relative_card_path uses default path.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + + with patch.object( + AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + ): + await resolver.get_agent_card(relative_card_path='') + + mock_httpx_client.get.assert_called_once_with( + f'{base_url}/{AGENT_CARD_WELL_KNOWN_PATH[1:]}', + ) + + @pytest.mark.parametrize('status_code', [400, 401, 403, 500, 502]) + @pytest.mark.asyncio + async def test_get_agent_card_different_status_codes( + self, resolver, mock_httpx_client, status_code + ): + """Test different HTTP status codes raise appropriate errors.""" + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = status_code + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + f'Status {status_code}', request=Mock(), response=mock_response + ) + mock_httpx_client.get.return_value = mock_response + with pytest.raises(A2AClientHTTPError) as exc_info: + await resolver.get_agent_card() + assert exc_info.value.status_code == status_code + + @pytest.mark.asyncio + async def test_get_agent_card_returns_agent_card_instance( + self, resolver, mock_httpx_client, mock_response, valid_agent_card_data + ): + """Test that get_agent_card returns an AgentCard instance.""" + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + mock_agent_card = Mock(spec=AgentCard) + + with patch.object( + AgentCard, 'model_validate', return_value=mock_agent_card + ): + result = await resolver.get_agent_card() + assert result == mock_agent_card + mock_response.raise_for_status.assert_called_once() + + @pytest.mark.asyncio + async def test_get_agent_card_with_signature_verifier( + self, resolver, mock_httpx_client, valid_agent_card_data + ): + """Test that the signature verifier is called if provided.""" + mock_verifier = MagicMock() + + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = valid_agent_card_data + mock_httpx_client.get.return_value = mock_response + + agent_card = await resolver.get_agent_card( + signature_verifier=mock_verifier + ) + + mock_verifier.assert_called_once_with(agent_card) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 16a1433f..c388974b 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -190,6 +190,7 @@ async def test_client_factory_connect_with_resolver_args( mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, + signature_verifier=None, ) @@ -216,6 +217,7 @@ async def test_client_factory_connect_resolver_args_without_client( mock_resolver.return_value.get_agent_card.assert_awaited_once_with( relative_card_path=relative_path, http_kwargs=http_kwargs, + signature_verifier=None, ) diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index bd705d93..edbcd6c7 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -114,6 +114,14 @@ async def async_iterable_from_list( yield item +def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions = {e.strip() for e in header_value.split(',')} + assert actual_extensions == expected_extensions + + class TestA2ACardResolver: BASE_URL = 'http://example.com' AGENT_CARD_PATH = AGENT_CARD_WELL_KNOWN_PATH @@ -774,7 +782,7 @@ async def test_get_card_with_extended_card_support( mock_send_request.return_value = rpc_response card = await client.get_card() - assert card == agent_card + assert card == AGENT_CARD_EXTENDED mock_send_request.assert_called_once() sent_payload = mock_send_request.call_args.args[0] assert sent_payload['method'] == 'agent/getAuthenticatedExtendedCard' @@ -823,18 +831,13 @@ async def test_send_message_with_default_extensions( mock_httpx_client.post.assert_called_once() _, mock_kwargs = mock_httpx_client.post.call_args - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = headers[HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio @patch('a2a.client.transports.jsonrpc.aconnect_sse') @@ -870,8 +873,121 @@ async def test_send_message_streaming_with_new_extensions( mock_aconnect_sse.assert_called_once() _, kwargs = mock_aconnect_sse.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v2', + }, + ) + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_server_error_propagates( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test that send_message_streaming propagates server errors (e.g., 403, 500) directly.""" + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Error stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 403 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + 'Forbidden', + request=httpx.Request('POST', 'http://test.url'), + response=mock_response, + ) + mock_event_source.response = mock_response + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError) as exc_info: + async for _ in client.send_message_streaming(request=params): + pass + + assert exc_info.value.status_code == 403 + mock_aconnect_sse.assert_called_once() + + @pytest.mark.asyncio + async def test_get_card_no_card_provided_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions set in Client when no card is initially provided. + Tests that the extensions are added to the HTTP GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + url=TestJsonRpcTransport.AGENT_URL, + extensions=extensions, + ) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_httpx_client.get.return_value = mock_response + + await client.get_card() + + mock_httpx_client.get.assert_called_once() + _, mock_kwargs = mock_httpx_client.get.call_args + + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions passed to get_card call when extended card support is enabled. + Tests that the extensions are added to the RPC request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + agent_card = AGENT_CARD.model_copy( + update={'supports_authenticated_extended_card': True} + ) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + extensions=extensions, + ) + + rpc_response = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': AGENT_CARD_EXTENDED.model_dump(mode='json'), + } + with patch.object( + client, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = rpc_response + await client.get_card(extensions=extensions) + + mock_send_request.assert_called_once() + _, mock_kwargs = mock_send_request.call_args[0] + + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 04bd1036..cd68b443 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -7,9 +7,14 @@ from httpx_sse import EventSource, ServerSentEvent from a2a.client import create_text_message_object +from a2a.client.errors import A2AClientHTTPError from a2a.client.transports.rest import RestTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.types import AgentCard, MessageSendParams, Role +from a2a.types import ( + AgentCapabilities, + AgentCard, + MessageSendParams, +) @pytest.fixture @@ -32,6 +37,14 @@ async def async_iterable_from_list( yield item +def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions = {e.strip() for e in header_value.split(',')} + assert actual_extensions == expected_extensions + + class TestRestTransportExtensions: @pytest.mark.asyncio async def test_send_message_with_default_extensions( @@ -67,18 +80,13 @@ async def test_send_message_with_default_extensions( mock_build_request.assert_called_once() _, kwargs = mock_build_request.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - header_value = kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'https://example.com/test-ext/v1', - 'https://example.com/test-ext/v2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) @pytest.mark.asyncio @patch('a2a.client.transports.rest.aconnect_sse') @@ -114,8 +122,141 @@ async def test_send_message_streaming_with_new_extensions( mock_aconnect_sse.assert_called_once() _, kwargs = mock_aconnect_sse.call_args - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + _assert_extensions_header( + kwargs, + { + 'https://example.com/test-ext/v2', + }, + ) + + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_server_error_propagates( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test that send_message_streaming propagates server errors (e.g., 403, 500) directly.""" + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Error stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 403 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + 'Forbidden', + request=httpx.Request('POST', 'http://test.url'), + response=mock_response, + ) + mock_event_source.response = mock_response + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + with pytest.raises(A2AClientHTTPError) as exc_info: + async for _ in client.send_message_streaming(request=params): + pass + + assert exc_info.value.status_code == 403 + + mock_aconnect_sse.assert_called_once() + + @pytest.mark.asyncio + async def test_get_card_no_card_provided_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions set in Client when no card is initially provided. + Tests that the extensions are added to the HTTP GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = RestTransport( + httpx_client=mock_httpx_client, + url='http://agent.example.com/api', + extensions=extensions, + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = { + 'name': 'Test Agent', + 'description': 'Test Agent Description', + 'url': 'http://agent.example.com/api', + 'version': '1.0.0', + 'default_input_modes': ['text'], + 'default_output_modes': ['text'], + 'capabilities': AgentCapabilities().model_dump(), + 'skills': [], + } + mock_httpx_client.get.return_value = mock_response + + await client.get_card() + + mock_httpx_client.get.assert_called_once() + _, mock_kwargs = mock_httpx_client.get.call_args + + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, + ) + + @pytest.mark.asyncio + async def test_get_card_with_extended_card_support_with_extensions( + self, mock_httpx_client: AsyncMock + ): + """Test get_card with extensions passed to get_card call when extended card support is enabled. + Tests that the extensions are added to the GET request.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + agent_card = AgentCard( + name='Test Agent', + description='Test Agent Description', + url='http://agent.example.com/api', + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities(), + skills=[], + supports_authenticated_extended_card=True, + ) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + ) + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = agent_card.model_dump(mode='json') + mock_httpx_client.send.return_value = mock_response + + with patch.object( + client, '_send_get_request', new_callable=AsyncMock + ) as mock_send_get_request: + mock_send_get_request.return_value = agent_card.model_dump( + mode='json' + ) + await client.get_card(extensions=extensions) + + mock_send_get_request.assert_called_once() + _, _, mock_kwargs = mock_send_get_request.call_args[0] + + _assert_extensions_header( + mock_kwargs, + { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + }, ) diff --git a/tests/e2e/push_notifications/notifications_app.py b/tests/e2e/push_notifications/notifications_app.py index ed032dcb..c12e9809 100644 --- a/tests/e2e/push_notifications/notifications_app.py +++ b/tests/e2e/push_notifications/notifications_app.py @@ -23,7 +23,7 @@ def create_notifications_app() -> FastAPI: @app.post('/notifications') async def add_notification(request: Request): - """Endpoint for injesting notifications from agents. It receives a JSON + """Endpoint for ingesting notifications from agents. It receives a JSON payload and stores it in-memory. """ token = request.headers.get('x-a2a-notification-token') @@ -56,7 +56,7 @@ async def list_notifications_by_task( str, Path(title='The ID of the task to list the notifications for.') ], ): - """Helper endpoint for retrieving injested notifications for a given task.""" + """Helper endpoint for retrieving ingested notifications for a given task.""" async with store_lock: notifications = store.get(task_id, []) return {'notifications': notifications} diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 775bd7fb..d7364b84 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -35,7 +35,7 @@ @pytest.fixture(scope='module') def notifications_server(): """ - Starts a simple push notifications injesting server and yields its URL. + Starts a simple push notifications ingesting server and yields its URL. """ host = '127.0.0.1' port = find_free_port() @@ -148,7 +148,7 @@ async def test_notification_triggering_after_config_change_e2e( notifications_server: str, agent_server: str, http_client: httpx.AsyncClient ): """ - Tests notification triggering after setting the push notificaiton config in a seperate call. + Tests notification triggering after setting the push notification config in a separate call. """ # Configure an A2A client without a push notification config. a2a_client = ClientFactory( diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e0a564ee..e6552fcb 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import AsyncGenerator -from typing import NamedTuple +from typing import NamedTuple, Any from unittest.mock import ANY, AsyncMock, patch import grpc @@ -9,6 +9,7 @@ import pytest_asyncio from grpc.aio import Channel +from jwt.api_jwk import PyJWK from a2a.client import ClientConfig from a2a.client.base_client import BaseClient from a2a.client.transports import JsonRpcTransport, RestTransport @@ -17,6 +18,10 @@ from a2a.grpc import a2a_pb2_grpc from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication from a2a.server.request_handlers import GrpcHandler, RequestHandler +from a2a.utils.signing import ( + create_agent_card_signer, + create_signature_verifier, +) from a2a.types import ( AgentCapabilities, AgentCard, @@ -37,6 +42,7 @@ TextPart, TransportProtocol, ) +from cryptography.hazmat.primitives import asymmetric # --- Test Constants --- @@ -83,6 +89,15 @@ ) +def create_key_provider(verification_key: PyJWK | str | bytes): + """Creates a key provider function for testing.""" + + def key_provider(kid: str | None, jku: str | None): + return verification_key + + return key_provider + + # --- Test Fixtures --- @@ -739,6 +754,7 @@ async def test_http_transport_get_authenticated_card( transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card) result = await transport.get_card() assert result.name == extended_agent_card.name + assert transport.agent_card is not None assert transport.agent_card.name == extended_agent_card.name assert transport._needs_extended_card is False @@ -761,6 +777,7 @@ def channel_factory(address: str) -> Channel: transport = GrpcTransport(channel=channel, agent_card=agent_card) # The transport starts with a minimal card, get_card() fetches the full one + assert transport.agent_card is not None transport.agent_card.supports_authenticated_extended_card = True result = await transport.get_card() @@ -772,7 +789,7 @@ def channel_factory(address: str) -> Channel: @pytest.mark.asyncio -async def test_base_client_sends_message_with_extensions( +async def test_json_transport_base_client_send_message_with_extensions( jsonrpc_setup: TransportSetup, agent_card: AgentCard ) -> None: """ @@ -827,3 +844,300 @@ async def test_base_client_sends_message_with_extensions( if hasattr(transport, 'close'): await transport.close() + + +@pytest.mark.asyncio +async def test_json_transport_get_signed_base_card( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """Tests fetching and verifying a symmetrically signed AgentCard via JSON-RPC. + + The client transport is initialized without a card, forcing it to fetch + the base card from the server. The server signs the card using HS384. + The client then verifies the signature. + """ + mock_request_handler = jsonrpc_setup.handler + agent_card.supports_authenticated_extended_card = False + + # Setup signing on the server side + key = 'key12345' + signer = create_agent_card_signer( + signing_key=key, + protected_header={ + 'alg': 'HS384', + 'kid': 'testkey', + 'jku': None, + 'typ': 'JOSE', + }, + ) + + app_builder = A2AFastAPIApplication( + agent_card, + mock_request_handler, + card_modifier=signer, # Sign the base card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + + transport = JsonRpcTransport( + httpx_client=httpx_client, + url=agent_card.url, + agent_card=None, + ) + + # Get the card, this will trigger verification in get_card + signature_verifier = create_signature_verifier( + create_key_provider(key), ['HS384'] + ) + result = await transport.get_card(signature_verifier=signature_verifier) + assert result.name == agent_card.name + assert result.signatures is not None + assert len(result.signatures) == 1 + assert transport.agent_card is not None + assert transport.agent_card.name == agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_json_transport_get_signed_extended_card( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """Tests fetching and verifying an asymmetrically signed extended AgentCard via JSON-RPC. + + The client has a base card and fetches the extended card, which is signed + by the server using ES256. The client verifies the signature on the + received extended card. + """ + mock_request_handler = jsonrpc_setup.handler + agent_card.supports_authenticated_extended_card = True + extended_agent_card = agent_card.model_copy(deep=True) + extended_agent_card.name = 'Extended Agent Card' + + # Setup signing on the server side + private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1()) + public_key = private_key.public_key() + signer = create_agent_card_signer( + signing_key=private_key, + protected_header={ + 'alg': 'ES256', + 'kid': 'testkey', + 'jku': None, + 'typ': 'JOSE', + }, + ) + + app_builder = A2AFastAPIApplication( + agent_card, + mock_request_handler, + extended_agent_card=extended_agent_card, + extended_card_modifier=lambda card, ctx: signer( + card + ), # Sign the extended card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + + transport = JsonRpcTransport( + httpx_client=httpx_client, agent_card=agent_card + ) + + # Get the card, this will trigger verification in get_card + signature_verifier = create_signature_verifier( + create_key_provider(public_key), ['HS384', 'ES256'] + ) + result = await transport.get_card(signature_verifier=signature_verifier) + assert result.name == extended_agent_card.name + assert result.signatures is not None + assert len(result.signatures) == 1 + assert transport.agent_card is not None + assert transport.agent_card.name == extended_agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_json_transport_get_signed_base_and_extended_cards( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """Tests fetching and verifying both base and extended cards via JSON-RPC when no card is initially provided. + + The client starts with no card. It first fetches the base card, which is + signed. It then fetches the extended card, which is also signed. Both signatures + are verified independently upon retrieval. + """ + mock_request_handler = jsonrpc_setup.handler + assert agent_card.signatures is None + agent_card.supports_authenticated_extended_card = True + extended_agent_card = agent_card.model_copy(deep=True) + extended_agent_card.name = 'Extended Agent Card' + + # Setup signing on the server side + private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1()) + public_key = private_key.public_key() + signer = create_agent_card_signer( + signing_key=private_key, + protected_header={ + 'alg': 'ES256', + 'kid': 'testkey', + 'jku': None, + 'typ': 'JOSE', + }, + ) + + app_builder = A2AFastAPIApplication( + agent_card, + mock_request_handler, + extended_agent_card=extended_agent_card, + card_modifier=signer, # Sign the base card + extended_card_modifier=lambda card, ctx: signer( + card + ), # Sign the extended card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + + transport = JsonRpcTransport( + httpx_client=httpx_client, + url=agent_card.url, + agent_card=None, + ) + + # Get the card, this will trigger verification in get_card + signature_verifier = create_signature_verifier( + create_key_provider(public_key), ['HS384', 'ES256', 'RS256'] + ) + result = await transport.get_card(signature_verifier=signature_verifier) + assert result.name == extended_agent_card.name + assert result.signatures is not None + assert len(result.signatures) == 1 + assert transport.agent_card is not None + assert transport.agent_card.name == extended_agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_rest_transport_get_signed_card( + rest_setup: TransportSetup, agent_card: AgentCard +) -> None: + """Tests fetching and verifying signed base and extended cards via REST. + + The client starts with no card. It first fetches the base card, which is + signed. It then fetches the extended card, which is also signed. Both signatures + are verified independently upon retrieval. + """ + mock_request_handler = rest_setup.handler + agent_card.supports_authenticated_extended_card = True + extended_agent_card = agent_card.model_copy(deep=True) + extended_agent_card.name = 'Extended Agent Card' + + # Setup signing on the server side + private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1()) + public_key = private_key.public_key() + signer = create_agent_card_signer( + signing_key=private_key, + protected_header={ + 'alg': 'ES256', + 'kid': 'testkey', + 'jku': None, + 'typ': 'JOSE', + }, + ) + + app_builder = A2ARESTFastAPIApplication( + agent_card, + mock_request_handler, + extended_agent_card=extended_agent_card, + card_modifier=signer, # Sign the base card + extended_card_modifier=lambda card, ctx: signer( + card + ), # Sign the extended card + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + + transport = RestTransport( + httpx_client=httpx_client, + url=agent_card.url, + agent_card=None, + ) + + # Get the card, this will trigger verification in get_card + signature_verifier = create_signature_verifier( + create_key_provider(public_key), ['HS384', 'ES256', 'RS256'] + ) + result = await transport.get_card(signature_verifier=signature_verifier) + assert result.name == extended_agent_card.name + assert result.signatures is not None + assert len(result.signatures) == 1 + assert transport.agent_card is not None + assert transport.agent_card.name == extended_agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + +@pytest.mark.asyncio +async def test_grpc_transport_get_signed_card( + mock_request_handler: AsyncMock, agent_card: AgentCard +) -> None: + """Tests fetching and verifying a signed AgentCard via gRPC.""" + # Setup signing on the server side + agent_card.supports_authenticated_extended_card = True + + private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1()) + public_key = private_key.public_key() + signer = create_agent_card_signer( + signing_key=private_key, + protected_header={ + 'alg': 'ES256', + 'kid': 'testkey', + 'jku': None, + 'typ': 'JOSE', + }, + ) + + server = grpc.aio.server() + port = server.add_insecure_port('[::]:0') + server_address = f'localhost:{port}' + agent_card.url = server_address + + servicer = GrpcHandler( + agent_card, + mock_request_handler, + card_modifier=signer, + ) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) + await server.start() + + transport = None # Initialize transport + try: + + def channel_factory(address: str) -> Channel: + return grpc.aio.insecure_channel(address) + + channel = channel_factory(server_address) + transport = GrpcTransport(channel=channel, agent_card=agent_card) + transport.agent_card = None + assert transport._needs_extended_card is True + + # Get the card, this will trigger verification in get_card + signature_verifier = create_signature_verifier( + create_key_provider(public_key), ['HS384', 'ES256', 'RS256'] + ) + result = await transport.get_card(signature_verifier=signature_verifier) + assert result.signatures is not None + assert len(result.signatures) == 1 + assert transport._needs_extended_card is False + finally: + if transport: + await transport.close() + await server.stop(0) # Gracefully stop the server diff --git a/tests/server/agent_execution/test_simple_request_context_builder.py b/tests/server/agent_execution/test_simple_request_context_builder.py index 5e1b8fd8..c1cbcf05 100644 --- a/tests/server/agent_execution/test_simple_request_context_builder.py +++ b/tests/server/agent_execution/test_simple_request_context_builder.py @@ -10,6 +10,7 @@ SimpleRequestContextBuilder, ) from a2a.server.context import ServerCallContext +from a2a.server.id_generator import IDGenerator from a2a.server.tasks.task_store import TaskStore from a2a.types import ( Message, @@ -275,6 +276,65 @@ async def test_build_populate_false_with_reference_task_ids(self) -> None: self.assertEqual(request_context.related_tasks, []) self.mock_task_store.get.assert_not_called() + async def test_build_with_custom_id_generators(self) -> None: + mock_task_id_generator = AsyncMock(spec=IDGenerator) + mock_context_id_generator = AsyncMock(spec=IDGenerator) + mock_task_id_generator.generate.return_value = 'custom_task_id' + mock_context_id_generator.generate.return_value = 'custom_context_id' + + builder = SimpleRequestContextBuilder( + should_populate_referred_tasks=False, + task_store=self.mock_task_store, + task_id_generator=mock_task_id_generator, + context_id_generator=mock_context_id_generator, + ) + params = MessageSendParams(message=create_sample_message()) + server_call_context = ServerCallContext(user=UnauthenticatedUser()) + + request_context = await builder.build( + params=params, + task_id=None, + context_id=None, + task=None, + context=server_call_context, + ) + + mock_task_id_generator.generate.assert_called_once() + mock_context_id_generator.generate.assert_called_once() + self.assertEqual(request_context.task_id, 'custom_task_id') + self.assertEqual(request_context.context_id, 'custom_context_id') + + async def test_build_with_provided_ids_and_custom_id_generators( + self, + ) -> None: + mock_task_id_generator = AsyncMock(spec=IDGenerator) + mock_context_id_generator = AsyncMock(spec=IDGenerator) + + builder = SimpleRequestContextBuilder( + should_populate_referred_tasks=False, + task_store=self.mock_task_store, + task_id_generator=mock_task_id_generator, + context_id_generator=mock_context_id_generator, + ) + params = MessageSendParams(message=create_sample_message()) + server_call_context = ServerCallContext(user=UnauthenticatedUser()) + + provided_task_id = 'provided_task_id' + provided_context_id = 'provided_context_id' + + request_context = await builder.build( + params=params, + task_id=provided_task_id, + context_id=provided_context_id, + task=None, + context=server_call_context, + ) + + mock_task_id_generator.generate.assert_not_called() + mock_context_id_generator.generate.assert_not_called() + self.assertEqual(request_context.task_id, provided_task_id) + self.assertEqual(request_context.context_id, provided_context_id) + if __name__ == '__main__': unittest.main() diff --git a/tests/server/events/test_event_queue.py b/tests/server/events/test_event_queue.py index 0ff966cc..96ded958 100644 --- a/tests/server/events/test_event_queue.py +++ b/tests/server/events/test_event_queue.py @@ -305,7 +305,7 @@ async def test_close_sets_flag_and_handles_internal_queue_new_python( async def test_close_graceful_py313_waits_for_join_and_children( event_queue: EventQueue, ) -> None: - """For Python >=3.13 and immediate=False, close should shutdown(False), then wait for join and children.""" + """For Python >=3.13 and immediate=False, close should shut down(False), then wait for join and children.""" with patch('sys.version_info', (3, 13, 0)): # Arrange from typing import cast diff --git a/tests/server/tasks/test_id_generator.py b/tests/server/tasks/test_id_generator.py new file mode 100644 index 00000000..11bfff2b --- /dev/null +++ b/tests/server/tasks/test_id_generator.py @@ -0,0 +1,131 @@ +import uuid + +import pytest + +from pydantic import ValidationError + +from a2a.server.id_generator import ( + IDGenerator, + IDGeneratorContext, + UUIDGenerator, +) + + +class TestIDGeneratorContext: + """Tests for IDGeneratorContext.""" + + def test_context_creation_with_all_fields(self): + """Test creating context with all fields populated.""" + context = IDGeneratorContext( + task_id='task_123', context_id='context_456' + ) + assert context.task_id == 'task_123' + assert context.context_id == 'context_456' + + def test_context_creation_with_defaults(self): + """Test creating context with default None values.""" + context = IDGeneratorContext() + assert context.task_id is None + assert context.context_id is None + + @pytest.mark.parametrize( + 'kwargs, expected_task_id, expected_context_id', + [ + ({'task_id': 'task_123'}, 'task_123', None), + ({'context_id': 'context_456'}, None, 'context_456'), + ], + ) + def test_context_creation_with_partial_fields( + self, kwargs, expected_task_id, expected_context_id + ): + """Test creating context with only some fields populated.""" + context = IDGeneratorContext(**kwargs) + assert context.task_id == expected_task_id + assert context.context_id == expected_context_id + + def test_context_mutability(self): + """Test that context fields can be updated (Pydantic models are mutable by default).""" + context = IDGeneratorContext(task_id='task_123') + context.task_id = 'task_456' + assert context.task_id == 'task_456' + + def test_context_validation(self): + """Test that context raises validation error for invalid types.""" + with pytest.raises(ValidationError): + IDGeneratorContext(task_id={'not': 'a string'}) + + +class TestIDGenerator: + """Tests for IDGenerator abstract base class.""" + + def test_cannot_instantiate_abstract_class(self): + """Test that IDGenerator cannot be instantiated directly.""" + with pytest.raises(TypeError): + IDGenerator() + + def test_subclass_must_implement_generate(self): + """Test that subclasses must implement the generate method.""" + + class IncompleteGenerator(IDGenerator): + pass + + with pytest.raises(TypeError): + IncompleteGenerator() + + def test_valid_subclass_implementation(self): + """Test that a valid subclass can be instantiated.""" + + class ValidGenerator(IDGenerator): # pylint: disable=C0115,R0903 + def generate(self, context: IDGeneratorContext) -> str: + return 'test_id' + + generator = ValidGenerator() + assert generator.generate(IDGeneratorContext()) == 'test_id' + + +@pytest.fixture +def generator(): + """Returns a UUIDGenerator instance.""" + return UUIDGenerator() + + +@pytest.fixture +def context(): + """Returns a IDGeneratorContext instance.""" + return IDGeneratorContext() + + +class TestUUIDGenerator: + """Tests for UUIDGenerator implementation.""" + + def test_generate_returns_string(self, generator, context): + """Test that generate returns a valid v4 UUID string.""" + result = generator.generate(context) + assert isinstance(result, str) + parsed_uuid = uuid.UUID(result) + assert parsed_uuid.version == 4 + + def test_generate_produces_unique_ids(self, generator, context): + """Test that multiple calls produce unique IDs.""" + ids = [generator.generate(context) for _ in range(100)] + # All IDs should be unique + assert len(ids) == len(set(ids)) + + @pytest.mark.parametrize( + 'context_arg', + [ + None, + IDGeneratorContext(), + ], + ids=[ + 'none_context', + 'empty_context', + ], + ) + def test_generate_works_with_various_contexts(self, context_arg): + """Test that generate works with various context inputs.""" + generator = UUIDGenerator() + result = generator.generate(context_arg) + assert isinstance(result, str) + parsed_uuid = uuid.UUID(result) + assert parsed_uuid.version == 4 diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 28acd27c..f3227d32 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -7,6 +7,10 @@ from a2a.types import ( Artifact, + AgentCard, + AgentCardSignature, + AgentCapabilities, + AgentSkill, Message, MessageSendParams, Part, @@ -23,6 +27,7 @@ build_text_artifact, create_task_obj, validate, + canonicalize_agent_card, ) @@ -45,6 +50,34 @@ 'type': 'task', } +SAMPLE_AGENT_CARD: dict[str, Any] = { + 'name': 'Test Agent', + 'description': 'A test agent', + 'url': 'http://localhost', + 'version': '1.0.0', + 'capabilities': AgentCapabilities( + streaming=None, + push_notifications=True, + ), + 'default_input_modes': ['text/plain'], + 'default_output_modes': ['text/plain'], + 'documentation_url': None, + 'icon_url': '', + 'skills': [ + AgentSkill( + id='skill1', + name='Test Skill', + description='A test skill', + tags=['test'], + ) + ], + 'signatures': [ + AgentCardSignature( + protected='protected_header', signature='test_signature' + ) + ], +} + # Test create_task_obj def test_create_task_obj(): @@ -328,3 +361,22 @@ def test_are_modalities_compatible_both_empty(): ) is True ) + + +def test_canonicalize_agent_card(): + """Test canonicalize_agent_card with defaults, optionals, and exceptions. + + - extensions is omitted as it's not set and optional. + - protocolVersion is included because it's always added by canonicalize_agent_card. + - signatures should be omitted. + """ + agent_card = AgentCard(**SAMPLE_AGENT_CARD) + expected_jcs = ( + '{"capabilities":{"pushNotifications":true},' + '"defaultInputModes":["text/plain"],"defaultOutputModes":["text/plain"],' + '"description":"A test agent","name":"Test Agent",' + '"skills":[{"description":"A test skill","id":"skill1","name":"Test Skill","tags":["test"]}],' + '"url":"http://localhost","version":"1.0.0"}' + ) + result = canonicalize_agent_card(agent_card) + assert result == expected_jcs diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index da54f833..f68d5c10 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -52,6 +52,7 @@ def sample_task(sample_message: types.Message) -> types.Task: ], ) ], + metadata={'source': 'test'}, ) @@ -107,6 +108,18 @@ def sample_agent_card() -> types.AgentCard: ) ), }, + signatures=[ + types.AgentCardSignature( + protected='protected_test', + signature='signature_test', + header={'alg': 'ES256'}, + ), + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header={'alg': 'ES256', 'kid': 'unique-key-identifier-123'}, + ), + ], ) @@ -508,3 +521,169 @@ def test_large_integer_roundtrip_with_utilities(self): assert final_result['nested']['another_large'] == 12345678901234567890 assert isinstance(final_result['nested']['another_large'], int) assert final_result['nested']['normal'] == 'text' + + def test_task_conversion_roundtrip( + self, sample_task: types.Task, sample_message: types.Message + ): + """Test conversion of Task to proto and back.""" + proto_task = proto_utils.ToProto.task(sample_task) + assert isinstance(proto_task, a2a_pb2.Task) + + roundtrip_task = proto_utils.FromProto.task(proto_task) + assert roundtrip_task.id == 'task-1' + assert roundtrip_task.context_id == 'ctx-1' + assert roundtrip_task.status == types.TaskStatus( + state=types.TaskState.working, message=sample_message + ) + assert roundtrip_task.history == sample_task.history + assert roundtrip_task.artifacts == [ + types.Artifact( + artifact_id='art-1', + description='', + metadata={}, + name='', + parts=[ + types.Part(root=types.TextPart(text='Artifact content')) + ], + ) + ] + assert roundtrip_task.metadata == {'source': 'test'} + + def test_agent_card_conversion_roundtrip( + self, sample_agent_card: types.AgentCard + ): + """Test conversion of AgentCard to proto and back.""" + proto_card = proto_utils.ToProto.agent_card(sample_agent_card) + assert isinstance(proto_card, a2a_pb2.AgentCard) + + roundtrip_card = proto_utils.FromProto.agent_card(proto_card) + assert roundtrip_card.name == 'Test Agent' + assert roundtrip_card.description == 'A test agent' + assert roundtrip_card.url == 'http://localhost' + assert roundtrip_card.version == '1.0.0' + assert roundtrip_card.capabilities == types.AgentCapabilities( + extensions=[], streaming=True, push_notifications=True + ) + assert roundtrip_card.default_input_modes == ['text/plain'] + assert roundtrip_card.default_output_modes == ['text/plain'] + assert roundtrip_card.skills == [ + types.AgentSkill( + id='skill1', + name='Test Skill', + description='A test skill', + tags=['test'], + examples=[], + input_modes=[], + output_modes=[], + ) + ] + assert roundtrip_card.provider == types.AgentProvider( + organization='Test Org', url='http://test.org' + ) + assert roundtrip_card.security == [{'oauth_scheme': ['read', 'write']}] + + # Normalized version of security_schemes. None fields are filled with defaults. + expected_security_schemes = { + 'oauth_scheme': types.SecurityScheme( + root=types.OAuth2SecurityScheme( + description='', + flows=types.OAuthFlows( + client_credentials=types.ClientCredentialsOAuthFlow( + refresh_url='', + scopes={ + 'write': 'Write access', + 'read': 'Read access', + }, + token_url='http://token.url', + ), + ), + ) + ), + 'apiKey': types.SecurityScheme( + root=types.APIKeySecurityScheme( + description='', + in_=types.In.header, + name='X-API-KEY', + ) + ), + 'httpAuth': types.SecurityScheme( + root=types.HTTPAuthSecurityScheme( + bearer_format='', + description='', + scheme='bearer', + ) + ), + 'oidc': types.SecurityScheme( + root=types.OpenIdConnectSecurityScheme( + description='', + open_id_connect_url='http://oidc.url', + ) + ), + } + assert roundtrip_card.security_schemes == expected_security_schemes + assert roundtrip_card.signatures == [ + types.AgentCardSignature( + protected='protected_test', + signature='signature_test', + header={'alg': 'ES256'}, + ), + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header={'alg': 'ES256', 'kid': 'unique-key-identifier-123'}, + ), + ] + + @pytest.mark.parametrize( + 'signature_data, expected_data', + [ + ( + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header={'alg': 'ES256'}, + ), + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header={'alg': 'ES256'}, + ), + ), + ( + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header=None, + ), + types.AgentCardSignature( + protected='protected_val', + signature='signature_val', + header={}, + ), + ), + ( + types.AgentCardSignature( + protected='', + signature='', + header={}, + ), + types.AgentCardSignature( + protected='', + signature='', + header={}, + ), + ), + ], + ) + def test_agent_card_signature_conversion_roundtrip( + self, signature_data, expected_data + ): + """Test conversion of AgentCardSignature to proto and back.""" + proto_signature = proto_utils.ToProto.agent_card_signature( + signature_data + ) + assert isinstance(proto_signature, a2a_pb2.AgentCardSignature) + roundtrip_signature = proto_utils.FromProto.agent_card_signature( + proto_signature + ) + assert roundtrip_signature == expected_data diff --git a/tests/utils/test_signing.py b/tests/utils/test_signing.py new file mode 100644 index 00000000..9a843d34 --- /dev/null +++ b/tests/utils/test_signing.py @@ -0,0 +1,185 @@ +from a2a.types import ( + AgentCard, + AgentCapabilities, + AgentSkill, +) +from a2a.types import ( + AgentCard, + AgentCapabilities, + AgentSkill, + AgentCardSignature, +) +from a2a.utils import signing +from typing import Any +from jwt.utils import base64url_encode + +import pytest +from cryptography.hazmat.primitives import asymmetric + + +def create_key_provider(verification_key: str | bytes | dict[str, Any]): + """Creates a key provider function for testing.""" + + def key_provider(kid: str | None, jku: str | None): + return verification_key + + return key_provider + + +# Fixture for a complete sample AgentCard +@pytest.fixture +def sample_agent_card() -> AgentCard: + return AgentCard( + name='Test Agent', + description='A test agent', + url='http://localhost', + version='1.0.0', + capabilities=AgentCapabilities( + streaming=None, + push_notifications=True, + ), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + documentation_url=None, + icon_url='', + skills=[ + AgentSkill( + id='skill1', + name='Test Skill', + description='A test skill', + tags=['test'], + ) + ], + ) + + +def test_signer_and_verifier_symmetric(sample_agent_card: AgentCard): + """Test the agent card signing and verification process with symmetric key encryption.""" + key = 'key12345' # Using a simple symmetric key for HS256 + wrong_key = 'wrongkey' + + agent_card_signer = signing.create_agent_card_signer( + signing_key=key, + protected_header={ + 'alg': 'HS384', + 'kid': 'key1', + 'jku': None, + 'typ': 'JOSE', + }, + ) + signed_card = agent_card_signer(sample_agent_card) + + assert signed_card.signatures is not None + assert len(signed_card.signatures) == 1 + signature = signed_card.signatures[0] + assert signature.protected is not None + assert signature.signature is not None + + # Verify the signature + verifier = signing.create_signature_verifier( + create_key_provider(key), ['HS256', 'HS384', 'ES256', 'RS256'] + ) + try: + verifier(signed_card) + except signing.InvalidSignaturesError: + pytest.fail('Signature verification failed with correct key') + + # Verify with wrong key + verifier_wrong_key = signing.create_signature_verifier( + create_key_provider(wrong_key), ['HS256', 'HS384', 'ES256', 'RS256'] + ) + with pytest.raises(signing.InvalidSignaturesError): + verifier_wrong_key(signed_card) + + +def test_signer_and_verifier_symmetric_multiple_signatures( + sample_agent_card: AgentCard, +): + """Test the agent card signing and verification process with symmetric key encryption. + This test adds a signatures to the AgentCard before signing.""" + encoded_header = base64url_encode( + b'{"alg": "HS256", "kid": "old_key"}' + ).decode('utf-8') + sample_agent_card.signatures = [ + AgentCardSignature(protected=encoded_header, signature='old_signature') + ] + key = 'key12345' # Using a simple symmetric key for HS256 + wrong_key = 'wrongkey' + + agent_card_signer = signing.create_agent_card_signer( + signing_key=key, + protected_header={ + 'alg': 'HS384', + 'kid': 'key1', + 'jku': None, + 'typ': 'JOSE', + }, + ) + signed_card = agent_card_signer(sample_agent_card) + + assert signed_card.signatures is not None + assert len(signed_card.signatures) == 2 + signature = signed_card.signatures[1] + assert signature.protected is not None + assert signature.signature is not None + + # Verify the signature + verifier = signing.create_signature_verifier( + create_key_provider(key), ['HS256', 'HS384', 'ES256', 'RS256'] + ) + try: + verifier(signed_card) + except signing.InvalidSignaturesError: + pytest.fail('Signature verification failed with correct key') + + # Verify with wrong key + verifier_wrong_key = signing.create_signature_verifier( + create_key_provider(wrong_key), ['HS256', 'HS384', 'ES256', 'RS256'] + ) + with pytest.raises(signing.InvalidSignaturesError): + verifier_wrong_key(signed_card) + + +def test_signer_and_verifier_asymmetric(sample_agent_card: AgentCard): + """Test the agent card signing and verification process with an asymmetric key encryption.""" + # Generate a dummy EC private key for ES256 + private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1()) + public_key = private_key.public_key() + # Generate another key pair for negative test + private_key_error = asymmetric.ec.generate_private_key( + asymmetric.ec.SECP256R1() + ) + public_key_error = private_key_error.public_key() + + agent_card_signer = signing.create_agent_card_signer( + signing_key=private_key, + protected_header={ + 'alg': 'ES256', + 'kid': 'key2', + 'jku': None, + 'typ': 'JOSE', + }, + ) + signed_card = agent_card_signer(sample_agent_card) + + assert signed_card.signatures is not None + assert len(signed_card.signatures) == 1 + signature = signed_card.signatures[0] + assert signature.protected is not None + assert signature.signature is not None + + verifier = signing.create_signature_verifier( + create_key_provider(public_key), ['HS256', 'HS384', 'ES256', 'RS256'] + ) + try: + verifier(signed_card) + except signing.InvalidSignaturesError: + pytest.fail('Signature verification failed with correct key') + + # Verify with wrong key + verifier_wrong_key = signing.create_signature_verifier( + create_key_provider(public_key_error), + ['HS256', 'HS384', 'ES256', 'RS256'], + ) + with pytest.raises(signing.InvalidSignaturesError): + verifier_wrong_key(signed_card)