Skip to content

Commit 3a3d914

Browse files
committed
Add DPoP integration test script and multiprotocol example updates
- Add scripts/run_phase4_dpop_integration_test.sh for automated DPoP tests - Update test_oauth2_protocol: expect OAuthFlowError when http_client is None - Update simple-auth-multiprotocol-client: OAuth+DPoP support, InMemoryStorage - Update simple-auth-multiprotocol: DPoP logging, oauth2 protocol_version 2.0
1 parent 07ac1f8 commit 3a3d914

5 files changed

Lines changed: 528 additions & 27 deletions

File tree

examples/clients/simple-auth-multiprotocol-client/mcp_simple_auth_multiprotocol_client/main.py

Lines changed: 176 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,129 @@
11
#!/usr/bin/env python3
2-
"""Multi-protocol MCP client: API Key + Mutual TLS (placeholder)."""
2+
"""Multi-protocol MCP client: OAuth (with optional DPoP), API Key, Mutual TLS (placeholder)."""
33

44
import asyncio
55
import os
6+
import threading
7+
import time
8+
import webbrowser
9+
from http.server import BaseHTTPRequestHandler, HTTPServer
610
from typing import Any
11+
from urllib.parse import parse_qs, urlparse
712

813
import httpx
914
from mcp.client.auth.multi_protocol import MultiProtocolAuthProvider, TokenStorage
1015
from mcp.client.auth.protocol import AuthContext, AuthProtocol
16+
from mcp.client.auth.protocols.oauth2 import OAuth2Protocol
1117
from mcp.client.auth.registry import AuthProtocolRegistry
1218
from mcp.client.session import ClientSession
1319
from mcp.client.streamable_http import streamable_http_client
1420
from mcp.shared.auth import (
1521
APIKeyCredentials,
1622
AuthCredentials,
1723
AuthProtocolMetadata,
24+
OAuthClientMetadata,
1825
OAuthToken,
1926
ProtectedResourceMetadata,
2027
)
28+
from pydantic import AnyHttpUrl
2129

2230

2331
class InMemoryStorage(TokenStorage):
24-
"""In-memory credential storage."""
32+
"""In-memory credential storage supporting both AuthCredentials and OAuthToken.
33+
34+
Also implements get_client_info/set_client_info for OAuth client registration storage.
35+
"""
2536

2637
def __init__(self) -> None:
27-
self._creds: AuthCredentials | None = None
38+
self._creds: AuthCredentials | OAuthToken | None = None
39+
self._client_info: Any = None
2840

2941
async def get_tokens(self) -> AuthCredentials | OAuthToken | None:
3042
return self._creds
3143

3244
async def set_tokens(self, tokens: AuthCredentials | OAuthToken) -> None:
33-
self._creds = tokens if isinstance(tokens, AuthCredentials) else None
45+
self._creds = tokens
46+
47+
async def get_client_info(self) -> Any:
48+
"""Get stored OAuth client information."""
49+
return self._client_info
50+
51+
async def set_client_info(self, client_info: Any) -> None:
52+
"""Store OAuth client information."""
53+
self._client_info = client_info
54+
55+
56+
class CallbackHandler(BaseHTTPRequestHandler):
57+
"""HTTP handler to capture OAuth callback."""
58+
59+
def __init__(self, request: Any, client_address: Any, server: Any, callback_data: dict[str, Any]):
60+
self.callback_data = callback_data
61+
super().__init__(request, client_address, server)
62+
63+
def do_GET(self) -> None:
64+
parsed = urlparse(self.path)
65+
query_params = parse_qs(parsed.query)
66+
if "code" in query_params:
67+
self.callback_data["authorization_code"] = query_params["code"][0]
68+
self.callback_data["state"] = query_params.get("state", [None])[0]
69+
self.send_response(200)
70+
self.send_header("Content-type", "text/html")
71+
self.end_headers()
72+
self.wfile.write(b"<h1>Authorization Successful!</h1><p>You can close this window.</p>")
73+
elif "error" in query_params:
74+
self.callback_data["error"] = query_params["error"][0]
75+
self.send_response(400)
76+
self.send_header("Content-type", "text/html")
77+
self.end_headers()
78+
self.wfile.write(f"<h1>Error</h1><p>{query_params['error'][0]}</p>".encode())
79+
else:
80+
self.send_response(404)
81+
self.end_headers()
82+
83+
def log_message(self, format: str, *args: Any) -> None:
84+
pass # Suppress logging
85+
86+
87+
class CallbackServer:
88+
"""Server to handle OAuth callbacks."""
89+
90+
def __init__(self, port: int = 3031):
91+
self.port = port
92+
self.server: HTTPServer | None = None
93+
self.thread: threading.Thread | None = None
94+
self.callback_data: dict[str, Any] = {"authorization_code": None, "state": None, "error": None}
95+
96+
def start(self) -> None:
97+
callback_data = self.callback_data
98+
99+
class DataHandler(CallbackHandler):
100+
def __init__(self, request: Any, client_address: Any, server: Any):
101+
super().__init__(request, client_address, server, callback_data)
102+
103+
self.server = HTTPServer(("localhost", self.port), DataHandler)
104+
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
105+
self.thread.start()
106+
print(f"Callback server started on http://localhost:{self.port}")
107+
108+
def stop(self) -> None:
109+
if self.server:
110+
self.server.shutdown()
111+
self.server.server_close()
112+
if self.thread:
113+
self.thread.join(timeout=1)
114+
115+
def wait_for_callback(self, timeout: int = 300) -> str:
116+
start = time.time()
117+
while time.time() - start < timeout:
118+
if self.callback_data["authorization_code"]:
119+
return self.callback_data["authorization_code"]
120+
if self.callback_data["error"]:
121+
raise RuntimeError(f"OAuth error: {self.callback_data['error']}")
122+
time.sleep(0.1)
123+
raise RuntimeError("Timeout waiting for OAuth callback")
124+
125+
def get_state(self) -> str | None:
126+
return self.callback_data["state"]
34127

35128

36129
class ApiKeyProtocol:
@@ -86,36 +179,87 @@ async def discover_metadata(
86179

87180

88181
def _register_protocols() -> None:
182+
AuthProtocolRegistry.register("oauth2", OAuth2Protocol)
89183
AuthProtocolRegistry.register("api_key", ApiKeyProtocol)
90184
AuthProtocolRegistry.register("mutual_tls", MutualTlsPlaceholderProtocol)
91185

92186

93187
class SimpleAuthMultiprotocolClient:
94-
"""MCP client with multi-protocol auth (API Key + mTLS placeholder)."""
188+
"""MCP client with multi-protocol auth (OAuth + DPoP, API Key, mTLS placeholder)."""
95189

96-
def __init__(self, server_url: str) -> None:
190+
def __init__(self, server_url: str, use_oauth: bool = False, dpop_enabled: bool = False) -> None:
97191
self.server_url = server_url
192+
self.use_oauth = use_oauth
193+
self.dpop_enabled = dpop_enabled
98194
self.session: ClientSession | None = None
99195

100196
async def connect(self) -> None:
101197
_register_protocols()
102-
api_key = os.getenv("MCP_API_KEY", "demo-api-key-12345")
103198
storage = InMemoryStorage()
104-
protocols: list[AuthProtocol] = [
105-
ApiKeyProtocol(api_key=api_key),
106-
MutualTlsPlaceholderProtocol(),
107-
]
108-
auth = MultiProtocolAuthProvider(
109-
server_url=self.server_url.rstrip("/").replace("/mcp", ""),
110-
storage=storage,
111-
protocols=protocols,
112-
)
113-
async with httpx.AsyncClient(auth=auth, follow_redirects=True) as http_client:
114-
async with streamable_http_client(
115-
url=self.server_url,
116-
http_client=http_client,
117-
) as (read_stream, write_stream, get_session_id):
118-
await self._run_session(read_stream, write_stream, get_session_id)
199+
protocols: list[AuthProtocol] = []
200+
201+
callback_server: CallbackServer | None = None
202+
203+
if self.use_oauth:
204+
# Setup OAuth with optional DPoP
205+
callback_server = CallbackServer(port=3031)
206+
callback_server.start()
207+
208+
async def callback_handler() -> tuple[str, str | None]:
209+
print("Waiting for OAuth authorization...")
210+
try:
211+
code = callback_server.wait_for_callback(timeout=300)
212+
return code, callback_server.get_state()
213+
finally:
214+
callback_server.stop()
215+
216+
async def redirect_handler(url: str) -> None:
217+
print(f"Opening browser for authorization: {url}")
218+
webbrowser.open(url)
219+
220+
client_metadata = OAuthClientMetadata(
221+
client_name="Multi-protocol Auth Client",
222+
redirect_uris=[AnyHttpUrl("http://localhost:3031/callback")],
223+
grant_types=["authorization_code", "refresh_token"],
224+
response_types=["code"],
225+
)
226+
227+
oauth_protocol = OAuth2Protocol(
228+
client_metadata=client_metadata,
229+
redirect_handler=redirect_handler,
230+
callback_handler=callback_handler,
231+
dpop_enabled=self.dpop_enabled,
232+
)
233+
protocols.append(oauth_protocol)
234+
print(f"OAuth protocol enabled (DPoP: {self.dpop_enabled})")
235+
236+
# Always add API Key and mTLS as fallback
237+
api_key = os.getenv("MCP_API_KEY", "demo-api-key-12345")
238+
protocols.append(ApiKeyProtocol(api_key=api_key))
239+
protocols.append(MutualTlsPlaceholderProtocol())
240+
241+
try:
242+
# Create http_client first, then pass it to auth provider
243+
# This allows OAuth discovery to work (requires http_client for PRM fetch)
244+
async with httpx.AsyncClient(follow_redirects=True) as http_client:
245+
auth = MultiProtocolAuthProvider(
246+
server_url=self.server_url.rstrip("/").replace("/mcp", ""),
247+
storage=storage,
248+
protocols=protocols,
249+
http_client=http_client,
250+
dpop_enabled=self.dpop_enabled,
251+
)
252+
# Set auth on client after creation
253+
http_client.auth = auth
254+
255+
async with streamable_http_client(
256+
url=self.server_url,
257+
http_client=http_client,
258+
) as (read_stream, write_stream, get_session_id):
259+
await self._run_session(read_stream, write_stream, get_session_id)
260+
finally:
261+
if callback_server:
262+
callback_server.stop()
119263

120264
async def _run_session(self, read_stream: Any, write_stream: Any, get_session_id: Any) -> None:
121265
print("Initializing MCP session...")
@@ -196,8 +340,17 @@ async def _interactive_loop(self) -> None:
196340

197341
async def main() -> None:
198342
server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8002/mcp")
343+
use_oauth = os.getenv("MCP_USE_OAUTH", "").lower() in ("1", "true", "yes")
344+
dpop_enabled = os.getenv("MCP_DPOP_ENABLED", "").lower() in ("1", "true", "yes")
345+
199346
print(f"Connecting to {server_url}...")
200-
client = SimpleAuthMultiprotocolClient(server_url)
347+
print(f" OAuth: {'enabled' if use_oauth else 'disabled'}")
348+
print(f" DPoP: {'enabled' if dpop_enabled else 'disabled'}")
349+
350+
if dpop_enabled and not use_oauth:
351+
print(" Warning: DPoP requires OAuth enabled (MCP_USE_OAUTH=1) to take effect")
352+
353+
client = SimpleAuthMultiprotocolClient(server_url, use_oauth=use_oauth, dpop_enabled=dpop_enabled)
201354
try:
202355
await client.connect()
203356
except Exception as e:

examples/servers/simple-auth-multiprotocol/mcp_simple_auth_multiprotocol/multiprotocol.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Multi-protocol auth: adapter for Starlette and Mutual TLS placeholder verifier."""
22

3+
import logging
34
import time
45
from typing import Any, cast
56

@@ -16,6 +17,8 @@
1617
OAuthTokenVerifier,
1718
)
1819

20+
logger = logging.getLogger(__name__)
21+
1922

2023
class MutualTLSVerifier:
2124
"""
@@ -84,11 +87,37 @@ def __init__(
8487
self._dpop_verifier = dpop_verifier
8588

8689
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, AuthenticatedUser] | None:
87-
result = await self._backend.verify(cast(Request, conn), dpop_verifier=self._dpop_verifier)
90+
request = cast(Request, conn)
91+
92+
# Log DPoP status
93+
dpop_header = request.headers.get("dpop")
94+
if self._dpop_verifier is not None:
95+
if dpop_header:
96+
logger.info("DPoP proof present, verification enabled")
97+
else:
98+
logger.debug("DPoP verification enabled but no DPoP header in request")
99+
elif dpop_header:
100+
logger.debug("DPoP header present but verification not enabled (ignoring)")
101+
102+
result = await self._backend.verify(request, dpop_verifier=self._dpop_verifier)
103+
88104
if result is None:
105+
if dpop_header and self._dpop_verifier is not None:
106+
logger.warning("Authentication failed (DPoP proof may be invalid)")
107+
else:
108+
logger.debug("Authentication failed (no valid credentials)")
89109
return None
110+
90111
if result.expires_at is not None and result.expires_at < int(time.time()):
112+
logger.warning("Token expired for client_id=%s", result.client_id)
91113
return None
114+
115+
# Log successful authentication
116+
if dpop_header and self._dpop_verifier is not None:
117+
logger.info("Authentication successful with DPoP (client_id=%s)", result.client_id)
118+
else:
119+
logger.info("Authentication successful (client_id=%s)", result.client_id)
120+
92121
return (
93122
AuthCredentials(result.scopes or []),
94123
AuthenticatedUser(result),

examples/servers/simple-auth-multiprotocol/mcp_simple_auth_multiprotocol/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _protocol_metadata_list(settings: ResourceServerSettings) -> list[AuthProtoc
6262
return [
6363
AuthProtocolMetadata(
6464
protocol_id="oauth2",
65-
protocol_version="1.0",
65+
protocol_version="2.0",
6666
metadata_url=oauth_metadata_url,
6767
scopes_supported=[settings.mcp_scope],
6868
),

0 commit comments

Comments
 (0)