Skip to content

Commit 571edfe

Browse files
committed
test(auth): cover protocol fallback and discovery leakage
Add regression tests ensuring 401 handling falls back when the server default protocol isn't injected, and that discovery responses are never returned as the business request response.
1 parent e9338fc commit 571edfe

1 file changed

Lines changed: 220 additions & 0 deletions

File tree

tests/client/test_multi_protocol_provider.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from mcp.client.auth.multi_protocol import (
77
MultiProtocolAuthProvider,
8+
OAuthTokenStorageAdapter,
89
TokenStorage,
910
_credentials_to_storage,
1011
_oauth_token_to_credentials,
@@ -55,6 +56,32 @@ async def discover_metadata(
5556
return None
5657

5758

59+
class _MockApiKeyProtocol:
60+
protocol_id = "api_key"
61+
protocol_version = "1.0"
62+
63+
def __init__(self, api_key: str) -> None:
64+
self._api_key = api_key
65+
66+
async def authenticate(self, context: AuthContext) -> AuthCredentials:
67+
return APIKeyCredentials(protocol_id="api_key", api_key=self._api_key)
68+
69+
def prepare_request(self, request: httpx.Request, credentials: AuthCredentials) -> None:
70+
assert isinstance(credentials, APIKeyCredentials)
71+
request.headers["X-API-Key"] = credentials.api_key
72+
73+
def validate_credentials(self, credentials: AuthCredentials) -> bool:
74+
return isinstance(credentials, APIKeyCredentials) and bool(credentials.api_key)
75+
76+
async def discover_metadata(
77+
self,
78+
metadata_url: str | None = None,
79+
prm: ProtectedResourceMetadata | None = None,
80+
http_client: httpx.AsyncClient | None = None,
81+
) -> AuthProtocolMetadata | None:
82+
return None
83+
84+
5885
@pytest.fixture
5986
def mock_storage() -> _MockStorage:
6087
return _MockStorage()
@@ -197,3 +224,196 @@ def test_prepare_request_no_op_when_protocol_missing(
197224
creds = AuthCredentials(protocol_id="other")
198225
provider._prepare_request(request, creds)
199226
assert _MockProtocol._prepare_called is False
227+
228+
229+
@pytest.mark.anyio
230+
async def test_401_flow_falls_back_when_default_protocol_not_injected() -> None:
231+
"""When server suggests default oauth2 but only api_key instance is injected, fallback to api_key and retry."""
232+
requests: list[httpx.Request] = []
233+
api_key = "demo-api-key-12345"
234+
235+
def handler(request: httpx.Request) -> httpx.Response:
236+
requests.append(request)
237+
path = request.url.path
238+
url = str(request.url)
239+
240+
if request.method == "GET" and "oauth-protected-resource" in path:
241+
prm = {
242+
"resource": "https://rs.example/mcp",
243+
"authorization_servers": ["https://as.example/"],
244+
"mcp_auth_protocols": [
245+
{"protocol_id": "oauth2", "protocol_version": "2.0", "metadata_url": "https://as.example/.well-known/oauth-authorization-server"},
246+
{"protocol_id": "api_key", "protocol_version": "1.0"},
247+
{"protocol_id": "mutual_tls", "protocol_version": "1.0"},
248+
],
249+
}
250+
return httpx.Response(200, json=prm)
251+
252+
if request.method == "GET" and path.endswith("/mcp/.well-known/authorization_servers"):
253+
return httpx.Response(404, text="not found")
254+
255+
if request.method == "POST" and path == "/mcp":
256+
if request.headers.get("x-api-key") == api_key:
257+
return httpx.Response(
258+
200,
259+
json={"jsonrpc": "2.0", "id": 1, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "rs", "version": "1.0"}}},
260+
)
261+
# 401 with multi-protocol hints
262+
www = (
263+
'Bearer error="invalid_token", '
264+
'resource_metadata="https://rs.example/.well-known/oauth-protected-resource/mcp", '
265+
'auth_protocols="oauth2 api_key mutual_tls", '
266+
'default_protocol="oauth2"'
267+
)
268+
return httpx.Response(401, headers={"www-authenticate": www}, text="unauthorized")
269+
270+
return httpx.Response(500, text=f"unexpected {request.method} {url}")
271+
272+
transport = httpx.MockTransport(handler)
273+
storage = _MockStorage()
274+
proto = _MockApiKeyProtocol(api_key=api_key)
275+
276+
async with httpx.AsyncClient(transport=transport) as client:
277+
provider = MultiProtocolAuthProvider(
278+
server_url="https://rs.example",
279+
storage=storage,
280+
protocols=[proto],
281+
http_client=client,
282+
)
283+
client.auth = provider
284+
r = await client.post("https://rs.example/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "t", "version": "1.0"}}})
285+
286+
assert r.status_code == 200
287+
# Must have retried POST /mcp with X-API-Key
288+
post_mcp = [req for req in requests if req.method == "POST" and req.url.path == "/mcp"]
289+
assert len(post_mcp) >= 2
290+
assert any(req.headers.get("x-api-key") == api_key for req in post_mcp)
291+
292+
293+
@pytest.mark.anyio
294+
async def test_401_flow_does_not_leak_discovery_response_when_no_protocols_injected() -> None:
295+
"""If no protocol instance is available, final response should correspond to original request (401), not discovery 404."""
296+
seen: list[tuple[str, str]] = []
297+
298+
def handler(request: httpx.Request) -> httpx.Response:
299+
seen.append((request.method, request.url.path))
300+
if request.method == "GET" and "oauth-protected-resource" in request.url.path:
301+
prm = {
302+
"resource": "https://rs.example/mcp",
303+
"authorization_servers": ["https://as.example/"],
304+
"mcp_auth_protocols": [
305+
{"protocol_id": "oauth2", "protocol_version": "2.0", "metadata_url": "https://as.example/.well-known/oauth-authorization-server"},
306+
{"protocol_id": "api_key", "protocol_version": "1.0"},
307+
],
308+
}
309+
return httpx.Response(200, json=prm)
310+
if request.method == "GET" and request.url.path.endswith("/mcp/.well-known/authorization_servers"):
311+
return httpx.Response(404, text="not found")
312+
if request.method == "POST" and request.url.path == "/mcp":
313+
www = (
314+
'Bearer error="invalid_token", '
315+
'resource_metadata="https://rs.example/.well-known/oauth-protected-resource/mcp", '
316+
'auth_protocols="oauth2 api_key", '
317+
'default_protocol="oauth2"'
318+
)
319+
return httpx.Response(401, headers={"www-authenticate": www}, text="unauthorized")
320+
return httpx.Response(500)
321+
322+
transport = httpx.MockTransport(handler)
323+
storage = _MockStorage()
324+
325+
async with httpx.AsyncClient(transport=transport) as client:
326+
provider = MultiProtocolAuthProvider(
327+
server_url="https://rs.example",
328+
storage=storage,
329+
protocols=[],
330+
http_client=client,
331+
)
332+
client.auth = provider
333+
r = await client.post("https://rs.example/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "t", "version": "1.0"}}})
334+
335+
assert r.status_code == 401
336+
# We should have attempted discovery, but final response must not be the discovery 404.
337+
assert ("GET", "/mcp/.well-known/authorization_servers") in seen
338+
339+
340+
class _OAuthTokenOnlyMockStorage:
341+
"""Minimal storage that only supports OAuthToken (dual contract: oauth2 side)."""
342+
343+
def __init__(self) -> None:
344+
self._tokens: OAuthToken | None = None
345+
346+
async def get_tokens(self) -> OAuthToken | None:
347+
return self._tokens
348+
349+
async def set_tokens(self, tokens: OAuthToken) -> None:
350+
self._tokens = tokens
351+
352+
353+
@pytest.mark.anyio
354+
async def test_oauth_token_storage_adapter_get_tokens_returns_credentials_when_wrapped_has_token() -> None:
355+
"""OAuthTokenStorageAdapter.get_tokens converts OAuthToken to OAuthCredentials."""
356+
raw = OAuthToken(
357+
access_token="at",
358+
token_type="Bearer",
359+
expires_in=3600,
360+
scope="read",
361+
refresh_token="rt",
362+
)
363+
wrapped = _OAuthTokenOnlyMockStorage()
364+
wrapped._tokens = raw
365+
adapter = OAuthTokenStorageAdapter(wrapped)
366+
367+
result = await adapter.get_tokens()
368+
369+
assert result is not None
370+
assert isinstance(result, OAuthCredentials)
371+
assert result.protocol_id == "oauth2"
372+
assert result.access_token == "at"
373+
assert result.refresh_token == "rt"
374+
375+
376+
@pytest.mark.anyio
377+
async def test_oauth_token_storage_adapter_set_tokens_stores_oauth_token_when_given_credentials() -> None:
378+
"""OAuthTokenStorageAdapter.set_tokens converts OAuthCredentials to OAuthToken and stores."""
379+
wrapped = _OAuthTokenOnlyMockStorage()
380+
adapter = OAuthTokenStorageAdapter(wrapped)
381+
creds = OAuthCredentials(
382+
protocol_id="oauth2",
383+
access_token="at",
384+
token_type="Bearer",
385+
refresh_token="rt",
386+
scope="read",
387+
expires_at=None,
388+
)
389+
390+
await adapter.set_tokens(creds)
391+
392+
assert wrapped._tokens is not None
393+
assert wrapped._tokens.access_token == "at"
394+
assert wrapped._tokens.refresh_token == "rt"
395+
396+
397+
@pytest.mark.anyio
398+
async def test_get_credentials_returns_oauth_credentials_when_storage_returns_oauth_token() -> None:
399+
"""MultiProtocolAuthProvider._get_credentials converts OAuthToken from storage to OAuthCredentials (dual contract)."""
400+
raw = OAuthToken(
401+
access_token="stored_at",
402+
token_type="Bearer",
403+
expires_in=3600,
404+
scope="read",
405+
)
406+
storage = _MockStorage()
407+
storage._tokens = raw
408+
provider = MultiProtocolAuthProvider(
409+
server_url="https://example.com",
410+
storage=storage,
411+
protocols=[],
412+
)
413+
provider._initialize()
414+
415+
result = await provider._get_credentials()
416+
417+
assert result is not None
418+
assert isinstance(result, OAuthCredentials)
419+
assert result.access_token == "stored_at"

0 commit comments

Comments
 (0)