Skip to content

Commit 53b11f9

Browse files
committed
feat(auth): OAuth2 client_credentials grant and fixed_client_info
- OAuthClientProvider: add fixed_client_info, _exchange_token_client_credentials - OAuth2Protocol: accept fixed_client_info for M2M flows - MultiProtocolAuthProvider: pass fixed_client_info into OAuthClientProvider - _oauth_401_flow: require client_info for client_credentials, skip dynamic registration - Server token handler: ClientCredentialsRequest, exchange_client_credentials - Server routes: advertise client_credentials in grant_types_supported
1 parent 5129578 commit 53b11f9

5 files changed

Lines changed: 88 additions & 19 deletions

File tree

src/mcp/client/auth/_oauth_401_flow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import httpx
1313

14+
from mcp.client.auth.exceptions import OAuthFlowError
1415
from mcp.client.auth.utils import (
1516
build_oauth_authorization_server_metadata_discovery_urls,
1617
build_protected_resource_metadata_discovery_urls,
@@ -117,6 +118,10 @@ async def oauth_401_flow_generator(
117118
)
118119

119120
# Step 4: Register client or use URL-based client ID (CIMD)
121+
# For client_credentials, a fixed client_id/client_secret must be provided; do not attempt DCR/CIMD.
122+
if "client_credentials" in (ctx.client_metadata.grant_types or []) and not ctx.client_info:
123+
raise OAuthFlowError("Missing client_info for client_credentials flow")
124+
120125
if not ctx.client_info:
121126
if should_use_client_metadata_url(ctx.oauth_metadata, ctx.client_metadata_url):
122127
logger.debug("Using URL-based client ID (CIMD): %s", ctx.client_metadata_url)

src/mcp/client/auth/multi_protocol.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -393,22 +393,13 @@ def _push(pid: str | None) -> None:
393393
oauth_protocol, "_client_metadata"
394394
),
395395
storage=cast(OAuth2TokenStorage, self.storage),
396-
redirect_handler=getattr(
397-
oauth_protocol, "_redirect_handler", None
398-
),
399-
callback_handler=getattr(
400-
oauth_protocol, "_callback_handler", None
401-
),
402-
timeout=getattr(
403-
oauth_protocol, "_timeout", self.timeout
404-
),
405-
client_metadata_url=getattr(
406-
oauth_protocol, "_client_metadata_url", None
407-
),
408-
)
409-
provider.context.protocol_version = request.headers.get(
410-
MCP_PROTOCOL_VERSION
396+
redirect_handler=getattr(oauth_protocol, "_redirect_handler", None),
397+
callback_handler=getattr(oauth_protocol, "_callback_handler", None),
398+
timeout=getattr(oauth_protocol, "_timeout", self.timeout),
399+
client_metadata_url=getattr(oauth_protocol, "_client_metadata_url", None),
400+
fixed_client_info=getattr(oauth_protocol, "_fixed_client_info", None),
411401
)
402+
provider.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
412403
gen = oauth_401_flow_generator(
413404
provider, original_request, original_401_response, initial_prm=prm
414405
)

src/mcp/client/auth/oauth2.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __init__(
228228
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
229229
timeout: float = 300.0,
230230
client_metadata_url: str | None = None,
231+
fixed_client_info: OAuthClientInformationFull | None = None,
231232
):
232233
"""Initialize OAuth2 authentication.
233234
@@ -262,6 +263,11 @@ def __init__(
262263
timeout=timeout,
263264
client_metadata_url=client_metadata_url,
264265
)
266+
self._fixed_client_info = fixed_client_info
267+
if fixed_client_info is not None:
268+
# In multi-protocol OAuth flow, we may drive oauth_401_flow_generator directly
269+
# without calling _initialize(); ensure client_info is available upfront.
270+
self.context.client_info = fixed_client_info
265271
self._initialized = False
266272

267273
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
@@ -297,6 +303,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
297303

298304
async def _perform_authorization(self) -> httpx.Request:
299305
"""Perform the authorization flow."""
306+
grant_types = set(self.context.client_metadata.grant_types or [])
307+
if "client_credentials" in grant_types:
308+
token_request = await self._exchange_token_client_credentials()
309+
return token_request
300310
auth_code, code_verifier = await self._perform_authorization_code_grant()
301311
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
302312
return token_request
@@ -362,6 +372,31 @@ def _get_token_endpoint(self) -> str:
362372
token_url = urljoin(auth_base_url, "/token")
363373
return token_url
364374

375+
async def _exchange_token_client_credentials(self) -> httpx.Request:
376+
"""Build token exchange request for client_credentials flow."""
377+
if not self.context.client_info:
378+
raise OAuthFlowError("Missing client info for client_credentials flow")
379+
380+
token_url = self._get_token_endpoint()
381+
token_data: dict[str, str] = {
382+
"grant_type": "client_credentials",
383+
}
384+
385+
# Some servers require explicit client_id in the form body (especially for client_secret_post).
386+
if self.context.client_info.client_id:
387+
token_data["client_id"] = self.context.client_info.client_id
388+
389+
# Only include resource param if conditions are met
390+
if self.context.should_include_resource_param(self.context.protocol_version):
391+
token_data["resource"] = self.context.get_resource_url() # RFC 8707
392+
393+
if self.context.client_metadata.scope:
394+
token_data["scope"] = self.context.client_metadata.scope
395+
396+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
397+
token_data, headers = self.context.prepare_token_auth(token_data, headers)
398+
return httpx.Request("POST", token_url, data=token_data, headers=headers)
399+
365400
async def _exchange_token_authorization_code(
366401
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
367402
) -> httpx.Request:
@@ -543,15 +578,14 @@ async def run_authentication(
543578
self.context.client_info = client_information
544579
await self.context.storage.set_client_info(client_information)
545580

546-
auth_code, code_verifier = await self._perform_authorization_code_grant()
547-
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
581+
token_request = await self._perform_authorization()
548582
token_response = await http_client.send(token_request)
549583
await self._handle_token_response(token_response)
550584

551585
async def _initialize(self) -> None: # pragma: no cover
552586
"""Load stored tokens and client info."""
553587
self.context.current_tokens = await self.context.storage.get_tokens()
554-
self.context.client_info = await self.context.storage.get_client_info()
588+
self.context.client_info = self._fixed_client_info or await self.context.storage.get_client_info()
555589
self._initialized = True
556590

557591
def _add_auth_header(self, request: httpx.Request) -> None:

src/mcp/client/auth/protocols/oauth2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from mcp.shared.auth import (
3232
AuthCredentials,
3333
AuthProtocolMetadata,
34+
OAuthClientInformationFull,
3435
OAuthClientMetadata,
3536
OAuthCredentials,
3637
OAuthMetadata,
@@ -104,6 +105,7 @@ def __init__(
104105
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
105106
timeout: float = 300.0,
106107
client_metadata_url: str | None = None,
108+
fixed_client_info: OAuthClientInformationFull | None = None,
107109
dpop_enabled: bool = False,
108110
dpop_algorithm: DPoPAlgorithm = "ES256",
109111
dpop_rsa_key_size: int = RSA_KEY_SIZE_DEFAULT,
@@ -113,6 +115,7 @@ def __init__(
113115
self._callback_handler = callback_handler
114116
self._timeout = timeout
115117
self._client_metadata_url = client_metadata_url
118+
self._fixed_client_info = fixed_client_info
116119
self._dpop_enabled = dpop_enabled
117120
self._dpop_algorithm: DPoPAlgorithm = dpop_algorithm
118121
self._dpop_rsa_key_size = dpop_rsa_key_size
@@ -133,6 +136,7 @@ async def authenticate(self, context: AuthContext) -> AuthCredentials:
133136
callback_handler=self._callback_handler,
134137
timeout=self._timeout,
135138
client_metadata_url=self._client_metadata_url,
139+
fixed_client_info=self._fixed_client_info,
136140
)
137141
protocol_version: str | None = None
138142
if context.protocol_metadata is not None:

src/mcp/server/auth/handlers/token.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,20 @@ class RefreshTokenRequest(BaseModel):
4040
resource: str | None = Field(None, description="Resource indicator for the token")
4141

4242

43-
TokenRequest = Annotated[AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type")]
43+
class ClientCredentialsRequest(BaseModel):
44+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2
45+
grant_type: Literal["client_credentials"]
46+
scope: str | None = Field(None, description="Optional scope parameter")
47+
client_id: str
48+
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
49+
client_secret: str | None = None
50+
# RFC 8707 resource indicator
51+
resource: str | None = Field(None, description="Resource indicator for the token")
52+
53+
TokenRequest = Annotated[
54+
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest,
55+
Field(discriminator="grant_type"),
56+
]
4457
token_request_adapter = TypeAdapter[TokenRequest](TokenRequest)
4558

4659

@@ -216,4 +229,26 @@ async def handle(self, request: Request):
216229
except TokenError as e:
217230
return self.response(TokenErrorResponse(error=e.error, error_description=e.error_description))
218231

232+
case ClientCredentialsRequest():
233+
# Exchange client credentials for access token
234+
scope_str = token_request.scope or getattr(client_info, "scope", None) or ""
235+
scopes = scope_str.split(" ") if scope_str else []
236+
exchange = getattr(self.provider, "exchange_client_credentials", None)
237+
if exchange is None:
238+
return self.response(
239+
TokenErrorResponse(
240+
error="unsupported_grant_type",
241+
error_description="client_credentials is not supported by this authorization server",
242+
)
243+
)
244+
try:
245+
tokens = await exchange(client_info, scopes=scopes, resource=token_request.resource)
246+
except TokenError as e:
247+
return self.response(
248+
TokenErrorResponse(
249+
error=e.error,
250+
error_description=e.error_description,
251+
)
252+
)
253+
219254
return self.response(tokens)

0 commit comments

Comments
 (0)