Skip to content

Commit e9338fc

Browse files
committed
fix(auth): fallback to injected protocols on 401
Prefer only locally injected protocol instances when selecting the auth protocol, and ensure the final response corresponds to the original request (avoid leaking discovery responses).
1 parent 4589cad commit e9338fc

1 file changed

Lines changed: 72 additions & 15 deletions

File tree

src/mcp/client/auth/multi_protocol.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import json
2222
import logging
23+
import math
2324
import time
2425
from collections.abc import AsyncGenerator
2526
from typing import Any, Protocol, cast
@@ -32,7 +33,6 @@
3233
from mcp.client.auth._oauth_401_flow import oauth_401_flow_generator
3334
from mcp.client.auth.oauth2 import OAuthClientProvider, TokenStorage as OAuth2TokenStorage
3435
from mcp.client.auth.protocol import AuthContext, AuthProtocol, DPoPEnabledProtocol
35-
from mcp.client.auth.registry import AuthProtocolRegistry
3636
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3737
from mcp.client.auth.utils import (
3838
build_protected_resource_metadata_discovery_urls,
@@ -55,6 +55,9 @@
5555

5656
logger = logging.getLogger(__name__)
5757

58+
# Protocol preferences: any protocol without an explicit preference should sort last.
59+
UNSPECIFIED_PROTOCOL_PREFERENCE: float = math.inf
60+
5861

5962
class TokenStorage(Protocol):
6063
"""凭证存储协议(兼容 OAuthToken 与 AuthCredentials)。"""
@@ -248,12 +251,16 @@ async def async_auth_flow(
248251
response = yield request
249252

250253
if response.status_code == 401:
254+
original_request = request
255+
original_401_response = response
251256
async with self._lock:
252257
resource_metadata_url = extract_resource_metadata_from_www_auth(response)
253258
auth_protocols_header = extract_auth_protocols_from_www_auth(response)
254259
default_protocol = extract_default_protocol_from_www_auth(response)
255260
protocol_preferences = extract_protocol_preferences_from_www_auth(response)
256261
server_url = str(request.url)
262+
attempted_any = False
263+
last_auth_error: Exception | None = None
257264

258265
# Step 1: PRM discovery (yield)
259266
prm: ProtectedResourceMetadata | None = None
@@ -286,19 +293,51 @@ async def async_auth_flow(
286293
if not available:
287294
logger.debug("No available protocols from discovery or WWW-Authenticate")
288295
else:
289-
selected_id = AuthProtocolRegistry.select_protocol(
290-
available, default_protocol, protocol_preferences
291-
)
292-
if selected_id:
296+
# Select protocol candidates based on server hints, but only
297+
# attempt protocols that are actually injected as instances.
298+
candidates: list[str] = []
299+
seen: set[str] = set()
300+
301+
def _push(pid: str | None) -> None:
302+
if not pid:
303+
return
304+
if pid in seen:
305+
return
306+
seen.add(pid)
307+
candidates.append(pid)
308+
309+
# Default protocol first (server recommendation)
310+
_push(default_protocol)
311+
# Then order by preferences if provided
312+
if protocol_preferences:
313+
for pid in sorted(
314+
available,
315+
key=lambda p: protocol_preferences.get(
316+
p, UNSPECIFIED_PROTOCOL_PREFERENCE
317+
),
318+
):
319+
_push(pid)
320+
# Then remaining in server-provided order
321+
for pid in available:
322+
_push(pid)
323+
324+
for selected_id in candidates:
293325
protocol = self._get_protocol(selected_id)
294-
if protocol:
295-
protocol_metadata = None
296-
if protocols_metadata:
297-
for m in protocols_metadata:
298-
if m.protocol_id == selected_id:
299-
protocol_metadata = m
300-
break
301-
326+
if protocol is None:
327+
logger.debug(
328+
"Protocol %s not injected as instance; skipping", selected_id
329+
)
330+
continue
331+
attempted_any = True
332+
333+
protocol_metadata = None
334+
if protocols_metadata:
335+
for m in protocols_metadata:
336+
if m.protocol_id == selected_id:
337+
protocol_metadata = m
338+
break
339+
340+
try:
302341
if selected_id == "oauth2":
303342
# OAuth: drive shared generator (single client, yield)
304343
oauth_protocol = protocol
@@ -325,7 +364,7 @@ async def async_auth_flow(
325364
MCP_PROTOCOL_VERSION
326365
)
327366
gen = oauth_401_flow_generator(
328-
provider, request, response, initial_prm=prm
367+
provider, original_request, original_401_response, initial_prm=prm
329368
)
330369
auth_req = await gen.__anext__()
331370
while True:
@@ -348,17 +387,35 @@ async def async_auth_flow(
348387
resource_metadata_url=resource_metadata_url,
349388
protected_resource_metadata=prm,
350389
scope_from_www_auth=extract_scope_from_www_auth(
351-
response
390+
original_401_response
352391
),
353392
)
354393
credentials = await protocol.authenticate(context)
355394
to_store = _credentials_to_storage(credentials)
356395
await self.storage.set_tokens(to_store)
357396

397+
# Stop after first successful protocol path that stores credentials
398+
break
399+
except Exception as e:
400+
last_auth_error = e
401+
logger.debug(
402+
"Protocol %s authentication failed: %s", selected_id, e
403+
)
404+
continue
405+
358406
credentials = await self._get_credentials()
359407
if credentials and self._is_credentials_valid(credentials):
360408
await self._ensure_dpop_initialized(credentials)
361409
self._prepare_request(request, credentials)
362410
response = yield request
411+
else:
412+
if attempted_any and last_auth_error is not None:
413+
# If we did attempt an injected protocol and it failed, surface the error
414+
# instead of returning a potentially confusing 401.
415+
raise last_auth_error
416+
# Ensure we do not leak discovery responses as the final response:
417+
# retry the original request once without new credentials so the
418+
# caller receives a response corresponding to the original request.
419+
response = yield original_request
363420
elif response.status_code == 403:
364421
await self._handle_403_response(response, request)

0 commit comments

Comments
 (0)