Skip to content

Commit de25956

Browse files
committed
fix(auth): restore multiprotocol provider after refactor
1 parent 8c1b060 commit de25956

1 file changed

Lines changed: 10 additions & 202 deletions

File tree

src/mcp/client/auth/multi_protocol.py

Lines changed: 10 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import math
2525
import time
26-
from collections.abc import AsyncGenerator, Mapping
26+
from collections.abc import AsyncGenerator
2727
from typing import Any, Protocol, cast
2828
from urllib.parse import urljoin
2929

@@ -36,7 +36,6 @@
3636
from mcp.client.auth.protocol import AuthContext, AuthProtocol, DPoPEnabledProtocol
3737
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3838
from mcp.client.auth.utils import (
39-
build_authorization_servers_discovery_urls,
4039
build_protected_resource_metadata_discovery_urls,
4140
create_oauth_metadata_request,
4241
extract_auth_protocols_from_www_auth,
@@ -45,7 +44,6 @@
4544
extract_protocol_preferences_from_www_auth,
4645
extract_resource_metadata_from_www_auth,
4746
extract_scope_from_www_auth,
48-
format_json_for_logging,
4947
handle_protected_resource_response,
5048
)
5149
from mcp.shared.auth import (
@@ -62,42 +60,6 @@
6260
UNSPECIFIED_PROTOCOL_PREFERENCE: float = math.inf
6361

6462

65-
class _DiscoveryResult:
66-
"""Mutable holder for 401 discovery results (PRM, protocols, attempted URLs)."""
67-
68-
def __init__(self) -> None:
69-
self.prm: ProtectedResourceMetadata | None = None
70-
self.protocols_metadata: list[AuthProtocolMetadata] = []
71-
self.discovery_attempted_urls: list[str] = []
72-
73-
74-
def _build_protocol_candidates(
75-
available: list[str],
76-
default_protocol: str | None,
77-
protocol_preferences: Mapping[str, float | int] | None,
78-
) -> list[str]:
79-
"""Build ordered protocol candidate list: default first, then by preference, then rest."""
80-
candidates: list[str] = []
81-
seen: set[str] = set()
82-
83-
def push(pid: str | None) -> None:
84-
if not pid or pid in seen:
85-
return
86-
seen.add(pid)
87-
candidates.append(pid)
88-
89-
push(default_protocol)
90-
if protocol_preferences:
91-
for pid in sorted(
92-
available,
93-
key=lambda p: protocol_preferences.get(p, UNSPECIFIED_PROTOCOL_PREFERENCE),
94-
):
95-
push(pid)
96-
for pid in available:
97-
push(pid)
98-
return candidates
99-
100-
10163
class TokenStorage(Protocol):
10264
"""
10365
凭证存储协议(multi_protocol 契约)。
@@ -316,66 +278,6 @@ async def _handle_403_response(
316278
if error or scope:
317279
logger.debug("403 WWW-Authenticate: error=%s scope=%s", error, scope)
318280

319-
<<<<<<< HEAD
320-
=======
321-
async def _run_401_discovery_requests(
322-
self,
323-
resource_metadata_url: str | None,
324-
server_url: str,
325-
result: _DiscoveryResult,
326-
) -> AsyncGenerator[httpx.Request, httpx.Response]:
327-
"""Run PRM + protocol discovery + OAuth fallback; yield discovery requests, store outcome in result."""
328-
prm_urls = build_protected_resource_metadata_discovery_urls(resource_metadata_url, server_url)
329-
for url in prm_urls:
330-
logger.debug("[Auth discovery] Trying PRM endpoint: %s", url)
331-
prm_req = create_oauth_metadata_request(url)
332-
prm_resp = yield prm_req
333-
result.prm = await handle_protected_resource_response(prm_resp)
334-
if result.prm is not None:
335-
break
336-
337-
prm = result.prm
338-
if prm and prm.mcp_auth_protocols:
339-
protocol_ids = [m.protocol_id for m in prm.mcp_auth_protocols]
340-
auth_servers = [str(u) for u in (prm.authorization_servers or [])]
341-
logger.debug(
342-
"[Auth discovery] Using PRM mcp_auth_protocols (priority 1): "
343-
"protocol_ids=%s, authorization_servers=%s",
344-
protocol_ids,
345-
auth_servers,
346-
)
347-
result.protocols_metadata = list(prm.mcp_auth_protocols)
348-
else:
349-
discovery_urls = build_authorization_servers_discovery_urls(server_url)
350-
for discovery_url in discovery_urls:
351-
result.discovery_attempted_urls.append(discovery_url)
352-
logger.debug("[Auth discovery] Trying unified discovery endpoint: %s", discovery_url)
353-
discovery_req = create_oauth_metadata_request(discovery_url)
354-
discovery_resp = yield discovery_req
355-
result.protocols_metadata = await self._parse_protocols_from_discovery_response(
356-
discovery_resp, None
357-
)
358-
if result.protocols_metadata:
359-
logger.debug("Unified discovery succeeded at %s", discovery_url)
360-
break
361-
362-
if not result.protocols_metadata and result.prm and result.prm.authorization_servers:
363-
logger.debug("Unified discovery failed, falling back to OAuth protocol discovery")
364-
oauth_protocol = self._get_protocol("oauth2")
365-
if oauth_protocol and hasattr(oauth_protocol, "discover_metadata"):
366-
try:
367-
oauth_metadata = await oauth_protocol.discover_metadata(
368-
metadata_url=None,
369-
prm=result.prm,
370-
http_client=self._http_client,
371-
)
372-
if oauth_metadata:
373-
result.protocols_metadata = [oauth_metadata]
374-
logger.debug("OAuth protocol discovery succeeded")
375-
except Exception as e:
376-
logger.debug("OAuth protocol discovery failed: %s", e)
377-
378-
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
379281
async def async_auth_flow(
380282
self, request: httpx.Request
381283
) -> AsyncGenerator[httpx.Request, httpx.Response]:
@@ -406,7 +308,6 @@ async def async_auth_flow(
406308
attempted_any = False
407309
last_auth_error: Exception | None = None
408310

409-
<<<<<<< HEAD
410311
# Step 1: PRM discovery (yield)
411312
prm: ProtectedResourceMetadata | None = None
412313
prm_urls = build_protected_resource_metadata_discovery_urls(
@@ -430,63 +331,27 @@ async def async_auth_flow(
430331
discovery_resp, prm
431332
)
432333

433-
=======
434-
discovery_result = _DiscoveryResult()
435-
discovery_gen = self._run_401_discovery_requests(
436-
resource_metadata_url, server_url, discovery_result
437-
)
438-
try:
439-
req = await discovery_gen.__anext__()
440-
except StopAsyncIteration:
441-
pass
442-
else:
443-
while True:
444-
resp = yield req
445-
try:
446-
req = await discovery_gen.asend(resp)
447-
except StopAsyncIteration:
448-
break
449-
450-
prm = discovery_result.prm
451-
protocols_metadata = discovery_result.protocols_metadata
452-
discovery_attempted_urls = discovery_result.discovery_attempted_urls
453-
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
454334
available = (
455335
[m.protocol_id for m in protocols_metadata]
456336
if protocols_metadata
457337
else (auth_protocols_header or [])
458338
)
459339
if not available:
460-
<<<<<<< HEAD
461340
logger.debug("No available protocols from discovery or WWW-Authenticate")
462341
else:
463342
# Select protocol candidates based on server hints, but only
464343
# attempt protocols that are actually injected as instances.
465344
candidates: list[str] = []
466345
seen: set[str] = set()
467-
=======
468-
error_msg = (
469-
f"Failed to discover authentication protocols. "
470-
f"Tried URLs: {discovery_attempted_urls}. "
471-
f"PRM available: {prm is not None}, "
472-
f"PRM has authorization_servers: {bool(prm.authorization_servers if prm else False)}, "
473-
f"WWW-Authenticate protocols: {auth_protocols_header}"
474-
)
475-
logger.error(error_msg)
476-
raise RuntimeError(error_msg)
477-
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
478-
479-
candidates = _build_protocol_candidates(
480-
available, default_protocol, protocol_preferences
481-
)
482-
for selected_id in candidates:
483-
protocol = self._get_protocol(selected_id)
484-
if protocol is None:
485-
logger.debug("Protocol %s not injected as instance; skipping", selected_id)
486-
continue
487-
attempted_any = True
488-
489-
<<<<<<< HEAD
346+
347+
def _push(pid: str | None) -> None:
348+
if not pid:
349+
return
350+
if pid in seen:
351+
return
352+
seen.add(pid)
353+
candidates.append(pid)
354+
490355
# Default protocol first (server recommendation)
491356
_push(default_protocol)
492357
# Then order by preferences if provided
@@ -516,58 +381,8 @@ async def async_auth_flow(
516381
for m in protocols_metadata:
517382
if m.protocol_id == selected_id:
518383
protocol_metadata = m
519-
=======
520-
protocol_metadata = None
521-
if protocols_metadata:
522-
for m in protocols_metadata:
523-
if m.protocol_id == selected_id:
524-
protocol_metadata = m
525-
break
526-
527-
try:
528-
if selected_id == "oauth2":
529-
oauth_protocol = protocol
530-
provider = OAuthClientProvider(
531-
server_url=server_url,
532-
client_metadata=getattr(oauth_protocol, "_client_metadata"),
533-
storage=cast(OAuth2TokenStorage, self.storage),
534-
redirect_handler=getattr(oauth_protocol, "_redirect_handler", None),
535-
callback_handler=getattr(oauth_protocol, "_callback_handler", None),
536-
timeout=getattr(oauth_protocol, "_timeout", self.timeout),
537-
client_metadata_url=getattr(oauth_protocol, "_client_metadata_url", None),
538-
fixed_client_info=getattr(oauth_protocol, "_fixed_client_info", None),
539-
)
540-
provider.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
541-
oauth_gen = oauth_401_flow_generator(
542-
provider, original_request, original_401_response, initial_prm=prm
543-
)
544-
auth_req = await oauth_gen.__anext__()
545-
while True:
546-
auth_resp = yield auth_req
547-
try:
548-
auth_req = await oauth_gen.asend(auth_resp)
549-
except StopAsyncIteration:
550-
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
551384
break
552-
else:
553-
context = AuthContext(
554-
server_url=server_url,
555-
storage=self.storage,
556-
protocol_id=selected_id,
557-
protocol_metadata=protocol_metadata,
558-
current_credentials=None,
559-
dpop_storage=self.dpop_storage,
560-
dpop_enabled=self.dpop_enabled,
561-
http_client=self._http_client,
562-
resource_metadata_url=resource_metadata_url,
563-
protected_resource_metadata=prm,
564-
scope_from_www_auth=extract_scope_from_www_auth(original_401_response),
565-
)
566-
credentials = await protocol.authenticate(context)
567-
to_store = _credentials_to_storage(credentials)
568-
await self.storage.set_tokens(to_store)
569385

570-
<<<<<<< HEAD
571386
try:
572387
if selected_id == "oauth2":
573388
# OAuth: drive shared generator (single client, yield)
@@ -624,13 +439,6 @@ async def async_auth_flow(
624439
"Protocol %s authentication failed: %s", selected_id, e
625440
)
626441
continue
627-
=======
628-
break
629-
except Exception as e:
630-
last_auth_error = e
631-
logger.debug("Protocol %s authentication failed: %s", selected_id, e)
632-
continue
633-
>>>>>>> 69d2d1d (refactor(auth): reduce async_auth_flow complexity, remove noqa)
634442

635443
credentials = await self._get_credentials()
636444
if credentials and self._is_credentials_valid(credentials):

0 commit comments

Comments
 (0)