Skip to content

Commit 8c1b060

Browse files
committed
refactor(auth): reduce async_auth_flow complexity, remove noqa
1 parent 1dda3cd commit 8c1b060

2 files changed

Lines changed: 246 additions & 25 deletions

File tree

src/mcp/client/auth/multi_protocol.py

Lines changed: 202 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import math
2525
import time
26-
from collections.abc import AsyncGenerator
26+
from collections.abc import AsyncGenerator, Mapping
2727
from typing import Any, Protocol, cast
2828
from urllib.parse import urljoin
2929

@@ -36,6 +36,7 @@
3636
from mcp.client.auth.protocol import AuthContext, AuthProtocol, DPoPEnabledProtocol
3737
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3838
from mcp.client.auth.utils import (
39+
build_authorization_servers_discovery_urls,
3940
build_protected_resource_metadata_discovery_urls,
4041
create_oauth_metadata_request,
4142
extract_auth_protocols_from_www_auth,
@@ -44,6 +45,7 @@
4445
extract_protocol_preferences_from_www_auth,
4546
extract_resource_metadata_from_www_auth,
4647
extract_scope_from_www_auth,
48+
format_json_for_logging,
4749
handle_protected_resource_response,
4850
)
4951
from mcp.shared.auth import (
@@ -60,6 +62,42 @@
6062
UNSPECIFIED_PROTOCOL_PREFERENCE: float = math.inf
6163

6264

65+
class _DiscoveryResult:
66+
"""Mutable holder for 401 discovery results (PRM, protocols, attempted URLs)."""
67+
68+
def __init__(self) -> None:
69+
self.prm: ProtectedResourceMetadata | None = None
70+
self.protocols_metadata: list[AuthProtocolMetadata] = []
71+
self.discovery_attempted_urls: list[str] = []
72+
73+
74+
def _build_protocol_candidates(
75+
available: list[str],
76+
default_protocol: str | None,
77+
protocol_preferences: Mapping[str, float | int] | None,
78+
) -> list[str]:
79+
"""Build ordered protocol candidate list: default first, then by preference, then rest."""
80+
candidates: list[str] = []
81+
seen: set[str] = set()
82+
83+
def push(pid: str | None) -> None:
84+
if not pid or pid in seen:
85+
return
86+
seen.add(pid)
87+
candidates.append(pid)
88+
89+
push(default_protocol)
90+
if protocol_preferences:
91+
for pid in sorted(
92+
available,
93+
key=lambda p: protocol_preferences.get(p, UNSPECIFIED_PROTOCOL_PREFERENCE),
94+
):
95+
push(pid)
96+
for pid in available:
97+
push(pid)
98+
return candidates
99+
100+
63101
class TokenStorage(Protocol):
64102
"""
65103
凭证存储协议(multi_protocol 契约)。
@@ -278,6 +316,66 @@ async def _handle_403_response(
278316
if error or scope:
279317
logger.debug("403 WWW-Authenticate: error=%s scope=%s", error, scope)
280318

319+
<<<<<<< HEAD
320+
=======
321+
async def _run_401_discovery_requests(
322+
self,
323+
resource_metadata_url: str | None,
324+
server_url: str,
325+
result: _DiscoveryResult,
326+
) -> AsyncGenerator[httpx.Request, httpx.Response]:
327+
"""Run PRM + protocol discovery + OAuth fallback; yield discovery requests, store outcome in result."""
328+
prm_urls = build_protected_resource_metadata_discovery_urls(resource_metadata_url, server_url)
329+
for url in prm_urls:
330+
logger.debug("[Auth discovery] Trying PRM endpoint: %s", url)
331+
prm_req = create_oauth_metadata_request(url)
332+
prm_resp = yield prm_req
333+
result.prm = await handle_protected_resource_response(prm_resp)
334+
if result.prm is not None:
335+
break
336+
337+
prm = result.prm
338+
if prm and prm.mcp_auth_protocols:
339+
protocol_ids = [m.protocol_id for m in prm.mcp_auth_protocols]
340+
auth_servers = [str(u) for u in (prm.authorization_servers or [])]
341+
logger.debug(
342+
"[Auth discovery] Using PRM mcp_auth_protocols (priority 1): "
343+
"protocol_ids=%s, authorization_servers=%s",
344+
protocol_ids,
345+
auth_servers,
346+
)
347+
result.protocols_metadata = list(prm.mcp_auth_protocols)
348+
else:
349+
discovery_urls = build_authorization_servers_discovery_urls(server_url)
350+
for discovery_url in discovery_urls:
351+
result.discovery_attempted_urls.append(discovery_url)
352+
logger.debug("[Auth discovery] Trying unified discovery endpoint: %s", discovery_url)
353+
discovery_req = create_oauth_metadata_request(discovery_url)
354+
discovery_resp = yield discovery_req
355+
result.protocols_metadata = await self._parse_protocols_from_discovery_response(
356+
discovery_resp, None
357+
)
358+
if result.protocols_metadata:
359+
logger.debug("Unified discovery succeeded at %s", discovery_url)
360+
break
361+
362+
if not result.protocols_metadata and result.prm and result.prm.authorization_servers:
363+
logger.debug("Unified discovery failed, falling back to OAuth protocol discovery")
364+
oauth_protocol = self._get_protocol("oauth2")
365+
if oauth_protocol and hasattr(oauth_protocol, "discover_metadata"):
366+
try:
367+
oauth_metadata = await oauth_protocol.discover_metadata(
368+
metadata_url=None,
369+
prm=result.prm,
370+
http_client=self._http_client,
371+
)
372+
if oauth_metadata:
373+
result.protocols_metadata = [oauth_metadata]
374+
logger.debug("OAuth protocol discovery succeeded")
375+
except Exception as e:
376+
logger.debug("OAuth protocol discovery failed: %s", e)
377+
378+
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
281379
async def async_auth_flow(
282380
self, request: httpx.Request
283381
) -> AsyncGenerator[httpx.Request, httpx.Response]:
@@ -308,6 +406,7 @@ async def async_auth_flow(
308406
attempted_any = False
309407
last_auth_error: Exception | None = None
310408

409+
<<<<<<< HEAD
311410
# Step 1: PRM discovery (yield)
312411
prm: ProtectedResourceMetadata | None = None
313412
prm_urls = build_protected_resource_metadata_discovery_urls(
@@ -331,27 +430,63 @@ async def async_auth_flow(
331430
discovery_resp, prm
332431
)
333432

433+
=======
434+
discovery_result = _DiscoveryResult()
435+
discovery_gen = self._run_401_discovery_requests(
436+
resource_metadata_url, server_url, discovery_result
437+
)
438+
try:
439+
req = await discovery_gen.__anext__()
440+
except StopAsyncIteration:
441+
pass
442+
else:
443+
while True:
444+
resp = yield req
445+
try:
446+
req = await discovery_gen.asend(resp)
447+
except StopAsyncIteration:
448+
break
449+
450+
prm = discovery_result.prm
451+
protocols_metadata = discovery_result.protocols_metadata
452+
discovery_attempted_urls = discovery_result.discovery_attempted_urls
453+
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
334454
available = (
335455
[m.protocol_id for m in protocols_metadata]
336456
if protocols_metadata
337457
else (auth_protocols_header or [])
338458
)
339459
if not available:
460+
<<<<<<< HEAD
340461
logger.debug("No available protocols from discovery or WWW-Authenticate")
341462
else:
342463
# Select protocol candidates based on server hints, but only
343464
# attempt protocols that are actually injected as instances.
344465
candidates: list[str] = []
345466
seen: set[str] = set()
346-
347-
def _push(pid: str | None) -> None:
348-
if not pid:
349-
return
350-
if pid in seen:
351-
return
352-
seen.add(pid)
353-
candidates.append(pid)
354-
467+
=======
468+
error_msg = (
469+
f"Failed to discover authentication protocols. "
470+
f"Tried URLs: {discovery_attempted_urls}. "
471+
f"PRM available: {prm is not None}, "
472+
f"PRM has authorization_servers: {bool(prm.authorization_servers if prm else False)}, "
473+
f"WWW-Authenticate protocols: {auth_protocols_header}"
474+
)
475+
logger.error(error_msg)
476+
raise RuntimeError(error_msg)
477+
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
478+
479+
candidates = _build_protocol_candidates(
480+
available, default_protocol, protocol_preferences
481+
)
482+
for selected_id in candidates:
483+
protocol = self._get_protocol(selected_id)
484+
if protocol is None:
485+
logger.debug("Protocol %s not injected as instance; skipping", selected_id)
486+
continue
487+
attempted_any = True
488+
489+
<<<<<<< HEAD
355490
# Default protocol first (server recommendation)
356491
_push(default_protocol)
357492
# Then order by preferences if provided
@@ -381,8 +516,58 @@ def _push(pid: str | None) -> None:
381516
for m in protocols_metadata:
382517
if m.protocol_id == selected_id:
383518
protocol_metadata = m
519+
=======
520+
protocol_metadata = None
521+
if protocols_metadata:
522+
for m in protocols_metadata:
523+
if m.protocol_id == selected_id:
524+
protocol_metadata = m
525+
break
526+
527+
try:
528+
if selected_id == "oauth2":
529+
oauth_protocol = protocol
530+
provider = OAuthClientProvider(
531+
server_url=server_url,
532+
client_metadata=getattr(oauth_protocol, "_client_metadata"),
533+
storage=cast(OAuth2TokenStorage, self.storage),
534+
redirect_handler=getattr(oauth_protocol, "_redirect_handler", None),
535+
callback_handler=getattr(oauth_protocol, "_callback_handler", None),
536+
timeout=getattr(oauth_protocol, "_timeout", self.timeout),
537+
client_metadata_url=getattr(oauth_protocol, "_client_metadata_url", None),
538+
fixed_client_info=getattr(oauth_protocol, "_fixed_client_info", None),
539+
)
540+
provider.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
541+
oauth_gen = oauth_401_flow_generator(
542+
provider, original_request, original_401_response, initial_prm=prm
543+
)
544+
auth_req = await oauth_gen.__anext__()
545+
while True:
546+
auth_resp = yield auth_req
547+
try:
548+
auth_req = await oauth_gen.asend(auth_resp)
549+
except StopAsyncIteration:
550+
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
384551
break
552+
else:
553+
context = AuthContext(
554+
server_url=server_url,
555+
storage=self.storage,
556+
protocol_id=selected_id,
557+
protocol_metadata=protocol_metadata,
558+
current_credentials=None,
559+
dpop_storage=self.dpop_storage,
560+
dpop_enabled=self.dpop_enabled,
561+
http_client=self._http_client,
562+
resource_metadata_url=resource_metadata_url,
563+
protected_resource_metadata=prm,
564+
scope_from_www_auth=extract_scope_from_www_auth(original_401_response),
565+
)
566+
credentials = await protocol.authenticate(context)
567+
to_store = _credentials_to_storage(credentials)
568+
await self.storage.set_tokens(to_store)
385569

570+
<<<<<<< HEAD
386571
try:
387572
if selected_id == "oauth2":
388573
# OAuth: drive shared generator (single client, yield)
@@ -439,6 +624,13 @@ def _push(pid: str | None) -> None:
439624
"Protocol %s authentication failed: %s", selected_id, e
440625
)
441626
continue
627+
=======
628+
break
629+
except Exception as e:
630+
last_auth_error = e
631+
logger.debug("Protocol %s authentication failed: %s", selected_id, e)
632+
continue
633+
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
442634

443635
credentials = await self._get_credentials()
444636
if credentials and self._is_credentials_valid(credentials):

src/mcp/client/auth/utils.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,35 @@ def extract_protocol_preferences_from_www_auth(response: Response) -> dict[str,
134134
return preferences if preferences else None
135135

136136

137+
def build_authorization_servers_discovery_urls(resource_url: str) -> list[str]:
138+
"""Build ordered list of unified discovery URLs.
139+
140+
Tries a path-relative discovery URL first (if resource_url contains a path),
141+
then falls back to the host-root discovery URL.
142+
"""
143+
parsed = urlparse(resource_url)
144+
base_url = f"{parsed.scheme}://{parsed.netloc}"
145+
146+
urls: list[str] = []
147+
148+
# Path-relative: https://host/<path>/.well-known/authorization_servers
149+
if parsed.path and parsed.path != "/":
150+
path = parsed.path.rstrip("/")
151+
urls.append(urljoin(base_url, f"{path}/.well-known/authorization_servers"))
152+
153+
# Root: https://host/.well-known/authorization_servers
154+
urls.append(urljoin(base_url, "/.well-known/authorization_servers"))
155+
156+
# De-duplicate while preserving order.
157+
seen: set[str] = set()
158+
unique: list[str] = []
159+
for url in urls:
160+
if url not in seen:
161+
seen.add(url)
162+
unique.append(url)
163+
return unique
164+
165+
137166
async def discover_authorization_servers(
138167
resource_url: str,
139168
http_client: AsyncClient,
@@ -155,21 +184,21 @@ async def discover_authorization_servers(
155184
Returns:
156185
List of protocol metadata; empty if discovery fails and no PRM fallback.
157186
"""
158-
# 1. Unified discovery endpoint (path-relative to resource_url)
159-
discovery_url = urljoin(resource_url.rstrip("/") + "/", ".well-known/authorization_servers")
160-
try:
161-
response = await http_client.get(discovery_url)
162-
if response.status_code == 200:
163-
content = await response.aread()
164-
data = json.loads(content)
165-
raw = data.get("protocols")
166-
protocols_data: list[dict[str, Any]] = cast(list[dict[str, Any]], raw) if isinstance(raw, list) else []
167-
if protocols_data:
168-
return [AuthProtocolMetadata.model_validate(p) for p in protocols_data]
169-
except (ValidationError, ValueError, KeyError, TypeError) as e:
170-
logger.debug("Unified authorization_servers discovery failed: %s", e)
171-
except Exception as e:
172-
logger.debug("Unified authorization_servers request failed: %s", e)
187+
# 1. Unified discovery endpoint (path-relative first, then root)
188+
for discovery_url in build_authorization_servers_discovery_urls(resource_url):
189+
try:
190+
response = await http_client.get(discovery_url)
191+
if response.status_code == 200:
192+
content = await response.aread()
193+
data = json.loads(content)
194+
raw = data.get("protocols")
195+
protocols_data: list[dict[str, Any]] = cast(list[dict[str, Any]], raw) if isinstance(raw, list) else []
196+
if protocols_data:
197+
return [AuthProtocolMetadata.model_validate(p) for p in protocols_data]
198+
except (ValidationError, ValueError, KeyError, TypeError) as e:
199+
logger.debug("Unified authorization_servers discovery failed (%s): %s", discovery_url, e)
200+
except Exception as e:
201+
logger.debug("Unified authorization_servers request failed (%s): %s", discovery_url, e)
173202

174203
# 2. Fallback: use protocol list from PRM
175204
if prm is not None and prm.mcp_auth_protocols:

0 commit comments

Comments
 (0)