2323import logging
2424import math
2525import time
26- from collections .abc import AsyncGenerator , Mapping
26+ from collections .abc import AsyncGenerator
2727from typing import Any , Protocol , cast
2828from urllib .parse import urljoin
2929
3636from mcp .client .auth .protocol import AuthContext , AuthProtocol , DPoPEnabledProtocol
3737from mcp .client .streamable_http import MCP_PROTOCOL_VERSION
3838from mcp .client .auth .utils import (
39- build_authorization_servers_discovery_urls ,
4039 build_protected_resource_metadata_discovery_urls ,
4140 create_oauth_metadata_request ,
4241 extract_auth_protocols_from_www_auth ,
4544 extract_protocol_preferences_from_www_auth ,
4645 extract_resource_metadata_from_www_auth ,
4746 extract_scope_from_www_auth ,
48- format_json_for_logging ,
4947 handle_protected_resource_response ,
5048)
5149from mcp .shared .auth import (
6260UNSPECIFIED_PROTOCOL_PREFERENCE : float = math .inf
6361
6462
65- class _DiscoveryResult :
66- """Mutable holder for 401 discovery results (PRM, protocols, attempted URLs)."""
67-
68- def __init__ (self ) -> None :
69- self .prm : ProtectedResourceMetadata | None = None
70- self .protocols_metadata : list [AuthProtocolMetadata ] = []
71- self .discovery_attempted_urls : list [str ] = []
72-
73-
74- def _build_protocol_candidates (
75- available : list [str ],
76- default_protocol : str | None ,
77- protocol_preferences : Mapping [str , float | int ] | None ,
78- ) -> list [str ]:
79- """Build ordered protocol candidate list: default first, then by preference, then rest."""
80- candidates : list [str ] = []
81- seen : set [str ] = set ()
82-
83- def push (pid : str | None ) -> None :
84- if not pid or pid in seen :
85- return
86- seen .add (pid )
87- candidates .append (pid )
88-
89- push (default_protocol )
90- if protocol_preferences :
91- for pid in sorted (
92- available ,
93- key = lambda p : protocol_preferences .get (p , UNSPECIFIED_PROTOCOL_PREFERENCE ),
94- ):
95- push (pid )
96- for pid in available :
97- push (pid )
98- return candidates
99-
100-
10163class TokenStorage (Protocol ):
10264 """
10365 凭证存储协议(multi_protocol 契约)。
@@ -316,66 +278,6 @@ async def _handle_403_response(
316278 if error or scope :
317279 logger .debug ("403 WWW-Authenticate: error=%s scope=%s" , error , scope )
318280
319- < << << << HEAD
320- == == == =
321- async def _run_401_discovery_requests (
322- self ,
323- resource_metadata_url : str | None ,
324- server_url : str ,
325- result : _DiscoveryResult ,
326- ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
327- """Run PRM + protocol discovery + OAuth fallback; yield discovery requests, store outcome in result."""
328- prm_urls = build_protected_resource_metadata_discovery_urls (resource_metadata_url , server_url )
329- for url in prm_urls :
330- logger .debug ("[Auth discovery] Trying PRM endpoint: %s" , url )
331- prm_req = create_oauth_metadata_request (url )
332- prm_resp = yield prm_req
333- result .prm = await handle_protected_resource_response (prm_resp )
334- if result .prm is not None :
335- break
336-
337- prm = result .prm
338- if prm and prm .mcp_auth_protocols :
339- protocol_ids = [m .protocol_id for m in prm .mcp_auth_protocols ]
340- auth_servers = [str (u ) for u in (prm .authorization_servers or [])]
341- logger .debug (
342- "[Auth discovery] Using PRM mcp_auth_protocols (priority 1): "
343- "protocol_ids=%s, authorization_servers=%s" ,
344- protocol_ids ,
345- auth_servers ,
346- )
347- result .protocols_metadata = list (prm .mcp_auth_protocols )
348- else :
349- discovery_urls = build_authorization_servers_discovery_urls (server_url )
350- for discovery_url in discovery_urls :
351- result .discovery_attempted_urls .append (discovery_url )
352- logger .debug ("[Auth discovery] Trying unified discovery endpoint: %s" , discovery_url )
353- discovery_req = create_oauth_metadata_request (discovery_url )
354- discovery_resp = yield discovery_req
355- result .protocols_metadata = await self ._parse_protocols_from_discovery_response (
356- discovery_resp , None
357- )
358- if result .protocols_metadata :
359- logger .debug ("Unified discovery succeeded at %s" , discovery_url )
360- break
361-
362- if not result .protocols_metadata and result .prm and result .prm .authorization_servers :
363- logger .debug ("Unified discovery failed, falling back to OAuth protocol discovery" )
364- oauth_protocol = self ._get_protocol ("oauth2" )
365- if oauth_protocol and hasattr (oauth_protocol , "discover_metadata" ):
366- try :
367- oauth_metadata = await oauth_protocol .discover_metadata (
368- metadata_url = None ,
369- prm = result .prm ,
370- http_client = self ._http_client ,
371- )
372- if oauth_metadata :
373- result .protocols_metadata = [oauth_metadata ]
374- logger .debug ("OAuth protocol discovery succeeded" )
375- except Exception as e :
376- logger .debug ("OAuth protocol discovery failed: %s" , e )
377-
378- >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
379281 async def async_auth_flow (
380282 self , request : httpx .Request
381283 ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
@@ -406,7 +308,6 @@ async def async_auth_flow(
406308 attempted_any = False
407309 last_auth_error : Exception | None = None
408310
409- < << << << HEAD
410311 # Step 1: PRM discovery (yield)
411312 prm : ProtectedResourceMetadata | None = None
412313 prm_urls = build_protected_resource_metadata_discovery_urls (
@@ -430,63 +331,27 @@ async def async_auth_flow(
430331 discovery_resp , prm
431332 )
432333
433- == == == =
434- discovery_result = _DiscoveryResult ()
435- discovery_gen = self ._run_401_discovery_requests (
436- resource_metadata_url , server_url , discovery_result
437- )
438- try :
439- req = await discovery_gen .__anext__ ()
440- except StopAsyncIteration :
441- pass
442- else :
443- while True :
444- resp = yield req
445- try :
446- req = await discovery_gen .asend (resp )
447- except StopAsyncIteration :
448- break
449-
450- prm = discovery_result .prm
451- protocols_metadata = discovery_result .protocols_metadata
452- discovery_attempted_urls = discovery_result .discovery_attempted_urls
453- >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
454334 available = (
455335 [m .protocol_id for m in protocols_metadata ]
456336 if protocols_metadata
457337 else (auth_protocols_header or [])
458338 )
459339 if not available :
460- < << << << HEAD
461340 logger .debug ("No available protocols from discovery or WWW-Authenticate" )
462341 else :
463342 # Select protocol candidates based on server hints, but only
464343 # attempt protocols that are actually injected as instances.
465344 candidates : list [str ] = []
466345 seen : set [str ] = set ()
467- == == == =
468- error_msg = (
469- f"Failed to discover authentication protocols. "
470- f"Tried URLs: { discovery_attempted_urls } . "
471- f"PRM available: { prm is not None } , "
472- f"PRM has authorization_servers: { bool (prm .authorization_servers if prm else False )} , "
473- f"WWW-Authenticate protocols: { auth_protocols_header } "
474- )
475- logger .error (error_msg )
476- raise RuntimeError (error_msg )
477- >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
478-
479- candidates = _build_protocol_candidates (
480- available , default_protocol , protocol_preferences
481- )
482- for selected_id in candidates :
483- protocol = self ._get_protocol (selected_id )
484- if protocol is None :
485- logger .debug ("Protocol %s not injected as instance; skipping" , selected_id )
486- continue
487- attempted_any = True
488-
489- < << << << HEAD
346+
347+ def _push (pid : str | None ) -> None :
348+ if not pid :
349+ return
350+ if pid in seen :
351+ return
352+ seen .add (pid )
353+ candidates .append (pid )
354+
490355 # Default protocol first (server recommendation)
491356 _push (default_protocol )
492357 # Then order by preferences if provided
@@ -516,58 +381,8 @@ async def async_auth_flow(
516381 for m in protocols_metadata :
517382 if m .protocol_id == selected_id :
518383 protocol_metadata = m
519- == == == =
520- protocol_metadata = None
521- if protocols_metadata :
522- for m in protocols_metadata :
523- if m .protocol_id == selected_id :
524- protocol_metadata = m
525- break
526-
527- try :
528- if selected_id == "oauth2" :
529- oauth_protocol = protocol
530- provider = OAuthClientProvider (
531- server_url = server_url ,
532- client_metadata = getattr (oauth_protocol , "_client_metadata" ),
533- storage = cast (OAuth2TokenStorage , self .storage ),
534- redirect_handler = getattr (oauth_protocol , "_redirect_handler" , None ),
535- callback_handler = getattr (oauth_protocol , "_callback_handler" , None ),
536- timeout = getattr (oauth_protocol , "_timeout" , self .timeout ),
537- client_metadata_url = getattr (oauth_protocol , "_client_metadata_url" , None ),
538- fixed_client_info = getattr (oauth_protocol , "_fixed_client_info" , None ),
539- )
540- provider .context .protocol_version = request .headers .get (MCP_PROTOCOL_VERSION )
541- oauth_gen = oauth_401_flow_generator (
542- provider , original_request , original_401_response , initial_prm = prm
543- )
544- auth_req = await oauth_gen .__anext__ ()
545- while True :
546- auth_resp = yield auth_req
547- try :
548- auth_req = await oauth_gen .asend (auth_resp )
549- except StopAsyncIteration :
550- >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
551384 break
552- else :
553- context = AuthContext (
554- server_url = server_url ,
555- storage = self .storage ,
556- protocol_id = selected_id ,
557- protocol_metadata = protocol_metadata ,
558- current_credentials = None ,
559- dpop_storage = self .dpop_storage ,
560- dpop_enabled = self .dpop_enabled ,
561- http_client = self ._http_client ,
562- resource_metadata_url = resource_metadata_url ,
563- protected_resource_metadata = prm ,
564- scope_from_www_auth = extract_scope_from_www_auth (original_401_response ),
565- )
566- credentials = await protocol .authenticate (context )
567- to_store = _credentials_to_storage (credentials )
568- await self .storage .set_tokens (to_store )
569385
570- < << << << HEAD
571386 try :
572387 if selected_id == "oauth2" :
573388 # OAuth: drive shared generator (single client, yield)
@@ -624,13 +439,6 @@ async def async_auth_flow(
624439 "Protocol %s authentication failed: %s" , selected_id , e
625440 )
626441 continue
627- == == == =
628- break
629- except Exception as e :
630- last_auth_error = e
631- logger .debug ("Protocol %s authentication failed: %s" , selected_id , e )
632- continue
633- >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
634442
635443 credentials = await self ._get_credentials ()
636444 if credentials and self ._is_credentials_valid (credentials ):
0 commit comments