|
5 | 5 |
|
6 | 6 | from mcp.client.auth.multi_protocol import ( |
7 | 7 | MultiProtocolAuthProvider, |
| 8 | + OAuthTokenStorageAdapter, |
8 | 9 | TokenStorage, |
9 | 10 | _credentials_to_storage, |
10 | 11 | _oauth_token_to_credentials, |
@@ -55,6 +56,32 @@ async def discover_metadata( |
55 | 56 | return None |
56 | 57 |
|
57 | 58 |
|
| 59 | +class _MockApiKeyProtocol: |
| 60 | + protocol_id = "api_key" |
| 61 | + protocol_version = "1.0" |
| 62 | + |
| 63 | + def __init__(self, api_key: str) -> None: |
| 64 | + self._api_key = api_key |
| 65 | + |
| 66 | + async def authenticate(self, context: AuthContext) -> AuthCredentials: |
| 67 | + return APIKeyCredentials(protocol_id="api_key", api_key=self._api_key) |
| 68 | + |
| 69 | + def prepare_request(self, request: httpx.Request, credentials: AuthCredentials) -> None: |
| 70 | + assert isinstance(credentials, APIKeyCredentials) |
| 71 | + request.headers["X-API-Key"] = credentials.api_key |
| 72 | + |
| 73 | + def validate_credentials(self, credentials: AuthCredentials) -> bool: |
| 74 | + return isinstance(credentials, APIKeyCredentials) and bool(credentials.api_key) |
| 75 | + |
| 76 | + async def discover_metadata( |
| 77 | + self, |
| 78 | + metadata_url: str | None = None, |
| 79 | + prm: ProtectedResourceMetadata | None = None, |
| 80 | + http_client: httpx.AsyncClient | None = None, |
| 81 | + ) -> AuthProtocolMetadata | None: |
| 82 | + return None |
| 83 | + |
| 84 | + |
58 | 85 | @pytest.fixture |
59 | 86 | def mock_storage() -> _MockStorage: |
60 | 87 | return _MockStorage() |
@@ -197,3 +224,196 @@ def test_prepare_request_no_op_when_protocol_missing( |
197 | 224 | creds = AuthCredentials(protocol_id="other") |
198 | 225 | provider._prepare_request(request, creds) |
199 | 226 | assert _MockProtocol._prepare_called is False |
| 227 | + |
| 228 | + |
| 229 | +@pytest.mark.anyio |
| 230 | +async def test_401_flow_falls_back_when_default_protocol_not_injected() -> None: |
| 231 | + """When server suggests default oauth2 but only api_key instance is injected, fallback to api_key and retry.""" |
| 232 | + requests: list[httpx.Request] = [] |
| 233 | + api_key = "demo-api-key-12345" |
| 234 | + |
| 235 | + def handler(request: httpx.Request) -> httpx.Response: |
| 236 | + requests.append(request) |
| 237 | + path = request.url.path |
| 238 | + url = str(request.url) |
| 239 | + |
| 240 | + if request.method == "GET" and "oauth-protected-resource" in path: |
| 241 | + prm = { |
| 242 | + "resource": "https://rs.example/mcp", |
| 243 | + "authorization_servers": ["https://as.example/"], |
| 244 | + "mcp_auth_protocols": [ |
| 245 | + {"protocol_id": "oauth2", "protocol_version": "2.0", "metadata_url": "https://as.example/.well-known/oauth-authorization-server"}, |
| 246 | + {"protocol_id": "api_key", "protocol_version": "1.0"}, |
| 247 | + {"protocol_id": "mutual_tls", "protocol_version": "1.0"}, |
| 248 | + ], |
| 249 | + } |
| 250 | + return httpx.Response(200, json=prm) |
| 251 | + |
| 252 | + if request.method == "GET" and path.endswith("/mcp/.well-known/authorization_servers"): |
| 253 | + return httpx.Response(404, text="not found") |
| 254 | + |
| 255 | + if request.method == "POST" and path == "/mcp": |
| 256 | + if request.headers.get("x-api-key") == api_key: |
| 257 | + return httpx.Response( |
| 258 | + 200, |
| 259 | + json={"jsonrpc": "2.0", "id": 1, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "rs", "version": "1.0"}}}, |
| 260 | + ) |
| 261 | + # 401 with multi-protocol hints |
| 262 | + www = ( |
| 263 | + 'Bearer error="invalid_token", ' |
| 264 | + 'resource_metadata="https://rs.example/.well-known/oauth-protected-resource/mcp", ' |
| 265 | + 'auth_protocols="oauth2 api_key mutual_tls", ' |
| 266 | + 'default_protocol="oauth2"' |
| 267 | + ) |
| 268 | + return httpx.Response(401, headers={"www-authenticate": www}, text="unauthorized") |
| 269 | + |
| 270 | + return httpx.Response(500, text=f"unexpected {request.method} {url}") |
| 271 | + |
| 272 | + transport = httpx.MockTransport(handler) |
| 273 | + storage = _MockStorage() |
| 274 | + proto = _MockApiKeyProtocol(api_key=api_key) |
| 275 | + |
| 276 | + async with httpx.AsyncClient(transport=transport) as client: |
| 277 | + provider = MultiProtocolAuthProvider( |
| 278 | + server_url="https://rs.example", |
| 279 | + storage=storage, |
| 280 | + protocols=[proto], |
| 281 | + http_client=client, |
| 282 | + ) |
| 283 | + client.auth = provider |
| 284 | + r = await client.post("https://rs.example/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "t", "version": "1.0"}}}) |
| 285 | + |
| 286 | + assert r.status_code == 200 |
| 287 | + # Must have retried POST /mcp with X-API-Key |
| 288 | + post_mcp = [req for req in requests if req.method == "POST" and req.url.path == "/mcp"] |
| 289 | + assert len(post_mcp) >= 2 |
| 290 | + assert any(req.headers.get("x-api-key") == api_key for req in post_mcp) |
| 291 | + |
| 292 | + |
| 293 | +@pytest.mark.anyio |
| 294 | +async def test_401_flow_does_not_leak_discovery_response_when_no_protocols_injected() -> None: |
| 295 | + """If no protocol instance is available, final response should correspond to original request (401), not discovery 404.""" |
| 296 | + seen: list[tuple[str, str]] = [] |
| 297 | + |
| 298 | + def handler(request: httpx.Request) -> httpx.Response: |
| 299 | + seen.append((request.method, request.url.path)) |
| 300 | + if request.method == "GET" and "oauth-protected-resource" in request.url.path: |
| 301 | + prm = { |
| 302 | + "resource": "https://rs.example/mcp", |
| 303 | + "authorization_servers": ["https://as.example/"], |
| 304 | + "mcp_auth_protocols": [ |
| 305 | + {"protocol_id": "oauth2", "protocol_version": "2.0", "metadata_url": "https://as.example/.well-known/oauth-authorization-server"}, |
| 306 | + {"protocol_id": "api_key", "protocol_version": "1.0"}, |
| 307 | + ], |
| 308 | + } |
| 309 | + return httpx.Response(200, json=prm) |
| 310 | + if request.method == "GET" and request.url.path.endswith("/mcp/.well-known/authorization_servers"): |
| 311 | + return httpx.Response(404, text="not found") |
| 312 | + if request.method == "POST" and request.url.path == "/mcp": |
| 313 | + www = ( |
| 314 | + 'Bearer error="invalid_token", ' |
| 315 | + 'resource_metadata="https://rs.example/.well-known/oauth-protected-resource/mcp", ' |
| 316 | + 'auth_protocols="oauth2 api_key", ' |
| 317 | + 'default_protocol="oauth2"' |
| 318 | + ) |
| 319 | + return httpx.Response(401, headers={"www-authenticate": www}, text="unauthorized") |
| 320 | + return httpx.Response(500) |
| 321 | + |
| 322 | + transport = httpx.MockTransport(handler) |
| 323 | + storage = _MockStorage() |
| 324 | + |
| 325 | + async with httpx.AsyncClient(transport=transport) as client: |
| 326 | + provider = MultiProtocolAuthProvider( |
| 327 | + server_url="https://rs.example", |
| 328 | + storage=storage, |
| 329 | + protocols=[], |
| 330 | + http_client=client, |
| 331 | + ) |
| 332 | + client.auth = provider |
| 333 | + r = await client.post("https://rs.example/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "t", "version": "1.0"}}}) |
| 334 | + |
| 335 | + assert r.status_code == 401 |
| 336 | + # We should have attempted discovery, but final response must not be the discovery 404. |
| 337 | + assert ("GET", "/mcp/.well-known/authorization_servers") in seen |
| 338 | + |
| 339 | + |
| 340 | +class _OAuthTokenOnlyMockStorage: |
| 341 | + """Minimal storage that only supports OAuthToken (dual contract: oauth2 side).""" |
| 342 | + |
| 343 | + def __init__(self) -> None: |
| 344 | + self._tokens: OAuthToken | None = None |
| 345 | + |
| 346 | + async def get_tokens(self) -> OAuthToken | None: |
| 347 | + return self._tokens |
| 348 | + |
| 349 | + async def set_tokens(self, tokens: OAuthToken) -> None: |
| 350 | + self._tokens = tokens |
| 351 | + |
| 352 | + |
| 353 | +@pytest.mark.anyio |
| 354 | +async def test_oauth_token_storage_adapter_get_tokens_returns_credentials_when_wrapped_has_token() -> None: |
| 355 | + """OAuthTokenStorageAdapter.get_tokens converts OAuthToken to OAuthCredentials.""" |
| 356 | + raw = OAuthToken( |
| 357 | + access_token="at", |
| 358 | + token_type="Bearer", |
| 359 | + expires_in=3600, |
| 360 | + scope="read", |
| 361 | + refresh_token="rt", |
| 362 | + ) |
| 363 | + wrapped = _OAuthTokenOnlyMockStorage() |
| 364 | + wrapped._tokens = raw |
| 365 | + adapter = OAuthTokenStorageAdapter(wrapped) |
| 366 | + |
| 367 | + result = await adapter.get_tokens() |
| 368 | + |
| 369 | + assert result is not None |
| 370 | + assert isinstance(result, OAuthCredentials) |
| 371 | + assert result.protocol_id == "oauth2" |
| 372 | + assert result.access_token == "at" |
| 373 | + assert result.refresh_token == "rt" |
| 374 | + |
| 375 | + |
| 376 | +@pytest.mark.anyio |
| 377 | +async def test_oauth_token_storage_adapter_set_tokens_stores_oauth_token_when_given_credentials() -> None: |
| 378 | + """OAuthTokenStorageAdapter.set_tokens converts OAuthCredentials to OAuthToken and stores.""" |
| 379 | + wrapped = _OAuthTokenOnlyMockStorage() |
| 380 | + adapter = OAuthTokenStorageAdapter(wrapped) |
| 381 | + creds = OAuthCredentials( |
| 382 | + protocol_id="oauth2", |
| 383 | + access_token="at", |
| 384 | + token_type="Bearer", |
| 385 | + refresh_token="rt", |
| 386 | + scope="read", |
| 387 | + expires_at=None, |
| 388 | + ) |
| 389 | + |
| 390 | + await adapter.set_tokens(creds) |
| 391 | + |
| 392 | + assert wrapped._tokens is not None |
| 393 | + assert wrapped._tokens.access_token == "at" |
| 394 | + assert wrapped._tokens.refresh_token == "rt" |
| 395 | + |
| 396 | + |
| 397 | +@pytest.mark.anyio |
| 398 | +async def test_get_credentials_returns_oauth_credentials_when_storage_returns_oauth_token() -> None: |
| 399 | + """MultiProtocolAuthProvider._get_credentials converts OAuthToken from storage to OAuthCredentials (dual contract).""" |
| 400 | + raw = OAuthToken( |
| 401 | + access_token="stored_at", |
| 402 | + token_type="Bearer", |
| 403 | + expires_in=3600, |
| 404 | + scope="read", |
| 405 | + ) |
| 406 | + storage = _MockStorage() |
| 407 | + storage._tokens = raw |
| 408 | + provider = MultiProtocolAuthProvider( |
| 409 | + server_url="https://example.com", |
| 410 | + storage=storage, |
| 411 | + protocols=[], |
| 412 | + ) |
| 413 | + provider._initialize() |
| 414 | + |
| 415 | + result = await provider._get_credentials() |
| 416 | + |
| 417 | + assert result is not None |
| 418 | + assert isinstance(result, OAuthCredentials) |
| 419 | + assert result.access_token == "stored_at" |
0 commit comments