Skip to content

Commit cb5f638

Browse files
committed
Refactor OAuth 401/403 flow with shared generator to fix deadlock
- Add _oauth_401_flow.py: oauth_401_flow_generator, oauth_403_flow_generator - Refactor OAuthClientProvider.async_auth_flow to drive shared generators - Refactor MultiProtocolAuthProvider 401 handling to use shared OAuth flow - Fix OAuth2Protocol.authenticate to use fresh client when called outside flow - Fix OAuthTokenVerifier: use scope for HTTP method (HTTPConnection compat)
1 parent 3a3d914 commit cb5f638

5 files changed

Lines changed: 342 additions & 202 deletions

File tree

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
共享的 OAuth 401/403 流程 generator。
3+
4+
供 OAuthClientProvider 与 MultiProtocolAuthProvider 复用,通过 yield 发送请求,
5+
实现单 client、无死锁的 OAuth 发现与认证流程。
6+
"""
7+
8+
import logging
9+
from collections.abc import AsyncGenerator
10+
from typing import TYPE_CHECKING, Any, Protocol
11+
12+
import httpx
13+
14+
from mcp.client.auth.utils import (
15+
build_oauth_authorization_server_metadata_discovery_urls,
16+
build_protected_resource_metadata_discovery_urls,
17+
create_client_info_from_metadata_url,
18+
create_client_registration_request,
19+
create_oauth_metadata_request,
20+
extract_field_from_www_auth,
21+
extract_resource_metadata_from_www_auth,
22+
extract_scope_from_www_auth,
23+
get_client_metadata_scopes,
24+
handle_auth_metadata_response,
25+
handle_protected_resource_response,
26+
handle_registration_response,
27+
should_use_client_metadata_url,
28+
)
29+
30+
if TYPE_CHECKING:
31+
from mcp.shared.auth import ProtectedResourceMetadata
32+
33+
34+
class _OAuth401FlowProvider(Protocol):
35+
"""Provider interface for oauth_401_flow_generator (OAuthClientProvider duck type)."""
36+
37+
@property
38+
def context(self) -> Any:
39+
...
40+
41+
async def _perform_authorization(self) -> httpx.Request:
42+
...
43+
44+
async def _handle_token_response(self, response: httpx.Response) -> None:
45+
...
46+
47+
48+
logger = logging.getLogger(__name__)
49+
50+
51+
async def oauth_401_flow_generator(
52+
provider: _OAuth401FlowProvider,
53+
request: httpx.Request,
54+
response_401: httpx.Response,
55+
*,
56+
initial_prm: "ProtectedResourceMetadata | None" = None,
57+
) -> AsyncGenerator[httpx.Request, httpx.Response]:
58+
"""
59+
OAuth 401 流程:PRM 发现(可跳过)→ AS 发现 → scope → 注册/CIMD → 授权码 → Token 交换。
60+
61+
通过 yield 发出请求,由调用方负责发送并传回响应。供 OAuthClientProvider 与
62+
MultiProtocolAuthProvider 复用,实现单 client、yield 模式的 OAuth 流程。
63+
64+
Args:
65+
provider: OAuthClientProvider 实例,需有 context、_perform_authorization、_handle_token_response
66+
request: 触发 401 的原始请求
67+
response_401: 401 响应
68+
initial_prm: 若提供则跳过 PRM 发现(MultiProtocolAuthProvider 已事先完成)
69+
"""
70+
ctx = provider.context
71+
72+
if initial_prm is not None:
73+
ctx.protected_resource_metadata = initial_prm
74+
if initial_prm.authorization_servers:
75+
ctx.auth_server_url = str(initial_prm.authorization_servers[0])
76+
else:
77+
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
78+
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response_401)
79+
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
80+
www_auth_resource_metadata_url, ctx.server_url
81+
)
82+
83+
for url in prm_discovery_urls:
84+
discovery_request = create_oauth_metadata_request(url)
85+
discovery_response = yield discovery_request
86+
87+
prm = await handle_protected_resource_response(discovery_response)
88+
if prm:
89+
ctx.protected_resource_metadata = prm
90+
assert len(prm.authorization_servers) > 0
91+
ctx.auth_server_url = str(prm.authorization_servers[0])
92+
break
93+
logger.debug("Protected resource metadata discovery failed: %s", url)
94+
95+
# Step 2: Discover OAuth Authorization Server Metadata (OASM)
96+
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
97+
ctx.auth_server_url, ctx.server_url
98+
)
99+
100+
for url in asm_discovery_urls:
101+
oauth_metadata_request = create_oauth_metadata_request(url)
102+
oauth_metadata_response = yield oauth_metadata_request
103+
104+
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
105+
if not ok:
106+
break
107+
if asm:
108+
ctx.oauth_metadata = asm
109+
break
110+
logger.debug("OAuth metadata discovery failed: %s", url)
111+
112+
# Step 3: Apply scope selection strategy
113+
ctx.client_metadata.scope = get_client_metadata_scopes(
114+
extract_scope_from_www_auth(response_401),
115+
ctx.protected_resource_metadata,
116+
ctx.oauth_metadata,
117+
)
118+
119+
# Step 4: Register client or use URL-based client ID (CIMD)
120+
if not ctx.client_info:
121+
if should_use_client_metadata_url(ctx.oauth_metadata, ctx.client_metadata_url):
122+
logger.debug("Using URL-based client ID (CIMD): %s", ctx.client_metadata_url)
123+
client_information = create_client_info_from_metadata_url(
124+
ctx.client_metadata_url, # type: ignore[arg-type]
125+
redirect_uris=ctx.client_metadata.redirect_uris,
126+
)
127+
ctx.client_info = client_information
128+
await ctx.storage.set_client_info(client_information)
129+
else:
130+
registration_request = create_client_registration_request(
131+
ctx.oauth_metadata,
132+
ctx.client_metadata,
133+
ctx.get_authorization_base_url(ctx.server_url),
134+
)
135+
registration_response = yield registration_request
136+
client_information = await handle_registration_response(registration_response)
137+
ctx.client_info = client_information
138+
await ctx.storage.set_client_info(client_information)
139+
140+
# Step 5: Perform authorization and complete token exchange
141+
token_request = await provider._perform_authorization() # type: ignore[reportPrivateUsage]
142+
token_response = yield token_request
143+
await provider._handle_token_response(token_response) # type: ignore[reportPrivateUsage]
144+
145+
146+
async def oauth_403_flow_generator(
147+
provider: _OAuth401FlowProvider,
148+
request: httpx.Request,
149+
response_403: httpx.Response,
150+
) -> AsyncGenerator[httpx.Request, httpx.Response]:
151+
"""
152+
OAuth 403 insufficient_scope 流程:更新 scope → 重新授权 → Token 交换。
153+
"""
154+
ctx = provider.context
155+
error = extract_field_from_www_auth(response_403, "error")
156+
157+
if error == "insufficient_scope":
158+
ctx.client_metadata.scope = get_client_metadata_scopes(
159+
extract_scope_from_www_auth(response_403), ctx.protected_resource_metadata
160+
)
161+
token_request = await provider._perform_authorization() # type: ignore[reportPrivateUsage]
162+
token_response = yield token_request
163+
await provider._handle_token_response(token_response) # type: ignore[reportPrivateUsage]

0 commit comments

Comments
 (0)