|
1 | 1 | #!/usr/bin/env python3 |
2 | | -"""Multi-protocol MCP client: API Key + Mutual TLS (placeholder).""" |
| 2 | +"""Multi-protocol MCP client: OAuth (with optional DPoP), API Key, Mutual TLS (placeholder).""" |
3 | 3 |
|
4 | 4 | import asyncio |
5 | 5 | import os |
| 6 | +import threading |
| 7 | +import time |
| 8 | +import webbrowser |
| 9 | +from http.server import BaseHTTPRequestHandler, HTTPServer |
6 | 10 | from typing import Any |
| 11 | +from urllib.parse import parse_qs, urlparse |
7 | 12 |
|
8 | 13 | import httpx |
9 | 14 | from mcp.client.auth.multi_protocol import MultiProtocolAuthProvider, TokenStorage |
10 | 15 | from mcp.client.auth.protocol import AuthContext, AuthProtocol |
| 16 | +from mcp.client.auth.protocols.oauth2 import OAuth2Protocol |
11 | 17 | from mcp.client.auth.registry import AuthProtocolRegistry |
12 | 18 | from mcp.client.session import ClientSession |
13 | 19 | from mcp.client.streamable_http import streamable_http_client |
14 | 20 | from mcp.shared.auth import ( |
15 | 21 | APIKeyCredentials, |
16 | 22 | AuthCredentials, |
17 | 23 | AuthProtocolMetadata, |
| 24 | + OAuthClientMetadata, |
18 | 25 | OAuthToken, |
19 | 26 | ProtectedResourceMetadata, |
20 | 27 | ) |
| 28 | +from pydantic import AnyHttpUrl |
21 | 29 |
|
22 | 30 |
|
23 | 31 | class InMemoryStorage(TokenStorage): |
24 | | - """In-memory credential storage.""" |
| 32 | + """In-memory credential storage supporting both AuthCredentials and OAuthToken. |
| 33 | + |
| 34 | + Also implements get_client_info/set_client_info for OAuth client registration storage. |
| 35 | + """ |
25 | 36 |
|
26 | 37 | def __init__(self) -> None: |
27 | | - self._creds: AuthCredentials | None = None |
| 38 | + self._creds: AuthCredentials | OAuthToken | None = None |
| 39 | + self._client_info: Any = None |
28 | 40 |
|
29 | 41 | async def get_tokens(self) -> AuthCredentials | OAuthToken | None: |
30 | 42 | return self._creds |
31 | 43 |
|
32 | 44 | async def set_tokens(self, tokens: AuthCredentials | OAuthToken) -> None: |
33 | | - self._creds = tokens if isinstance(tokens, AuthCredentials) else None |
| 45 | + self._creds = tokens |
| 46 | + |
| 47 | + async def get_client_info(self) -> Any: |
| 48 | + """Get stored OAuth client information.""" |
| 49 | + return self._client_info |
| 50 | + |
| 51 | + async def set_client_info(self, client_info: Any) -> None: |
| 52 | + """Store OAuth client information.""" |
| 53 | + self._client_info = client_info |
| 54 | + |
| 55 | + |
| 56 | +class CallbackHandler(BaseHTTPRequestHandler): |
| 57 | + """HTTP handler to capture OAuth callback.""" |
| 58 | + |
| 59 | + def __init__(self, request: Any, client_address: Any, server: Any, callback_data: dict[str, Any]): |
| 60 | + self.callback_data = callback_data |
| 61 | + super().__init__(request, client_address, server) |
| 62 | + |
| 63 | + def do_GET(self) -> None: |
| 64 | + parsed = urlparse(self.path) |
| 65 | + query_params = parse_qs(parsed.query) |
| 66 | + if "code" in query_params: |
| 67 | + self.callback_data["authorization_code"] = query_params["code"][0] |
| 68 | + self.callback_data["state"] = query_params.get("state", [None])[0] |
| 69 | + self.send_response(200) |
| 70 | + self.send_header("Content-type", "text/html") |
| 71 | + self.end_headers() |
| 72 | + self.wfile.write(b"<h1>Authorization Successful!</h1><p>You can close this window.</p>") |
| 73 | + elif "error" in query_params: |
| 74 | + self.callback_data["error"] = query_params["error"][0] |
| 75 | + self.send_response(400) |
| 76 | + self.send_header("Content-type", "text/html") |
| 77 | + self.end_headers() |
| 78 | + self.wfile.write(f"<h1>Error</h1><p>{query_params['error'][0]}</p>".encode()) |
| 79 | + else: |
| 80 | + self.send_response(404) |
| 81 | + self.end_headers() |
| 82 | + |
| 83 | + def log_message(self, format: str, *args: Any) -> None: |
| 84 | + pass # Suppress logging |
| 85 | + |
| 86 | + |
| 87 | +class CallbackServer: |
| 88 | + """Server to handle OAuth callbacks.""" |
| 89 | + |
| 90 | + def __init__(self, port: int = 3031): |
| 91 | + self.port = port |
| 92 | + self.server: HTTPServer | None = None |
| 93 | + self.thread: threading.Thread | None = None |
| 94 | + self.callback_data: dict[str, Any] = {"authorization_code": None, "state": None, "error": None} |
| 95 | + |
| 96 | + def start(self) -> None: |
| 97 | + callback_data = self.callback_data |
| 98 | + |
| 99 | + class DataHandler(CallbackHandler): |
| 100 | + def __init__(self, request: Any, client_address: Any, server: Any): |
| 101 | + super().__init__(request, client_address, server, callback_data) |
| 102 | + |
| 103 | + self.server = HTTPServer(("localhost", self.port), DataHandler) |
| 104 | + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) |
| 105 | + self.thread.start() |
| 106 | + print(f"Callback server started on http://localhost:{self.port}") |
| 107 | + |
| 108 | + def stop(self) -> None: |
| 109 | + if self.server: |
| 110 | + self.server.shutdown() |
| 111 | + self.server.server_close() |
| 112 | + if self.thread: |
| 113 | + self.thread.join(timeout=1) |
| 114 | + |
| 115 | + def wait_for_callback(self, timeout: int = 300) -> str: |
| 116 | + start = time.time() |
| 117 | + while time.time() - start < timeout: |
| 118 | + if self.callback_data["authorization_code"]: |
| 119 | + return self.callback_data["authorization_code"] |
| 120 | + if self.callback_data["error"]: |
| 121 | + raise RuntimeError(f"OAuth error: {self.callback_data['error']}") |
| 122 | + time.sleep(0.1) |
| 123 | + raise RuntimeError("Timeout waiting for OAuth callback") |
| 124 | + |
| 125 | + def get_state(self) -> str | None: |
| 126 | + return self.callback_data["state"] |
34 | 127 |
|
35 | 128 |
|
36 | 129 | class ApiKeyProtocol: |
@@ -86,36 +179,87 @@ async def discover_metadata( |
86 | 179 |
|
87 | 180 |
|
88 | 181 | def _register_protocols() -> None: |
| 182 | + AuthProtocolRegistry.register("oauth2", OAuth2Protocol) |
89 | 183 | AuthProtocolRegistry.register("api_key", ApiKeyProtocol) |
90 | 184 | AuthProtocolRegistry.register("mutual_tls", MutualTlsPlaceholderProtocol) |
91 | 185 |
|
92 | 186 |
|
93 | 187 | class SimpleAuthMultiprotocolClient: |
94 | | - """MCP client with multi-protocol auth (API Key + mTLS placeholder).""" |
| 188 | + """MCP client with multi-protocol auth (OAuth + DPoP, API Key, mTLS placeholder).""" |
95 | 189 |
|
96 | | - def __init__(self, server_url: str) -> None: |
| 190 | + def __init__(self, server_url: str, use_oauth: bool = False, dpop_enabled: bool = False) -> None: |
97 | 191 | self.server_url = server_url |
| 192 | + self.use_oauth = use_oauth |
| 193 | + self.dpop_enabled = dpop_enabled |
98 | 194 | self.session: ClientSession | None = None |
99 | 195 |
|
100 | 196 | async def connect(self) -> None: |
101 | 197 | _register_protocols() |
102 | | - api_key = os.getenv("MCP_API_KEY", "demo-api-key-12345") |
103 | 198 | storage = InMemoryStorage() |
104 | | - protocols: list[AuthProtocol] = [ |
105 | | - ApiKeyProtocol(api_key=api_key), |
106 | | - MutualTlsPlaceholderProtocol(), |
107 | | - ] |
108 | | - auth = MultiProtocolAuthProvider( |
109 | | - server_url=self.server_url.rstrip("/").replace("/mcp", ""), |
110 | | - storage=storage, |
111 | | - protocols=protocols, |
112 | | - ) |
113 | | - async with httpx.AsyncClient(auth=auth, follow_redirects=True) as http_client: |
114 | | - async with streamable_http_client( |
115 | | - url=self.server_url, |
116 | | - http_client=http_client, |
117 | | - ) as (read_stream, write_stream, get_session_id): |
118 | | - await self._run_session(read_stream, write_stream, get_session_id) |
| 199 | + protocols: list[AuthProtocol] = [] |
| 200 | + |
| 201 | + callback_server: CallbackServer | None = None |
| 202 | + |
| 203 | + if self.use_oauth: |
| 204 | + # Setup OAuth with optional DPoP |
| 205 | + callback_server = CallbackServer(port=3031) |
| 206 | + callback_server.start() |
| 207 | + |
| 208 | + async def callback_handler() -> tuple[str, str | None]: |
| 209 | + print("Waiting for OAuth authorization...") |
| 210 | + try: |
| 211 | + code = callback_server.wait_for_callback(timeout=300) |
| 212 | + return code, callback_server.get_state() |
| 213 | + finally: |
| 214 | + callback_server.stop() |
| 215 | + |
| 216 | + async def redirect_handler(url: str) -> None: |
| 217 | + print(f"Opening browser for authorization: {url}") |
| 218 | + webbrowser.open(url) |
| 219 | + |
| 220 | + client_metadata = OAuthClientMetadata( |
| 221 | + client_name="Multi-protocol Auth Client", |
| 222 | + redirect_uris=[AnyHttpUrl("http://localhost:3031/callback")], |
| 223 | + grant_types=["authorization_code", "refresh_token"], |
| 224 | + response_types=["code"], |
| 225 | + ) |
| 226 | + |
| 227 | + oauth_protocol = OAuth2Protocol( |
| 228 | + client_metadata=client_metadata, |
| 229 | + redirect_handler=redirect_handler, |
| 230 | + callback_handler=callback_handler, |
| 231 | + dpop_enabled=self.dpop_enabled, |
| 232 | + ) |
| 233 | + protocols.append(oauth_protocol) |
| 234 | + print(f"OAuth protocol enabled (DPoP: {self.dpop_enabled})") |
| 235 | + |
| 236 | + # Always add API Key and mTLS as fallback |
| 237 | + api_key = os.getenv("MCP_API_KEY", "demo-api-key-12345") |
| 238 | + protocols.append(ApiKeyProtocol(api_key=api_key)) |
| 239 | + protocols.append(MutualTlsPlaceholderProtocol()) |
| 240 | + |
| 241 | + try: |
| 242 | + # Create http_client first, then pass it to auth provider |
| 243 | + # This allows OAuth discovery to work (requires http_client for PRM fetch) |
| 244 | + async with httpx.AsyncClient(follow_redirects=True) as http_client: |
| 245 | + auth = MultiProtocolAuthProvider( |
| 246 | + server_url=self.server_url.rstrip("/").replace("/mcp", ""), |
| 247 | + storage=storage, |
| 248 | + protocols=protocols, |
| 249 | + http_client=http_client, |
| 250 | + dpop_enabled=self.dpop_enabled, |
| 251 | + ) |
| 252 | + # Set auth on client after creation |
| 253 | + http_client.auth = auth |
| 254 | + |
| 255 | + async with streamable_http_client( |
| 256 | + url=self.server_url, |
| 257 | + http_client=http_client, |
| 258 | + ) as (read_stream, write_stream, get_session_id): |
| 259 | + await self._run_session(read_stream, write_stream, get_session_id) |
| 260 | + finally: |
| 261 | + if callback_server: |
| 262 | + callback_server.stop() |
119 | 263 |
|
120 | 264 | async def _run_session(self, read_stream: Any, write_stream: Any, get_session_id: Any) -> None: |
121 | 265 | print("Initializing MCP session...") |
@@ -196,8 +340,17 @@ async def _interactive_loop(self) -> None: |
196 | 340 |
|
197 | 341 | async def main() -> None: |
198 | 342 | server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8002/mcp") |
| 343 | + use_oauth = os.getenv("MCP_USE_OAUTH", "").lower() in ("1", "true", "yes") |
| 344 | + dpop_enabled = os.getenv("MCP_DPOP_ENABLED", "").lower() in ("1", "true", "yes") |
| 345 | + |
199 | 346 | print(f"Connecting to {server_url}...") |
200 | | - client = SimpleAuthMultiprotocolClient(server_url) |
| 347 | + print(f" OAuth: {'enabled' if use_oauth else 'disabled'}") |
| 348 | + print(f" DPoP: {'enabled' if dpop_enabled else 'disabled'}") |
| 349 | + |
| 350 | + if dpop_enabled and not use_oauth: |
| 351 | + print(" Warning: DPoP requires OAuth enabled (MCP_USE_OAUTH=1) to take effect") |
| 352 | + |
| 353 | + client = SimpleAuthMultiprotocolClient(server_url, use_oauth=use_oauth, dpop_enabled=dpop_enabled) |
201 | 354 | try: |
202 | 355 | await client.connect() |
203 | 356 | except Exception as e: |
|
0 commit comments