2020
2121import json
2222import logging
23+ import math
2324import time
2425from collections .abc import AsyncGenerator
2526from typing import Any , Protocol , cast
3233from mcp .client .auth ._oauth_401_flow import oauth_401_flow_generator
3334from mcp .client .auth .oauth2 import OAuthClientProvider , TokenStorage as OAuth2TokenStorage
3435from mcp .client .auth .protocol import AuthContext , AuthProtocol , DPoPEnabledProtocol
35- from mcp .client .auth .registry import AuthProtocolRegistry
3636from mcp .client .streamable_http import MCP_PROTOCOL_VERSION
3737from mcp .client .auth .utils import (
3838 build_protected_resource_metadata_discovery_urls ,
5555
5656logger = logging .getLogger (__name__ )
5757
58+ # Protocol preferences: any protocol without an explicit preference should sort last.
59+ UNSPECIFIED_PROTOCOL_PREFERENCE : float = math .inf
60+
5861
5962class TokenStorage (Protocol ):
6063 """凭证存储协议(兼容 OAuthToken 与 AuthCredentials)。"""
@@ -248,12 +251,16 @@ async def async_auth_flow(
248251 response = yield request
249252
250253 if response .status_code == 401 :
254+ original_request = request
255+ original_401_response = response
251256 async with self ._lock :
252257 resource_metadata_url = extract_resource_metadata_from_www_auth (response )
253258 auth_protocols_header = extract_auth_protocols_from_www_auth (response )
254259 default_protocol = extract_default_protocol_from_www_auth (response )
255260 protocol_preferences = extract_protocol_preferences_from_www_auth (response )
256261 server_url = str (request .url )
262+ attempted_any = False
263+ last_auth_error : Exception | None = None
257264
258265 # Step 1: PRM discovery (yield)
259266 prm : ProtectedResourceMetadata | None = None
@@ -286,19 +293,51 @@ async def async_auth_flow(
286293 if not available :
287294 logger .debug ("No available protocols from discovery or WWW-Authenticate" )
288295 else :
289- selected_id = AuthProtocolRegistry .select_protocol (
290- available , default_protocol , protocol_preferences
291- )
292- if selected_id :
296+ # Select protocol candidates based on server hints, but only
297+ # attempt protocols that are actually injected as instances.
298+ candidates : list [str ] = []
299+ seen : set [str ] = set ()
300+
301+ def _push (pid : str | None ) -> None :
302+ if not pid :
303+ return
304+ if pid in seen :
305+ return
306+ seen .add (pid )
307+ candidates .append (pid )
308+
309+ # Default protocol first (server recommendation)
310+ _push (default_protocol )
311+ # Then order by preferences if provided
312+ if protocol_preferences :
313+ for pid in sorted (
314+ available ,
315+ key = lambda p : protocol_preferences .get (
316+ p , UNSPECIFIED_PROTOCOL_PREFERENCE
317+ ),
318+ ):
319+ _push (pid )
320+ # Then remaining in server-provided order
321+ for pid in available :
322+ _push (pid )
323+
324+ for selected_id in candidates :
293325 protocol = self ._get_protocol (selected_id )
294- if protocol :
295- protocol_metadata = None
296- if protocols_metadata :
297- for m in protocols_metadata :
298- if m .protocol_id == selected_id :
299- protocol_metadata = m
300- break
301-
326+ if protocol is None :
327+ logger .debug (
328+ "Protocol %s not injected as instance; skipping" , selected_id
329+ )
330+ continue
331+ attempted_any = True
332+
333+ protocol_metadata = None
334+ if protocols_metadata :
335+ for m in protocols_metadata :
336+ if m .protocol_id == selected_id :
337+ protocol_metadata = m
338+ break
339+
340+ try :
302341 if selected_id == "oauth2" :
303342 # OAuth: drive shared generator (single client, yield)
304343 oauth_protocol = protocol
@@ -325,7 +364,7 @@ async def async_auth_flow(
325364 MCP_PROTOCOL_VERSION
326365 )
327366 gen = oauth_401_flow_generator (
328- provider , request , response , initial_prm = prm
367+ provider , original_request , original_401_response , initial_prm = prm
329368 )
330369 auth_req = await gen .__anext__ ()
331370 while True :
@@ -348,17 +387,35 @@ async def async_auth_flow(
348387 resource_metadata_url = resource_metadata_url ,
349388 protected_resource_metadata = prm ,
350389 scope_from_www_auth = extract_scope_from_www_auth (
351- response
390+ original_401_response
352391 ),
353392 )
354393 credentials = await protocol .authenticate (context )
355394 to_store = _credentials_to_storage (credentials )
356395 await self .storage .set_tokens (to_store )
357396
397+ # Stop after first successful protocol path that stores credentials
398+ break
399+ except Exception as e :
400+ last_auth_error = e
401+ logger .debug (
402+ "Protocol %s authentication failed: %s" , selected_id , e
403+ )
404+ continue
405+
358406 credentials = await self ._get_credentials ()
359407 if credentials and self ._is_credentials_valid (credentials ):
360408 await self ._ensure_dpop_initialized (credentials )
361409 self ._prepare_request (request , credentials )
362410 response = yield request
411+ else :
412+ if attempted_any and last_auth_error is not None :
413+ # If we did attempt an injected protocol and it failed, surface the error
414+ # instead of returning a potentially confusing 401.
415+ raise last_auth_error
416+ # Ensure we do not leak discovery responses as the final response:
417+ # retry the original request once without new credentials so the
418+ # caller receives a response corresponding to the original request.
419+ response = yield original_request
363420 elif response .status_code == 403 :
364421 await self ._handle_403_response (response , request )
0 commit comments