2323import logging
2424import math
2525import time
26- from collections .abc import AsyncGenerator
26+ from collections .abc import AsyncGenerator , Mapping
2727from typing import Any , Protocol , cast
2828from urllib .parse import urljoin
2929
3636from mcp .client .auth .protocol import AuthContext , AuthProtocol , DPoPEnabledProtocol
3737from mcp .client .streamable_http import MCP_PROTOCOL_VERSION
3838from mcp .client .auth .utils import (
39+ build_authorization_servers_discovery_urls ,
3940 build_protected_resource_metadata_discovery_urls ,
4041 create_oauth_metadata_request ,
4142 extract_auth_protocols_from_www_auth ,
4445 extract_protocol_preferences_from_www_auth ,
4546 extract_resource_metadata_from_www_auth ,
4647 extract_scope_from_www_auth ,
48+ format_json_for_logging ,
4749 handle_protected_resource_response ,
4850)
4951from mcp .shared .auth import (
6062UNSPECIFIED_PROTOCOL_PREFERENCE : float = math .inf
6163
6264
65+ class _DiscoveryResult :
66+ """Mutable holder for 401 discovery results (PRM, protocols, attempted URLs)."""
67+
68+ def __init__ (self ) -> None :
69+ self .prm : ProtectedResourceMetadata | None = None
70+ self .protocols_metadata : list [AuthProtocolMetadata ] = []
71+ self .discovery_attempted_urls : list [str ] = []
72+
73+
74+ def _build_protocol_candidates (
75+ available : list [str ],
76+ default_protocol : str | None ,
77+ protocol_preferences : Mapping [str , float | int ] | None ,
78+ ) -> list [str ]:
79+ """Build ordered protocol candidate list: default first, then by preference, then rest."""
80+ candidates : list [str ] = []
81+ seen : set [str ] = set ()
82+
83+ def push (pid : str | None ) -> None :
84+ if not pid or pid in seen :
85+ return
86+ seen .add (pid )
87+ candidates .append (pid )
88+
89+ push (default_protocol )
90+ if protocol_preferences :
91+ for pid in sorted (
92+ available ,
93+ key = lambda p : protocol_preferences .get (p , UNSPECIFIED_PROTOCOL_PREFERENCE ),
94+ ):
95+ push (pid )
96+ for pid in available :
97+ push (pid )
98+ return candidates
99+
100+
63101class TokenStorage (Protocol ):
64102 """
65103 凭证存储协议(multi_protocol 契约)。
@@ -278,6 +316,66 @@ async def _handle_403_response(
278316 if error or scope :
279317 logger .debug ("403 WWW-Authenticate: error=%s scope=%s" , error , scope )
280318
319+ < << << << HEAD
320+ == == == =
321+ async def _run_401_discovery_requests (
322+ self ,
323+ resource_metadata_url : str | None ,
324+ server_url : str ,
325+ result : _DiscoveryResult ,
326+ ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
327+ """Run PRM + protocol discovery + OAuth fallback; yield discovery requests, store outcome in result."""
328+ prm_urls = build_protected_resource_metadata_discovery_urls (resource_metadata_url , server_url )
329+ for url in prm_urls :
330+ logger .debug ("[Auth discovery] Trying PRM endpoint: %s" , url )
331+ prm_req = create_oauth_metadata_request (url )
332+ prm_resp = yield prm_req
333+ result .prm = await handle_protected_resource_response (prm_resp )
334+ if result .prm is not None :
335+ break
336+
337+ prm = result .prm
338+ if prm and prm .mcp_auth_protocols :
339+ protocol_ids = [m .protocol_id for m in prm .mcp_auth_protocols ]
340+ auth_servers = [str (u ) for u in (prm .authorization_servers or [])]
341+ logger .debug (
342+ "[Auth discovery] Using PRM mcp_auth_protocols (priority 1): "
343+ "protocol_ids=%s, authorization_servers=%s" ,
344+ protocol_ids ,
345+ auth_servers ,
346+ )
347+ result .protocols_metadata = list (prm .mcp_auth_protocols )
348+ else :
349+ discovery_urls = build_authorization_servers_discovery_urls (server_url )
350+ for discovery_url in discovery_urls :
351+ result .discovery_attempted_urls .append (discovery_url )
352+ logger .debug ("[Auth discovery] Trying unified discovery endpoint: %s" , discovery_url )
353+ discovery_req = create_oauth_metadata_request (discovery_url )
354+ discovery_resp = yield discovery_req
355+ result .protocols_metadata = await self ._parse_protocols_from_discovery_response (
356+ discovery_resp , None
357+ )
358+ if result .protocols_metadata :
359+ logger .debug ("Unified discovery succeeded at %s" , discovery_url )
360+ break
361+
362+ if not result .protocols_metadata and result .prm and result .prm .authorization_servers :
363+ logger .debug ("Unified discovery failed, falling back to OAuth protocol discovery" )
364+ oauth_protocol = self ._get_protocol ("oauth2" )
365+ if oauth_protocol and hasattr (oauth_protocol , "discover_metadata" ):
366+ try :
367+ oauth_metadata = await oauth_protocol .discover_metadata (
368+ metadata_url = None ,
369+ prm = result .prm ,
370+ http_client = self ._http_client ,
371+ )
372+ if oauth_metadata :
373+ result .protocols_metadata = [oauth_metadata ]
374+ logger .debug ("OAuth protocol discovery succeeded" )
375+ except Exception as e :
376+ logger .debug ("OAuth protocol discovery failed: %s" , e )
377+
378+ >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
281379 async def async_auth_flow (
282380 self , request : httpx .Request
283381 ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
@@ -308,6 +406,7 @@ async def async_auth_flow(
308406 attempted_any = False
309407 last_auth_error : Exception | None = None
310408
409+ < << << << HEAD
311410 # Step 1: PRM discovery (yield)
312411 prm : ProtectedResourceMetadata | None = None
313412 prm_urls = build_protected_resource_metadata_discovery_urls (
@@ -331,27 +430,63 @@ async def async_auth_flow(
331430 discovery_resp , prm
332431 )
333432
433+ == == == =
434+ discovery_result = _DiscoveryResult ()
435+ discovery_gen = self ._run_401_discovery_requests (
436+ resource_metadata_url , server_url , discovery_result
437+ )
438+ try :
439+ req = await discovery_gen .__anext__ ()
440+ except StopAsyncIteration :
441+ pass
442+ else :
443+ while True :
444+ resp = yield req
445+ try :
446+ req = await discovery_gen .asend (resp )
447+ except StopAsyncIteration :
448+ break
449+
450+ prm = discovery_result .prm
451+ protocols_metadata = discovery_result .protocols_metadata
452+ discovery_attempted_urls = discovery_result .discovery_attempted_urls
453+ >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
334454 available = (
335455 [m .protocol_id for m in protocols_metadata ]
336456 if protocols_metadata
337457 else (auth_protocols_header or [])
338458 )
339459 if not available :
460+ < << << << HEAD
340461 logger .debug ("No available protocols from discovery or WWW-Authenticate" )
341462 else :
342463 # Select protocol candidates based on server hints, but only
343464 # attempt protocols that are actually injected as instances.
344465 candidates : list [str ] = []
345466 seen : set [str ] = set ()
346-
347- def _push (pid : str | None ) -> None :
348- if not pid :
349- return
350- if pid in seen :
351- return
352- seen .add (pid )
353- candidates .append (pid )
354-
467+ == == == =
468+ error_msg = (
469+ f"Failed to discover authentication protocols. "
470+ f"Tried URLs: { discovery_attempted_urls } . "
471+ f"PRM available: { prm is not None } , "
472+ f"PRM has authorization_servers: { bool (prm .authorization_servers if prm else False )} , "
473+ f"WWW-Authenticate protocols: { auth_protocols_header } "
474+ )
475+ logger .error (error_msg )
476+ raise RuntimeError (error_msg )
477+ >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
478+
479+ candidates = _build_protocol_candidates (
480+ available , default_protocol , protocol_preferences
481+ )
482+ for selected_id in candidates :
483+ protocol = self ._get_protocol (selected_id )
484+ if protocol is None :
485+ logger .debug ("Protocol %s not injected as instance; skipping" , selected_id )
486+ continue
487+ attempted_any = True
488+
489+ < << << << HEAD
355490 # Default protocol first (server recommendation)
356491 _push (default_protocol )
357492 # Then order by preferences if provided
@@ -381,8 +516,58 @@ def _push(pid: str | None) -> None:
381516 for m in protocols_metadata :
382517 if m .protocol_id == selected_id :
383518 protocol_metadata = m
519+ == == == =
520+ protocol_metadata = None
521+ if protocols_metadata :
522+ for m in protocols_metadata :
523+ if m .protocol_id == selected_id :
524+ protocol_metadata = m
525+ break
526+
527+ try :
528+ if selected_id == "oauth2" :
529+ oauth_protocol = protocol
530+ provider = OAuthClientProvider (
531+ server_url = server_url ,
532+ client_metadata = getattr (oauth_protocol , "_client_metadata" ),
533+ storage = cast (OAuth2TokenStorage , self .storage ),
534+ redirect_handler = getattr (oauth_protocol , "_redirect_handler" , None ),
535+ callback_handler = getattr (oauth_protocol , "_callback_handler" , None ),
536+ timeout = getattr (oauth_protocol , "_timeout" , self .timeout ),
537+ client_metadata_url = getattr (oauth_protocol , "_client_metadata_url" , None ),
538+ fixed_client_info = getattr (oauth_protocol , "_fixed_client_info" , None ),
539+ )
540+ provider .context .protocol_version = request .headers .get (MCP_PROTOCOL_VERSION )
541+ oauth_gen = oauth_401_flow_generator (
542+ provider , original_request , original_401_response , initial_prm = prm
543+ )
544+ auth_req = await oauth_gen .__anext__ ()
545+ while True :
546+ auth_resp = yield auth_req
547+ try :
548+ auth_req = await oauth_gen .asend (auth_resp )
549+ except StopAsyncIteration :
550+ >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
384551 break
552+ else :
553+ context = AuthContext (
554+ server_url = server_url ,
555+ storage = self .storage ,
556+ protocol_id = selected_id ,
557+ protocol_metadata = protocol_metadata ,
558+ current_credentials = None ,
559+ dpop_storage = self .dpop_storage ,
560+ dpop_enabled = self .dpop_enabled ,
561+ http_client = self ._http_client ,
562+ resource_metadata_url = resource_metadata_url ,
563+ protected_resource_metadata = prm ,
564+ scope_from_www_auth = extract_scope_from_www_auth (original_401_response ),
565+ )
566+ credentials = await protocol .authenticate (context )
567+ to_store = _credentials_to_storage (credentials )
568+ await self .storage .set_tokens (to_store )
385569
570+ < << << << HEAD
386571 try :
387572 if selected_id == "oauth2" :
388573 # OAuth: drive shared generator (single client, yield)
@@ -439,6 +624,13 @@ def _push(pid: str | None) -> None:
439624 "Protocol %s authentication failed: %s" , selected_id , e
440625 )
441626 continue
627+ == == == =
628+ break
629+ except Exception as e :
630+ last_auth_error = e
631+ logger .debug ("Protocol %s authentication failed: %s" , selected_id , e )
632+ continue
633+ >> >> >> > 69 d2d1d (refactor (auth ): reduce async_auth_flow complexity , remove noqa )
442634
443635 credentials = await self ._get_credentials ()
444636 if credentials and self ._is_credentials_valid (credentials ):
0 commit comments