diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 4aa91793..ca0d7323 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,11 +38,11 @@ jobs: - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Initialize CodeQL - uses: github/codeql-action/init@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/init@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/analyze@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index bb8b659d..c8035cb0 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -73,6 +73,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/upload-sarif@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: sarif_file: results.sarif diff --git a/auth/auth.go b/auth/auth.go index 40fa259f..f2058700 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -47,6 +47,21 @@ type RequireBearerTokenOptions struct { ResourceMetadataURL string // The required scopes. Scopes []string + // AllowMissingExpiration opts the middleware out of the + // `tokenInfo.Expiration.IsZero()` reject. Default false preserves the + // existing strict behaviour (every TokenInfo must carry an Expiration). + // + // Some IdPs emit session-bound bearer tokens that do not carry a standalone + // `exp` claim — the token's lifetime is bounded by an external session and + // is not advertised in-band. Resource servers integrating with such IdPs + // need to opt in to validating the rest of the token (scopes, signature + // via the verifier callback, etc.) without requiring the expiration field + // to be present. + // + // When enabled, the verifier is still responsible for any session-level + // validity check it can perform; this option only relaxes the middleware's + // own expiration enforcement. + AllowMissingExpiration bool } type tokenInfoKey struct{} @@ -131,9 +146,10 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO // Check expiration. if tokenInfo.Expiration.IsZero() { - return nil, "token missing expiration", http.StatusUnauthorized - } - if tokenInfo.Expiration.Before(time.Now()) { + if opts == nil || !opts.AllowMissingExpiration { + return nil, "token missing expiration", http.StatusUnauthorized + } + } else if tokenInfo.Expiration.Before(time.Now()) { return nil, "token expired", http.StatusUnauthorized } return tokenInfo, "", 0 diff --git a/auth/auth_test.go b/auth/auth_test.go index 4028c907..fe523a14 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -62,6 +62,11 @@ func TestVerify(t *testing.T) { "no expiration", nil, "Bearer noexp", "token missing expiration", 401, }, + { + "no expiration with AllowMissingExpiration accepts", + &RequireBearerTokenOptions{AllowMissingExpiration: true}, "Bearer noexp", + "", 0, + }, { "expired", nil, "Bearer expired", "token expired", 401, diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 47541f99..a3daeecb 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -50,6 +50,12 @@ type AuthorizationResult struct { Code string // State string returned by the authorization server. State string + // Iss is the issuer identifier returned by the authorization server in the + // authorization response per [RFC 9207]. The AuthorizationCodeFetcher should + // populate this from the "iss" query parameter in the redirect URI if present. + // + // [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 + Iss string } // AuthorizationArgs is the input to [AuthorizationCodeFetcher]. @@ -318,6 +324,9 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ // Purposefully leaving the error unwrappable so it can be handled by the caller. return err } + if err := validateIssuerResponse(authRes.Iss, asm.Issuer, asm.AuthorizationResponseIssParameterSupported); err != nil { + return err + } err = h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) if err != nil { @@ -498,6 +507,9 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * // 2. Attempt to use pre-registered client configuration. preCfg := h.config.PreregisteredClient if preCfg != nil { + if preCfg.Issuer != "" && !authutil.IssuersEqual(preCfg.Issuer, asm.Issuer) { + return nil, fmt.Errorf("authorization server issuer %q does not match pre-registered credentials issuer %q", asm.Issuer, preCfg.Issuer) + } authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) clientSecret := "" if preCfg.ClientSecretAuth != nil { @@ -560,6 +572,27 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg }, nil } +// validateIssuerResponse validates the "iss" parameter in an authorization response +// per [RFC 9207]. +// +// [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 +func validateIssuerResponse(iss, expectedIssuer string, issParameterSupported bool) error { + if issParameterSupported { + if iss == "" { + return fmt.Errorf("authorization server advertises RFC 9207 iss parameter support but none was received in the authorization response") + } + if iss != expectedIssuer { + return fmt.Errorf("authorization response issuer %q does not match expected issuer %q", iss, expectedIssuer) + } + } else { + if iss != "" { + return fmt.Errorf("authorization server does not advertise RFC 9207 iss parameter support but iss was received in the authorization response") + } + } + + return nil +} + // exchangeAuthorizationCode exchanges the authorization code for a token // and stores it in a token source. func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index 12520054..c84b0032 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -78,6 +78,7 @@ func TestAuthorize(t *testing.T) { return &AuthorizationResult{ Code: location.Query().Get("code"), State: location.Query().Get("state"), + Iss: location.Query().Get("iss"), }, nil }, }) @@ -176,6 +177,7 @@ func TestAuthorize_ScopeAccumulation(t *testing.T) { return &AuthorizationResult{ Code: loc.Query().Get("code"), State: loc.Query().Get("state"), + Iss: loc.Query().Get("iss"), }, nil }, }) @@ -607,6 +609,8 @@ func TestHandleRegistration(t *testing.T) { asm *oauthex.AuthServerMeta want *resolvedClientConfig wantError bool + issuerMatch bool + issuerSuffix string }{ { name: "ClientIDMetadataDocument", @@ -645,6 +649,79 @@ func TestHandleRegistration(t *testing.T) { authStyle: oauth2.AuthStyleInParams, }, }, + { + name: "Preregistered_IssuerMatch", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "", // set dynamically in the test + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInParams, + }, + issuerMatch: true, + }, + { + name: "Preregistered_IssuerMismatch", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "https://other-issuer.example.com", + }, + }, + wantError: true, + }, + { + name: "Preregistered_IssuerMatchTrailingSlash", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "", // set dynamically in the test (with trailing slash) + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInParams, + }, + issuerMatch: true, + issuerSuffix: "/", + }, { name: "NoneSupported", handlerConfig: &AuthorizationCodeHandlerConfig{ @@ -658,6 +735,10 @@ func TestHandleRegistration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig}) s.Start(t) + // Set the Issuer dynamically if requested by the test case. + if tt.issuerMatch { + tt.handlerConfig.PreregisteredClient.Issuer = s.URL() + tt.issuerSuffix + } tt.handlerConfig.AuthorizationCodeFetcher = func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { return nil, nil } @@ -677,6 +758,9 @@ func TestHandleRegistration(t *testing.T) { } return } + if tt.wantError { + t.Fatal("handleRegistration() expected error, got nil") + } if got.registrationType != tt.want.registrationType { t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, tt.want.registrationType) } @@ -736,6 +820,59 @@ func TestDynamicRegistration(t *testing.T) { } } +func TestValidateIssuerResponse(t *testing.T) { + const expectedIssuer = "https://auth.example.com" + + tests := []struct { + name string + iss string + issSupported bool + wantErr bool + wantErrContains string + }{ + { + name: "ValidIss", + iss: expectedIssuer, + issSupported: true, + }, + { + name: "WrongIss", + iss: "https://attacker.example.com", + issSupported: true, + wantErr: true, + wantErrContains: "does not match expected issuer", + }, + { + name: "MissingIssWhenRequired", + iss: "", + issSupported: true, + wantErr: true, + wantErrContains: "RFC 9207", + }, + { + name: "MissingIssWhenNotRequired", + iss: "", + issSupported: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateIssuerResponse(tt.iss, expectedIssuer, tt.issSupported) + if tt.wantErr { + if err == nil { + t.Fatalf("validateIssuerResponse() = nil, want error containing %q", tt.wantErrContains) + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("validateIssuerResponse() error = %q, want it to contain %q", err.Error(), tt.wantErrContains) + } + } else if err != nil { + t.Fatalf("validateIssuerResponse() unexpected error = %v", err) + } + }) + } +} + func TestInferApplicationType(t *testing.T) { tests := []struct { name string diff --git a/auth/extauth/client_credentials.go b/auth/extauth/client_credentials.go index b95fcaee..c65d8b06 100644 --- a/auth/extauth/client_credentials.go +++ b/auth/extauth/client_credentials.go @@ -128,6 +128,11 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ } } + creds := h.config.Credentials + if creds.Issuer != "" && !authutil.IssuersEqual(creds.Issuer, asm.Issuer) { + return fmt.Errorf("authorization server issuer %q does not match pre-registered credentials issuer %q", asm.Issuer, creds.Issuer) + } + // Determine requestedScopes: use PRM's scopes_supported if available. requestedScopes := scopesFromChallenges(wwwChallenges) if len(requestedScopes) == 0 && len(prm.ScopesSupported) > 0 { @@ -140,7 +145,6 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ requestedScopes = authutil.UnionScopes(h.grantedScopes[asm.Issuer], requestedScopes) // Step 3: Exchange client credentials for an access token. - creds := h.config.Credentials cfg := &clientcredentials.Config{ ClientID: creds.ClientID, ClientSecret: creds.ClientSecretAuth.ClientSecret, diff --git a/auth/extauth/client_credentials_test.go b/auth/extauth/client_credentials_test.go index 550e56af..b2973de7 100644 --- a/auth/extauth/client_credentials_test.go +++ b/auth/extauth/client_credentials_test.go @@ -221,6 +221,51 @@ func TestClientCredentialsHandler_Authorize(t *testing.T) { } }) + t.Run("issuer mismatch", func(t *testing.T) { + config := validClientCredentialsConfig() + config.Credentials.Issuer = "https://other-issuer.example.com" + handler, err := NewClientCredentialsHandler(config) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + err = handler.Authorize(t.Context(), req, resp) + if err == nil { + t.Fatal("expected Authorize to fail with issuer mismatch") + } + if !strings.Contains(err.Error(), "does not match") { + t.Errorf("error %q does not mention issuer mismatch", err.Error()) + } + }) + + t.Run("issuer match ignoring trailing slash", func(t *testing.T) { + config := validClientCredentialsConfig() + // authServer.URL() has no trailing slash; configure with one to + // verify the comparison tolerates the difference (per RFC 8414 §3.3 + // normalization applied in oauthex.IssuersEqual). + config.Credentials.Issuer = authServer.URL() + "/" + handler, err := NewClientCredentialsHandler(config) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("Authorize() unexpected error = %v", err) + } + }) + t.Run("PRM via resource_metadata in challenge", func(t *testing.T) { prmMux := http.NewServeMux() prmMux.Handle("/custom-prm", auth.ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 947b095e..55a55ade 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -66,6 +66,11 @@ func init() { "auth/token-endpoint-auth-basic", "auth/token-endpoint-auth-post", "auth/token-endpoint-auth-none", + "auth/iss-supported", + "auth/iss-not-advertised", + "auth/iss-supported-missing", + "auth/iss-wrong-issuer", + "auth/iss-unexpected", } for _, scenario := range authScenarios { registerScenario(scenario, runAuthClient) @@ -232,6 +237,7 @@ func fetchAuthorizationCodeAndState(ctx context.Context, args *auth.Authorizatio return &auth.AuthorizationResult{ Code: locURL.Query().Get("code"), State: locURL.Query().Get("state"), + Iss: locURL.Query().Get("iss"), }, nil } diff --git a/design/mrtr.md b/design/mrtr.md new file mode 100644 index 00000000..16321a46 --- /dev/null +++ b/design/mrtr.md @@ -0,0 +1,264 @@ +## Context + +A proposal for implementing Multi Round-Trip Requests +(MRTR) as defined in [SEP-2322](https://github.com/CaitieM20/modelcontextprotocol/blob/de6d76fba3078eda957dadb3cec51ca8ab851b5c/seps/2322-MRTR.md). + +In the new protocol version servers can't initiate requests to clients, but when a server requires additional input for completing `tools/call`, `prompts/get`, or `resources/read` it can return an incomplete result along with a set of `inputRequests`. The client fulfills them locally and retries the same call with `inputResponses` attached. + +## Goals + +**Must have:** + +* Backward compatibility. +* Correct representation on the wire. + +**Nice to have:** + +* Minimal changes to the exported API surface. +* Hard for server implementers to construct an invalid payload. +* Simple input request handling for clients. +* Protocol-version-independent code. +* Consistency with the rest of the SDK. + +## Proposal + +`ServerSession` methods return an error for new-version protocol connections. + +`InputRequest`/`InputResponse` is introduced as a sealed-interface: +```go +// Implemented by *ElicitParams, *CreateMessageParams, *ListRootsParams +type InputRequest interface{ isInputRequest() } + +type InputRequestMap map[string]InputRequest +// MarshalJSON encodes as map[string]struct{ Method string; Params InputRequest } +func (m InputRequestMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Params InputRequest } +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { ... } + +// Implemented by *ElicitResult, *CreateMessageResult, *ListRootsResult. +type InputResponse interface{ isInputResult() } + +type InputResponseMap map[string]InputResponse +// MarshalJSON encodes as map[string]struct{ Method string; Result InputResponse } +func (m InputResponseMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Result InputResponse } +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { ... } +``` + +All affected methods' `*Params` are extended with `InputResponseMap` and `RequestState` fields: +```go +type CallToolParams struct { + ... + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + RequestState string `json:"requestState,omitempty"` +} +// Same for GetPromptParams, ReadResourceParams +``` + +`InputRequests` and `RequestState` fields are added directly to `CallToolResult`, `GetPromptResult`, and `ReadResourceResult` as exported. +Result type discriminator (completed, input_required) is unexported so that SDK users don't need to set it to the correct constant in addition to setting either `Content` or `InputRequests`. Handler execution result is validated and augmented before marshaling: +```go +type CallToolResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` + resultType string // set by the SDK and used in MarshalJSON() +} +// Same for GetPromptResult, ReadResourceResult. +``` +Alternatively, the field could only exist on `wire struct`, but this would make us return `complete` to older clients or empty string to newer clients, because there's no access to negotiated protocol version in `MarshalJSON`. + +Servers request additional input by constructing a correct struct literal: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return &mcp.CallToolResult{ + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` +The SDK validates at runtime that a handler does not return both content and `InputRequests` — doing so logs a warning and returns a `CodeInternalError` JSON-RPC error. + +An unexported receiving middleware is installed on the server for backward compatibility with older clients. When a handler returns `InputRequests` and the connected client uses a protocol version that does not support MRTR, the middleware fulfills the requests by calling `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the client directly and reinvokes the handler once with the collected `InputResponses`. If any of these calls fail, the entire request fails. Input requests are fulfilled concurrently. This lets server developers write protocol-version-independent code. + +An unexported sending middleware is installed on the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. Clients have an option to disable it and write a retry loop manually using `NeedsInput()`: +```go +type MultiRoundTripOptions struct { + Disabled bool +} + +client := mcp.NewClient(impl, &mcp.ClientOptions{MultiRoundTrip: &mcp.MultiRoundTripOptions{Disabled: true}}) +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if result.NeedsInput() { ... } +``` + +`NeedsInput()` checks the unexported `resultType` field rather than `InputRequests`, correctly handling the load-shedding case where the server returns `input_required` with an empty map. + +**Pros** + +This is arguably the simplest and the most transparent approach which is also closest to the spec. +What gets explicitly set on the server can be observed on the wire and on the client. +The opt-out client middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. +The server middleware makes handler code protocol-version-independent — the same handler works for both old and new clients. + +**Cons** + +The biggest downside of the proposal is that server developers can construct incorrect responses (both content and input requests) and this will only be validated at runtime. + +## Alternatives considered + +### Unexported fields + +MRTR fields can be unexported, accessible only through getters, constructible only through constructor functions, and handled explicitly in custom `(Unm|M)arshalJSON`. This will make it impossible for developers to construct incorrect responses and for clients to perform an erroneous `len(result.InputRequests) > 0` check in the load-shedding case. +```go +type CallToolResult struct { + ... + inputRequests InputRequestMap + requestState string + resultType string +} + +func (r *CallToolResult) InputRequests() (InputRequestMap, bool) { ... } + +// InputRequiredResult struct exists for backward-compatibility in case of new fields being needed for input request results. +type InputRequiredResult struct { + InputRequests InputRequestMap + RequestState string +} + +// RequireInput constructs a tool call, prompt or resource result with input requests set. +// mrtrResult provides methods for setting private fields on these types. +func RequireInput[T any, TP interface { *T; mrtrResult }](r InputRequiredResult) TP { ... } +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return mcp.RequireInput[mcp.CallToolResult](mcp.InputRequiredResult{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }), nil, nil + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if requests, ok := result.InputRequests(); ok { ... } +``` + +The biggest downside of this approach is the obscure data model with hidden fields. An incomplete `mcp.CallToolResult` looks like an uninitialized struct until `InputRequests` method result is examined. +In addition to this, the verbose `RequireInput` syntax (no auto type inference from assignment target) does not look idiomatic and fits poorly into the existing SDK APIs. + +--- + +### `InputRequiredError` type + +We could explore a different data channel - `error` return value. This would give us the natural "happy path is when all inputs are provided" flow on the server side, and good result interpretability on the client side (impossible to confuse with a successful response). +The new error could be converted to the correct wire representation at the marshaling stage. +```go +type InputRequiredError struct { + InputRequests InputRequestMap + RequestState string +} + +func (e *InputRequiredError) Error() string { + return fmt.Sprintf("input required: %d request(s)", len(e.InputRequests)) +} +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return nil, zero, &mcp.InputRequiredError{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + } + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +var inputReqErr *mcp.InputRequiredError +if errors.As(err, &inputReqErr) { ... } +``` + +The downsides of this approach are: +* The drift from the protocol, where MRTR is not an error flow. +* Obscure "customError -> non-error protocol type on wirte -> customError" data lifecycle. +* Things get confusing for error-processing middleware. + +--- + +### New functions + +We could introduce new functions with a different handler signature where the return type is a sealed interface. This would give us compiler-enforced correctness for values constructed by tool handlers and clients would be forced to unpack `mcp.RoundTripCallToolResult` and make a concious decision for how to handle it. +```go +type RoundTripToolHandler func(context.Context, *CallToolRequest) (RoundTripCallToolResult, error) +type RoundTripToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (RoundTripCallToolResult, Out, error) + +// RoundTripCallToolResult is implemented by CallToolResult and IncompleteResult +type RoundTripCallToolResult interface { isMRTRResult() } + +type IncompleteResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` +} + +func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) +func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) +``` + +`Server.AddTool` wraps the old `ToolHandler` into a `RoundTripToolHandler` to update its function signature: +```go +mcp.AddRoundTripTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (mcp.RoundTripCallToolResult, MyOut, error) { + if needsInput(in) { + return &mcp.IncompleteResult{ + ResultType: mcp.ResultTypeInputRequired, + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +The downsides of this approach are: +* SEP suggests `ResultType` will potentially be extended with new values, `RoundTrip` in new function names will not allow us to cleanly extend the sealed interface with new types. But an overly generic name for new functions will make the API use-case less clear. +* Different code +* SDK takes the same action (puts it on the wire) regardless of the returned type, it exists only for enforcing correctness of the user code. +* Exported API surface bloat: +7 exported functions. + +--- + +### Exported Middleware + +We could flip "unexported MRTR middleware with opt-out option" to "exported middleware with opt-in requirement". +```go +func AutoMRTR(opts *MultiRoundTripOptions) Middleware { ... } +type MultiRoundTripOptions struct { + MaxRetries int +} +client := mcp.NewClient(impl, nil) +client.AddSendingMiddleware(mcp.AutoMRTR(&mcp.MultiRoundTripOptions{ + MaxRetries: 5, +})) +``` +This would change semantics of `*Handler` fields - depending on the protocol version in use, an extra initialization step will be required for them to "take effect". + +--- + +### Server API protocol version bridging + +Converting `ServerSession.Elicit`/`CreateMessage`/`ListRoots` calls into MRTR wire format transparently (suspend the handler, return `input_required`, resume on retry). Rejected because of a significant implementation effort and the fact that it contradicts the design goal of MRTR where servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. + diff --git a/docs/mcpgodebug.md b/docs/mcpgodebug.md index 36eddb52..f1e62373 100644 --- a/docs/mcpgodebug.md +++ b/docs/mcpgodebug.md @@ -27,6 +27,12 @@ Options listed below were added and will be removed in the 1.9.0 version of the Params), restoring the previous behavior. The default behavior was changed to align with SEP-2164 and the JSON-RPC specification. +- `hintomitempty` added. If set to `1`, `ToolAnnotations` JSON marshaling + will omit `ReadOnlyHint` and `IdempotentHint` when their value is `false`, + restoring the previous behavior. The default behavior was changed to always + serialize these fields, since their Go types are bare `bool` (not `*bool`) + and omitting `false` made it indistinguishable from unset. + - `allowsessionsinstateless` added. If set to `1`, stateless streamable HTTP servers will read the `Mcp-Session-Id` request header (or generate one via `GetSessionID`), set it on response headers, and accept `DELETE` requests, diff --git a/docs/protocol.md b/docs/protocol.md index cfcdb855..f45d1dfc 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -15,6 +15,7 @@ 1. [Token Passthrough](#token-passthrough) 1. [Server-Side Request Forgery](#server-side-request-forgery) 1. [Session Hijacking](#session-hijacking) + 1. [Issuer Mix-Up](#issuer-mix-up) 1. [Utilities](#utilities) 1. [Cancellation](#cancellation) 1. [Ping](#ping) @@ -322,6 +323,7 @@ This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) - [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) +- [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207) Authorization Server Issuer Identification To use it, configure the handler and assign it to the transport: @@ -333,11 +335,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Open the args.URL in a browser and return the resulting code and state. + // Open the args.URL in a browser and return the resulting code, state, and iss. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + iss := ... // "iss" query parameter from the redirect URI (RFC 9207) + return &auth.AuthorizationResult{Code: code, State: state, Iss: iss}, nil }, }) @@ -490,6 +493,22 @@ sets `UserID` on the returned `TokenInfo`, the streamable transport will: `TokenInfo.UserID` to enable this protection. This prevents an attacker with a valid token from hijacking another user's session by guessing or obtaining their session ID. +### Issuer Mix-Up + +The [mitigation](https://www.rfc-editor.org/rfc/rfc9207) against issuer mix-up attacks is +implemented per [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207). The SDK client validates +the `iss` parameter in authorization responses to ensure they originated from the expected +authorization server: + +- If `iss` is present in the redirect URI, the SDK verifies it matches the issuer from the + authorization server's metadata. A mismatch results in an error. +- If `iss` is absent but the authorization server advertises + `authorization_response_iss_parameter_supported: true` in its [RFC 8414](https://www.rfc-editor.org/rfc/rfc8414) + metadata, the SDK rejects the response with an error. + +The `AuthorizationCodeFetcher` is responsible for extracting the `iss` query parameter from +the redirect URI and returning it in [`AuthorizationResult.Iss`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationResult). + ## Utilities ### Cancellation diff --git a/docs/rough_edges.md b/docs/rough_edges.md index 7ee18b1a..c1a618df 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -63,3 +63,8 @@ v2. - `StreamableHTTPOptions.CrossOriginProtection` should not have been part of the SDK API. Cross-origin protection is a general HTTP concern, not specific to MCP, and can be applied as standard HTTP middleware. + +- `ToolAnnotations` (`mcp/protocol.go`) should have all fields typed as `*bool` + for full control to define what is being sent over the wire. Different + MCP clients have different requirements, and some of them require all fields + to be explicitly set to either `true` or `false`. diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index f514ebae..7e5d4645 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -36,6 +36,7 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { r.authChan <- &auth.AuthorizationResult{ Code: req.URL.Query().Get("code"), State: req.URL.Query().Get("state"), + Iss: req.URL.Query().Get("iss"), } fmt.Fprint(w, "Authentication successful. You can close this window.") }) diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod index 8690256a..f1bfc378 100644 --- a/examples/server/auth-middleware/go.mod +++ b/examples/server/auth-middleware/go.mod @@ -1,16 +1,16 @@ module auth-middleware-example -go 1.23.0 +go 1.25.0 require ( - github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/modelcontextprotocol/go-sdk v0.3.0 ) require ( github.com/google/jsonschema-go v0.4.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect ) replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum index d257e104..a14b1344 100644 --- a/examples/server/auth-middleware/go.sum +++ b/examples/server/auth-middleware/go.sum @@ -1,5 +1,6 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= @@ -9,5 +10,6 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index 61c8788c..dc375543 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -1,6 +1,6 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting -go 1.23.0 +go 1.25.0 require ( github.com/modelcontextprotocol/go-sdk v0.3.0 diff --git a/go.mod b/go.mod index 860d6fa0..3287a957 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( require ( github.com/segmentio/asm v1.1.3 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect ) diff --git a/go.sum b/go.sum index 377a7b11..c13454aa 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= diff --git a/internal/authutil/util.go b/internal/authutil/util.go new file mode 100644 index 00000000..880d1bfa --- /dev/null +++ b/internal/authutil/util.go @@ -0,0 +1,13 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package authutil + +import "strings" + +// IssuersEqual reports whether two OAuth 2.0 authorization server issuer +// identifiers refer to the same server comparing them without the final trailing slash. +func IssuersEqual(a, b string) bool { + return strings.TrimSuffix(a, "/") == strings.TrimSuffix(b, "/") +} diff --git a/internal/authutil/util_test.go b/internal/authutil/util_test.go new file mode 100644 index 00000000..9d3ff7a4 --- /dev/null +++ b/internal/authutil/util_test.go @@ -0,0 +1,29 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package authutil + +import "testing" + +func TestIssuersEqual(t *testing.T) { + tests := []struct { + a, b string + want bool + }{ + {"https://issuer.example.com", "https://issuer.example.com", true}, + {"https://issuer.example.com/", "https://issuer.example.com", true}, + {"https://issuer.example.com", "https://issuer.example.com/", true}, + {"https://issuer.example.com/", "https://issuer.example.com/", true}, + {"https://issuer.example.com/tenant", "https://issuer.example.com/tenant", true}, + {"https://issuer.example.com/tenant/", "https://issuer.example.com/tenant", true}, + {"https://issuer.example.com", "https://other.example.com", false}, + {"https://issuer.example.com/a", "https://issuer.example.com/b", false}, + {"", "", true}, + } + for _, tt := range tests { + if got := IssuersEqual(tt.a, tt.b); got != tt.want { + t.Errorf("IssuersEqual(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want) + } + } +} diff --git a/internal/docs/mcpgodebug.src.md b/internal/docs/mcpgodebug.src.md index b5f5a781..88639a26 100644 --- a/internal/docs/mcpgodebug.src.md +++ b/internal/docs/mcpgodebug.src.md @@ -26,6 +26,12 @@ Options listed below were added and will be removed in the 1.9.0 version of the Params), restoring the previous behavior. The default behavior was changed to align with SEP-2164 and the JSON-RPC specification. +- `hintomitempty` added. If set to `1`, `ToolAnnotations` JSON marshaling + will omit `ReadOnlyHint` and `IdempotentHint` when their value is `false`, + restoring the previous behavior. The default behavior was changed to always + serialize these fields, since their Go types are bare `bool` (not `*bool`) + and omitting `false` made it indistinguishable from unset. + - `allowsessionsinstateless` added. If set to `1`, stateless streamable HTTP servers will read the `Mcp-Session-Id` request header (or generate one via `GetSessionID`), set it on response headers, and accept `DELETE` requests, diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 593742c7..4044a6fd 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -247,6 +247,7 @@ This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) - [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) +- [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207) Authorization Server Issuer Identification To use it, configure the handler and assign it to the transport: @@ -258,11 +259,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Open the args.URL in a browser and return the resulting code and state. + // Open the args.URL in a browser and return the resulting code, state, and iss. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + iss := ... // "iss" query parameter from the redirect URI (RFC 9207) + return &auth.AuthorizationResult{Code: code, State: state, Iss: iss}, nil }, }) @@ -415,6 +417,22 @@ sets `UserID` on the returned `TokenInfo`, the streamable transport will: `TokenInfo.UserID` to enable this protection. This prevents an attacker with a valid token from hijacking another user's session by guessing or obtaining their session ID. +### Issuer Mix-Up + +The [mitigation](https://www.rfc-editor.org/rfc/rfc9207) against issuer mix-up attacks is +implemented per [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207). The SDK client validates +the `iss` parameter in authorization responses to ensure they originated from the expected +authorization server: + +- If `iss` is present in the redirect URI, the SDK verifies it matches the issuer from the + authorization server's metadata. A mismatch results in an error. +- If `iss` is absent but the authorization server advertises + `authorization_response_iss_parameter_supported: true` in its [RFC 8414](https://www.rfc-editor.org/rfc/rfc8414) + metadata, the SDK rejects the response with an error. + +The `AuthorizationCodeFetcher` is responsible for extracting the `iss` query parameter from +the redirect URI and returning it in [`AuthorizationResult.Iss`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationResult). + ## Utilities ### Cancellation diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index d573e141..90751f74 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -62,3 +62,8 @@ v2. - `StreamableHTTPOptions.CrossOriginProtection` should not have been part of the SDK API. Cross-origin protection is a general HTTP concern, not specific to MCP, and can be applied as standard HTTP middleware. + +- `ToolAnnotations` (`mcp/protocol.go`) should have all fields typed as `*bool` + for full control to define what is being sent over the wire. Different + MCP clients have different requirements, and some of them require all fields + to be explicitly set to either `true` or `false`. \ No newline at end of file diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go index b27e00ed..c180f862 100644 --- a/internal/oauthtest/fake_authorization_server.go +++ b/internal/oauthtest/fake_authorization_server.go @@ -15,6 +15,7 @@ import ( "maps" "net/http" "net/http/httptest" + "net/url" "slices" "testing" @@ -178,6 +179,8 @@ func (s *FakeAuthorizationServer) handleMetadata(w http.ResponseWriter, r *http. CodeChallengeMethodsSupported: []string{"S256"}, ClientIDMetadataDocumentSupported: cimdSupported, TokenEndpointAuthMethodsSupported: []string{"client_secret_post", "client_secret_basic"}, + // Advertise RFC 9207 support: the authorize endpoint includes "iss" in responses. + AuthorizationResponseIssParameterSupported: true, } // Set CORS headers for cross-origin client discovery. w.Header().Set("Access-Control-Allow-Origin", "*") @@ -267,8 +270,9 @@ func (s *FakeAuthorizationServer) handleAuthorize(w http.ResponseWriter, r *http } state := r.URL.Query().Get("state") + issuer := s.URL() + s.config.IssuerPath - redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, state) + redirectURL := fmt.Sprintf("%s?code=%s&state=%s&iss=%s", redirectURI, code, state, url.QueryEscape(issuer)) http.Redirect(w, r, redirectURL, http.StatusFound) } diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..979172ba 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -58,13 +58,17 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } - return &Client{ + c := &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } + if opts.MultiRoundTrip == nil || !opts.MultiRoundTrip.Disabled { + c.AddSendingMiddleware(clientMultiRoundTripMiddleware()) + } + return c } // ClientOptions configures the behavior of the client. @@ -154,10 +158,21 @@ type ClientOptions struct { ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) LoggingMessageHandler func(context.Context, *LoggingMessageRequest) ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // MultiRoundTrip configures the automatic MultiRoundTrip (Multi Round-Trip Requests) middleware. + // By default (nil), the middleware is enabled with default settings. + // Set Disabled to true to opt out of automatic MultiRoundTrip handling. + MultiRoundTrip *MultiRoundTripOptions // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // KeepAliveFailureThreshold is the number of consecutive keepalive ping + // failures tolerated before the session is closed. A value of 0 or 1 + // closes the session on the first failure (the default). Higher values + // align with the spec's "multiple failed pings MAY trigger a connection + // reset" guidance, letting a transient miss pass without tearing down an + // otherwise live session. Has no effect unless KeepAlive is non-zero. + KeepAliveFailureThreshold int } // toolContextKeyType is the context key type for passing tool definitions @@ -433,7 +448,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await // startKeepalive starts the keepalive mechanism for this client session. func (cs *ClientSession) startKeepalive(interval time.Duration) { - startKeepalive(cs, interval, &cs.keepaliveCancel, cs.client.opts.Logger) + startKeepalive(cs, interval, cs.client.opts.KeepAliveFailureThreshold, &cs.keepaliveCancel, cs.client.opts.Logger) } // AddRoots adds the given roots to the client, diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index ba263e6f..86bb86ac 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -223,4 +224,4 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(mcp.CallToolResult{}, mcp.GetPromptResult{}, mcp.ReadResourceResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..5c8e7d12 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -206,7 +207,7 @@ func TestEndToEnd(t *testing.T) { Role: "user", }}, } - if diff := cmp.Diff(wantReview, gotReview); diff != "" { + if diff := cmp.Diff(wantReview, gotReview, ctrCmpOpts...); diff != "" { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } @@ -1919,6 +1920,82 @@ func TestKeepAliveFailure_Logged(t *testing.T) { }) } +// scriptedKeepaliveSession is a keepaliveSession test double whose Ping +// returns errors from a script (one entry consumed per call; the last entry +// repeats once exhausted), and records how many times Close was called. Ping +// returns immediately so the keepalive loop's pace is driven purely by the +// ticker, making the test deterministic under synctest. +type scriptedKeepaliveSession struct { + pingErrs []error + pingCalls atomic.Int64 + closeCalls atomic.Int64 +} + +func (s *scriptedKeepaliveSession) Ping(context.Context, *PingParams) error { + n := int(s.pingCalls.Add(1)) - 1 + if n >= len(s.pingErrs) { + n = len(s.pingErrs) - 1 + } + return s.pingErrs[n] +} + +func (s *scriptedKeepaliveSession) Close() error { + s.closeCalls.Add(1) + return nil +} + +// TestStartKeepalive_FailureThreshold verifies that the session is kept alive +// across consecutive ping failures below the threshold and only closed once the +// threshold is reached. +func TestStartKeepalive_FailureThreshold(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const interval = 100 * time.Millisecond + sess := &scriptedKeepaliveSession{pingErrs: []error{errors.New("boom")}} + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + var cancel context.CancelFunc + startKeepalive(sess, interval, 3, &cancel, logger) + defer cancel() + + // After two ticks → two failures, still below threshold 3: not closed. + time.Sleep(2*interval + interval/2) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 0 { + t.Fatalf("session closed below threshold: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + + // Third tick → third failure reaches threshold: session closed. + time.Sleep(interval) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 1 { + t.Fatalf("expected one Close at threshold, got closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + }) +} + +// TestStartKeepalive_SuccessResetsFailures verifies that a successful ping +// resets the consecutive-failure counter, so an isolated failure between +// successes never accumulates toward the threshold. +func TestStartKeepalive_SuccessResetsFailures(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const interval = 100 * time.Millisecond + // fail, success, fail, fail, then success (the tail repeats): the run + // never has 3 consecutive failures, so the session is never closed. + sess := &scriptedKeepaliveSession{pingErrs: []error{ + errors.New("boom"), nil, errors.New("boom"), errors.New("boom"), nil, + }} + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + var cancel context.CancelFunc + startKeepalive(sess, interval, 3, &cancel, logger) + defer cancel() + + time.Sleep(6 * interval) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 0 { + t.Fatalf("session closed despite a success resetting the counter: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + }) +} + func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. @@ -2371,4 +2448,4 @@ func TestSetErrorPreservesContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} diff --git a/mcp/mrtr.go b/mcp/mrtr.go new file mode 100644 index 00000000..8ad5f1ae --- /dev/null +++ b/mcp/mrtr.go @@ -0,0 +1,268 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/sync/errgroup" +) + +const maxMultiRoundTripRetries = 10 +const maxLoadSheddingMultiRoundTripRetries = 3 + +// MultiRoundTripOptions configures the client-side multi round-trip request (SEP-2322) +// middleware. The middleware is enabled by default and automatically fulfills input +// requests from the server by invoking the appropriate client handlers and +// retrying the original call. +type MultiRoundTripOptions struct { + // Disabled prevents the automatic multi-round-tirp middleware from being installed. + // When true, the client returns input-required results directly and callers must + // handle the retry loop themselves using [CallToolResult.NeedsInput], + // [GetPromptResult.NeedsInput], or [ReadResourceResult.NeedsInput]. + Disabled bool +} + +type multiRoundTripResponse interface { + setResultType(resultType) + inputRequests() map[string]InputRequest + requestState() string + hasContent() bool +} + +func handleMultiRoundTripResult(ss *ServerSession, logger *slog.Logger, res multiRoundTripResponse) error { + if res == nil { + return nil + } + hasInputRequests := res.inputRequests() != nil + + if hasInputRequests && res.hasContent() { + logger.Warn("handler returned both content and inputRequests") + return &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server bug: result has both content and inputRequests", + } + } + + if clientSupportsMultiRoundTrip(ss) { + // For older clients the resultType is left unset. Input requests will be handled + // by serverMultiRoundTripMiddleware client calls and handler reinvocation. + if hasInputRequests { + res.setResultType(resultTypeInputRequired) + } else { + res.setResultType(resultTypeComplete) + } + } + return nil +} + +func clientSupportsMultiRoundTrip(ss *ServerSession) bool { + protocolVersion := latestProtocolVersion + if iparams := ss.InitializeParams(); iparams != nil { + protocolVersion = iparams.ProtocolVersion + } + return protocolVersion >= protocolVersion20260630 +} + +func clientMultiRoundTripMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + loadSheddingFailures := 0 + for retries := 1; ; retries++ { + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + mrtrResult, ok := res.(multiRoundTripResponse) + if !ok { + return res, nil + } + reqMap := mrtrResult.inputRequests() + if reqMap == nil { + return res, nil + } + if len(reqMap) == 0 { + loadSheddingFailures++ + } + if loadSheddingFailures >= maxLoadSheddingMultiRoundTripRetries { + return nil, fmt.Errorf("multi-round-trip: exceeded maximum load-shedding retries (%d)", maxLoadSheddingMultiRoundTripRetries) + } + if retries >= maxMultiRoundTripRetries { + return nil, fmt.Errorf("multi-round-trip: exceeded maximum retries (%d)", maxMultiRoundTripRetries) + } + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, nil + } + responses, err := fulfillInputRequests(ctx, cs, reqMap) + if err != nil { + return nil, err + } + setMultiRoundTripRetryParams(req, responses, mrtrResult.requestState()) + } + } + } +} + +// serverMultiRoundTripMiddleware is a receiving middleware for servers that transparently +// handles multi-round-trip for clients on older protocol versions. When a handler returns +// InputRequests and the client does not support multi-round-trip, the middleware fulfills +// the requests by calling the client directly and reinvokes the handler once with the responses. +func serverMultiRoundTripMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + ss, ok := req.GetSession().(*ServerSession) + if !ok { + return next(ctx, method, req) + } + if clientSupportsMultiRoundTrip(ss) { + return next(ctx, method, req) + } + + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + mrtrResult, ok := res.(multiRoundTripResponse) + if !ok { + return res, nil + } + reqMap := mrtrResult.inputRequests() + if reqMap == nil { + return res, nil + } + if len(reqMap) == 0 { + return nil, fmt.Errorf("the server is busy, retry later") + } + responses, err := fulfillServerInputRequests(ctx, ss, reqMap) + if err != nil { + return nil, err + } + setMultiRoundTripRetryParams(req, responses, mrtrResult.requestState()) + return next(ctx, method, req) + } + } +} + +func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillServerInputRequest(ctx, ss, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("multi-round-trip: %w", err) + } + return responses, nil +} + +func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return ss.Elicit(ctx, p) + case *CreateMessageParams: + return ss.CreateMessageWithTools(ctx, createMessageParamsToWithTools(p)) + case *CreateMessageWithToolsParams: + return ss.CreateMessageWithTools(ctx, p) + case *ListRootsParams: + return ss.ListRoots(ctx, p) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} + +func createMessageParamsToWithTools(p *CreateMessageParams) *CreateMessageWithToolsParams { + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) + } + return &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } +} + +func setMultiRoundTripRetryParams(req Request, responses InputResponseMap, state string) { + switch p := req.GetParams().(type) { + case *CallToolParams: + p.InputResponses = responses + p.RequestState = state + case *CallToolParamsRaw: + p.InputResponses = responses + p.RequestState = state + case *GetPromptParams: + p.InputResponses = responses + p.RequestState = state + case *ReadResourceParams: + p.InputResponses = responses + p.RequestState = state + } +} + +func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillInputRequest(ctx, cs, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("multi round-trip: %w", err) + } + return responses, nil +} + +func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return cs.client.elicit(ctx, newClientRequest(cs, p)) + case *CreateMessageParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: createMessageParamsToWithTools(p)}) + case *CreateMessageWithToolsParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: p}) + case *ListRootsParams: + return cs.client.listRoots(ctx, newClientRequest(cs, p)) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go new file mode 100644 index 00000000..12eba7ea --- /dev/null +++ b/mcp/mrtr_test.go @@ -0,0 +1,546 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "slices" + "sync/atomic" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" +) + +func TestMultiRoundTrip_ManualRetry(t *testing.T) { + type deployResult struct { + Deployed bool `json:"deployed"` + Reason string `json:"reason,omitempty"` + } + + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }, nil, nil + } + + resp, ok := req.Params.InputResponses["confirm"] + if !ok { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Please confirm (retry)"}}, + }, nil, nil + } + + if req.Params.RequestState == "" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "no_state"}, nil + } + if elicitResult := resp.(*ElicitResult); elicitResult != nil && elicitResult.Action != "accept" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil + } + + return &CallToolResult{}, &deployResult{Deployed: true}, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + // Round 1: initiate deployment + res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if got := len(res.InputRequests); got != 1 { + t.Fatalf("len(res.InputRequests) = %d, want 1", got) + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("res.InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + // Round 2: retry with confirmation + res, err = conn.CallTool(ctx, &CallToolParams{ + Name: "deploy", + InputResponses: InputResponseMap{ + "confirm": &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, + }, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("CallTool() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + + if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } +} + +func TestMultiRoundTrip_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + tests := []struct { + name string + inputRequests InputRequestMap + wantResult map[string]any + }{ + { + name: "elicit", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Deploy?"}, + }, + wantResult: map[string]any{"ids": []any{"confirm"}}, + }, + { + name: "createMessage", + inputRequests: InputRequestMap{ + "summarize": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "summarize"}}}, + MaxTokens: 100, + }, + }, + wantResult: map[string]any{"ids": []any{"summarize"}}, + }, + { + name: "listRoots", + inputRequests: InputRequestMap{ + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"roots"}}, + }, + { + name: "all three", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "draft": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "write"}}}, + MaxTokens: 50, + }, + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"confirm", "draft", "roots"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + inputRequests := tt.inputRequests + AddTool(srv, &Tool{Name: "act"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil, nil + } + // Collect the IDs of fulfilled responses. + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + return &CallToolResult{}, map[string]any{"ids": ids}, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "response"}, + }, nil + }, + }) + conn.client.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + + res, err := conn.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + + // Sort the expected IDs for stable comparison. + if wantIDs, ok := tt.wantResult["ids"].([]any); ok { + slices.SortFunc(wantIDs, func(a, b any) int { + if a.(string) < b.(string) { + return -1 + } + return 1 + }) + } + + if diff := cmp.Diff(tt.wantResult, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestMultiRoundTrip_MaxRetries(t *testing.T) { + testCases := []struct { + name string + requests InputRequestMap + wantRetries int + }{ + { + name: "load shedding", + requests: InputRequestMap{}, + wantRetries: maxLoadSheddingMultiRoundTripRetries, + }, + { + name: "input request", + requests: InputRequestMap{"confirm": &ElicitParams{Message: "Again?"}}, + wantRetries: maxMultiRoundTripRetries, + }, + } + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + var serverCalls atomic.Int32 + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "loop"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + serverCalls.Add(1) + return &CallToolResult{InputRequests: tc.requests, RequestState: "loop-state"}, nil, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + _, err := conn.CallTool(ctx, &CallToolParams{Name: "loop"}) + if err == nil { + t.Fatal("CallTool() err = nil, want error for exceeded max retries") + } + if serverCalls.Load() != int32(tc.wantRetries) { + t.Errorf("serverCalls = %d, want %d", serverCalls.Load(), tc.wantRetries) + } + }) + } +} + +func TestMultiRoundTrip_ServerMiddleware(t *testing.T) { + // multiRoundTripToolHandler returns a ToolHandler (plain, non-generic) that requests + // the given inputRequests on the first call and returns the fulfilled + // response IDs on the second. + multiRoundTripToolHandler := func(inputRequests InputRequestMap) ToolHandler { + return func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil + } + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + content := &TextContent{Text: fmt.Sprintf("%v", ids)} + return &CallToolResult{Content: []Content{content}}, nil + } + } + + tests := []struct { + name string + inputRequests InputRequestMap + wantText string + }{ + { + name: "elicit via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Sure?"}, + }, + wantText: "[confirm]", + }, + { + name: "elicit and listRoots via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "roots": &ListRootsParams{}, + }, + wantText: "[confirm roots]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddTool( + &Tool{Name: "act", InputSchema: &jsonschema.Schema{Type: "object"}}, + multiRoundTripToolHandler(tt.inputRequests), + ) + + // Connect with an OLD protocol version where multi-round-trip is not supported. + // The server middleware should handle it transparently. + st, ct := NewInMemoryTransports() + ss, err := srv.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + c.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if got := res.Content[0].(*TextContent).Text; got != tt.wantText { + t.Errorf("result text = %q, want %q", got, tt.wantText) + } + }) + } +} + +func TestMultiRoundTrip_GetPrompt_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } + if got := res.Messages[0].Content.(*TextContent).Text; got != "review this code" { + t.Errorf("message text = %q, want %q", got, "review this code") + } +} + +func TestMultiRoundTrip_GetPrompt_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + res, err = conn.GetPrompt(ctx, &GetPromptParams{ + Name: "review", + InputResponses: InputResponseMap{"confirm": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("GetPrompt() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } +} + +func TestMultiRoundTrip_ReadResource_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } + if got := res.Contents[0].Text; got != "resource data" { + t.Errorf("resource text = %q, want %q", got, "resource data") + } +} + +func TestMultiRoundTrip_ReadResource_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["auth"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[auth] type = %T, want *ElicitParams", res.InputRequests["auth"]) + } + + res, err = conn.ReadResource(ctx, &ReadResourceParams{ + URI: "test://data", + InputResponses: InputResponseMap{"auth": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("ReadResource() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } +} + +func mustConnect(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { + t.Helper() + st, ct := NewInMemoryTransports() + ss, err := s.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = ss.Close() + }) + + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = cs.Close() + }) + return cs +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 1646788a..f401d220 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -13,6 +13,178 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) +// resultType indicates whether a result is complete or requires further input +// from the client via the multi round-trip request protocol. +type resultType string + +const ( + // resultTypeComplete indicates the result is final. + // This is the default when ResultType is empty. + resultTypeComplete resultType = "complete" + + // resultTypeInputRequired indicates the server needs additional client + // input before it can complete the request. The client should fulfill the + // InputRequests and retry the call with the responses. + resultTypeInputRequired resultType = "input_required" +) + +// InputRequest is a type for parameters that a server can include in the response +// to request input from client (SEP-2322). Implementations are [*ElicitParams], +// [*CreateMessageParams], and [*ListRootsParams]. +type InputRequest interface{ isInputRequest() } + +// InputRequestMap maps server-assigned request IDs to [InputRequest] values. +// It is used in result types to tell the client what input the server needs. +type InputRequestMap map[string]InputRequest + +func (m InputRequestMap) MarshalJSON() ([]byte, error) { + if m == nil { + return json.Marshal(map[string]any(nil)) + } + type wire struct { + Method string `json:"method"` + Params InputRequest `json:"params,omitempty"` + } + typeToMethod := func(v InputRequest) (string, error) { + switch v.(type) { + case *ElicitParams: + return methodElicit, nil + case *CreateMessageParams, *CreateMessageWithToolsParams: + return methodCreateMessage, nil + case *ListRootsParams: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Params: v} + } + return json.Marshal(converted) +} + +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + if rawMap == nil { + return nil + } + result := make(InputRequestMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + } + } + *m = result + return nil +} + +// InputResponse is a type for results that a client sends back when fulfilling +// a server input request (SEP-2322). Implementations are [*ElicitResult], +// [*CreateMessageResult], and [*ListRootsResult]. +type InputResponse interface{ isInputResponse() } + +// InputResponseMap maps request IDs (from [InputRequestMap]) to [InputResponse] +// values. It is used in params types when retrying a call after an +// input-required result. +type InputResponseMap map[string]InputResponse + +func (m InputResponseMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Result InputResponse `json:"result,omitempty"` + } + typeToMethod := func(v InputResponse) (string, error) { + switch v.(type) { + case *ElicitResult: + return methodElicit, nil + case *CreateMessageResult, *CreateMessageWithToolsResult: + return methodCreateMessage, nil + case *ListRootsResult: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Result: v} + } + return json.Marshal(converted) +} + +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Result json.RawMessage `json:"result"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputResponseMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputResponse method: %q", raw.Method) + } + } + *m = result + return nil +} + // Optional annotations for the client. The client can use annotations to inform // how objects are used or displayed. type Annotations struct { @@ -46,6 +218,14 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -61,6 +241,14 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -107,6 +295,24 @@ type CallToolResult struct { // the Content field. IsError bool `json:"isError,omitempty"` + // InputRequests is a map of server-assigned IDs to input requests. + // Populated only when ResultType is ResultTypeInputRequired. + // The client must fulfill these and echo the IDs back in InputResponses + // when retrying the call. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + + // RequestState is an opaque string the client must echo back when + // retrying after an input-required result. Servers use this to carry + // context between independent requests. + // + // Unauthenticated servers must encrypt, sign and verify this value. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. Empty or ResultTypeComplete means the call succeeded + // normally. ResultTypeInputRequired means the client should fulfill the + // InputRequests and retry the call. + resultType resultType // The error passed to setError, if any. // It is not marshaled, and therefore it is only visible on the server. // Its only use is in server sending middleware, where it can be accessed @@ -145,13 +351,49 @@ func (r *CallToolResult) GetError() error { func (*CallToolResult) isResult() {} -// UnmarshalJSON handles the unmarshalling of content into the Content -// interface. +func (r *CallToolResult) setResultType(rt resultType) { r.resultType = rt } +func (r *CallToolResult) requestState() string { return r.RequestState } +func (r *CallToolResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *CallToolResult) hasContent() bool { + return len(r.Content) > 0 || r.StructuredContent != nil +} + +// NeedsInput reports whether this result requires further client input. +// This is true when the server returned ResultType "input_required". +// When NeedsInput returns true, check InputRequests for the set of +// requests the server needs fulfilled before retrying the call. +// An empty InputRequests with NeedsInput true indicates load-shedding. +func (r *CallToolResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *CallToolResult) MarshalJSON() ([]byte, error) { + type res CallToolResult // avoid recursion + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + func (x *CallToolResult) UnmarshalJSON(data []byte) error { type res CallToolResult // avoid recursion var wire struct { res - Content []*wireContent `json:"content"` + Content []*wireContent `json:"content"` + ResultType resultType `json:"resultType"` } if err := internaljson.Unmarshal(data, &wire); err != nil { return err @@ -160,15 +402,18 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } + wire.res.resultType = wire.ResultType *x = CallToolResult(wire.res) return nil } func (x *CallToolParams) isParams() {} +func (x *CallToolParams) isNil() bool { return x == nil } func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) isNil() bool { return x == nil } func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } @@ -187,6 +432,7 @@ type CancelledParams struct { } func (x *CancelledParams) isParams() {} +func (x *CancelledParams) isNil() bool { return x == nil } func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -374,7 +620,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } -func (*CompleteParams) isParams() {} +func (x *CompleteParams) isParams() {} +func (x *CompleteParams) isNil() bool { return x == nil } type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` @@ -422,6 +669,8 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isInputRequest() {} +func (x *CreateMessageParams) isNil() bool { return x == nil } func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -448,6 +697,8 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isInputRequest() {} +func (x *CreateMessageWithToolsParams) isNil() bool { return x == nil } func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -547,7 +798,8 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isInputResponse() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -592,7 +844,8 @@ var createMessageWithToolsResultAllow = map[string]bool{ "tool_use": true, } -func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isInputResponse() {} // MarshalJSON marshals the result. When Content has a single element, it is // marshaled as a single object for compatibility with pre-2025-11-25 @@ -651,9 +904,17 @@ type GetPromptParams struct { Arguments map[string]string `json:"arguments,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) isNil() bool { return x == nil } func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -665,10 +926,67 @@ type GetPromptResult struct { // An optional description for the prompt. Description string `json:"description,omitempty"` Messages []*PromptMessage `json:"messages"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for multi-round-trip retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType resultType } func (*GetPromptResult) isResult() {} +func (r *GetPromptResult) setResultType(rt resultType) { r.resultType = rt } +func (r *GetPromptResult) requestState() string { return r.RequestState } +func (r *GetPromptResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *GetPromptResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *GetPromptResult) MarshalJSON() ([]byte, error) { + type res GetPromptResult + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + +func (x *GetPromptResult) UnmarshalJSON(data []byte) error { + type res GetPromptResult + var wire struct { + res + ResultType resultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = GetPromptResult(wire.res) + return nil +} + // InitializeParams is sent by the client to initialize the session. type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -706,6 +1024,7 @@ func (p *initializeParamsV2) toV1() *InitializeParams { } func (x *InitializeParams) isParams() {} +func (x *InitializeParams) isNil() bool { return x == nil } func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -739,6 +1058,7 @@ type InitializedParams struct { } func (x *InitializedParams) isParams() {} +func (x *InitializedParams) isNil() bool { return x == nil } func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -752,6 +1072,7 @@ type ListPromptsParams struct { } func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) isNil() bool { return x == nil } func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -780,6 +1101,7 @@ type ListResourceTemplatesParams struct { } func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) isNil() bool { return x == nil } func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -808,6 +1130,7 @@ type ListResourcesParams struct { } func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) isNil() bool { return x == nil } func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -833,6 +1156,8 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isInputRequest() {} +func (x *ListRootsParams) isNil() bool { return x == nil } func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -846,7 +1171,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isInputResponse() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -858,6 +1184,7 @@ type ListToolsParams struct { } func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) isNil() bool { return x == nil } func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -896,6 +1223,7 @@ type LoggingMessageParams struct { } func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) isNil() bool { return x == nil } func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -958,6 +1286,7 @@ type PingParams struct { } func (x *PingParams) isParams() {} +func (x *PingParams) isNil() bool { return x == nil } func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -978,7 +1307,8 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } -func (*ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isNil() bool { return x == nil } // IconTheme specifies the theme an icon is designed for. type IconTheme string @@ -1048,6 +1378,7 @@ type PromptListChangedParams struct { } func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) isNil() bool { return x == nil } func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1086,9 +1417,17 @@ type ReadResourceParams struct { // The URI of the resource to read. The URI can use any protocol; it is up to // the server how to interpret it. URI string `json:"uri"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) isNil() bool { return x == nil } func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1098,10 +1437,67 @@ type ReadResourceResult struct { // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Contents []*ResourceContents `json:"contents"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for multi-round-trip retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType resultType } func (*ReadResourceResult) isResult() {} +func (r *ReadResourceResult) setResultType(rt resultType) { r.resultType = rt } +func (r *ReadResourceResult) requestState() string { return r.RequestState } +func (r *ReadResourceResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *ReadResourceResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *ReadResourceResult) MarshalJSON() ([]byte, error) { + type res ReadResourceResult + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + +func (x *ReadResourceResult) UnmarshalJSON(data []byte) error { + type res ReadResourceResult + var wire struct { + res + ResultType resultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = ReadResourceResult(wire.res) + return nil +} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -1145,6 +1541,7 @@ type ResourceListChangedParams struct { } func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) isNil() bool { return x == nil } func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1205,6 +1602,7 @@ type RootsListChangedParams struct { } func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) isNil() bool { return x == nil } func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1288,6 +1686,7 @@ type SetLoggingLevelParams struct { } func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) isNil() bool { return x == nil } func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1346,6 +1745,13 @@ type Tool struct { Icons []Icon `json:"icons,omitempty"` } +// hintomitempty is a compatibility parameter that restores the pre-1.7.0 +// behavior of [ToolAnnotations] JSON marshaling, where false-valued bare bool +// fields (ReadOnlyHint, IdempotentHint) were omitted from the output. +// See the documentation for the mcpgodebug package for instructions on how to +// enable it. +var hintomitempty = mcpgodebug.Value("hintomitempty") + // Additional properties describing a Tool to clients. // // NOTE: all properties in ToolAnnotations are hints. They are not @@ -1368,7 +1774,7 @@ type ToolAnnotations struct { // (This property is meaningful only when ReadOnlyHint == false.) // // Default: false - IdempotentHint bool `json:"idempotentHint,omitempty"` + IdempotentHint bool `json:"idempotentHint"` // If true, this tool may interact with an "open world" of external entities. If // false, the tool's domain of interaction is closed. For example, the world of // a web search tool is open, whereas that of a memory tool is not. @@ -1378,11 +1784,30 @@ type ToolAnnotations struct { // If true, the tool does not modify its environment. // // Default: false - ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + ReadOnlyHint bool `json:"readOnlyHint"` // A human-readable title for the tool. Title string `json:"title,omitempty"` } +// MarshalJSON implements [json.Marshaler] for ToolAnnotations. +// +// To restore the previous behavior where false-valued ReadOnlyHint and +// IdempotentHint were omitted, set MCPGODEBUG=hintomitempty=1. +func (t ToolAnnotations) MarshalJSON() ([]byte, error) { + if hintomitempty == "1" { + type compat struct { + DestructiveHint *bool `json:"destructiveHint,omitempty"` + IdempotentHint bool `json:"idempotentHint,omitempty"` + OpenWorldHint *bool `json:"openWorldHint,omitempty"` + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + Title string `json:"title,omitempty"` + } + return json.Marshal(compat(t)) + } + type nomethod ToolAnnotations + return json.Marshal(nomethod(t)) +} + type ToolListChangedParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -1390,6 +1815,7 @@ type ToolListChangedParams struct { } func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) isNil() bool { return x == nil } func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1403,7 +1829,8 @@ type SubscribeParams struct { URI string `json:"uri"` } -func (*SubscribeParams) isParams() {} +func (x *SubscribeParams) isParams() {} +func (x *SubscribeParams) isNil() bool { return x == nil } // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous @@ -1416,7 +1843,8 @@ type UnsubscribeParams struct { URI string `json:"uri"` } -func (*UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isNil() bool { return x == nil } // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the @@ -1429,7 +1857,8 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } -func (*ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } // TODO(jba): add CompleteRequest and related types. @@ -1468,7 +1897,9 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isInputRequest() {} +func (x *ElicitParams) isNil() bool { return x == nil } func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1488,7 +1919,8 @@ type ElicitResult struct { Content map[string]any `json:"content,omitempty"` } -func (*ElicitResult) isResult() {} +func (*ElicitResult) isResult() {} +func (*ElicitResult) isInputResponse() {} // ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. type ElicitationCompleteParams struct { @@ -1500,18 +1932,21 @@ type ElicitationCompleteParams struct { ElicitationID string `json:"elicitationId"` } -func (*ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isNil() bool { return x == nil } -// An Implementation describes the name and version of an MCP implementation, with an optional -// title for UI representation. +// An Implementation describes the name and version of an MCP implementation, with +// optional display metadata. type Implementation struct { // Intended for programmatic or logical use, but used as a display name in past // specs or fallback (if title isn't present). Name string `json:"name"` // Intended for UI and end-user contexts — optimized to be human-readable and // easily understood, even by those unfamiliar with domain-specific terminology. - Title string `json:"title,omitempty"` - Version string `json:"version"` + Title string `json:"title,omitempty"` + // A human-readable description of the implementation. + Description string `json:"description,omitempty"` + Version string `json:"version"` // WebsiteURL for the server, if any. WebsiteURL string `json:"websiteUrl,omitempty"` // Icons for the Server, if any. @@ -1630,3 +2065,17 @@ const ( notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" ) + +// Per-request _meta field names for the >= 2026-06-30 protocol version. +// +// These keys appear inside a Params._meta map and carry information that +// previously came from the initialization handshake (SEP-2575). +const ( + // MetaKeyProtocolVersion identifies the MCP protocol version that the + // request follows. + MetaKeyProtocolVersion = "io.modelcontextprotocol/protocolVersion" + // MetaKeyClientInfo carries the client's [Implementation]. + MetaKeyClientInfo = "io.modelcontextprotocol/clientInfo" + // MetaKeyClientCapabilities carries the client's [ClientCapabilities]. + MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" +) diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index 751d0812..edf3623d 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -1151,6 +1151,71 @@ func TestToWithTools_Conversion(t *testing.T) { } } +func TestInputRequestMapJSON(t *testing.T) { + t.Run("nil is omitted from JSON", func(t *testing.T) { + result := CallToolResult{Content: []Content{&TextContent{Text: "ok"}}} + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if _, ok := raw["inputRequests"]; ok { + t.Errorf("nil InputRequests should be omitted, got %s", raw["inputRequests"]) + } + }) + + t.Run("non-nil empty round-trips", func(t *testing.T) { + result := CallToolResult{ + Content: []Content{&TextContent{Text: "ok"}}, + InputRequests: InputRequestMap{}, + } + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if string(raw["inputRequests"]) != "{}" { + t.Errorf("empty InputRequests should marshal to {}, got %s", raw["inputRequests"]) + } + var got CallToolResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.InputRequests == nil { + t.Error("empty InputRequests should round-trip as non-nil") + } + }) + + t.Run("populated round-trips", func(t *testing.T) { + result := CallToolResult{ + Content: []Content{&TextContent{Text: "ok"}}, + InputRequests: InputRequestMap{ + "r1": &ElicitParams{Message: "confirm?"}, + }, + } + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var got CallToolResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.InputRequests == nil { + t.Fatal("InputRequests should not be nil after round-trip") + } + if _, ok := got.InputRequests["r1"]; !ok { + t.Error("expected key r1 in InputRequests") + } + }) +} + func TestContentUnmarshal(t *testing.T) { // Verify that types with a Content field round-trip properly. roundtrip := func(in, out any) { @@ -1194,3 +1259,60 @@ func TestContentUnmarshal(t *testing.T) { var gotpm PromptMessage roundtrip(pm, &gotpm) } + +func TestToolAnnotations_MarshalJSON(t *testing.T) { + boolPtr := func(b bool) *bool { return &b } + + tests := []struct { + name string + in ToolAnnotations + want string + }{ + { + name: "ZeroValue", + in: ToolAnnotations{}, + want: `{"idempotentHint":false,"readOnlyHint":false}`, + }, + { + name: "AllFalse", + in: ToolAnnotations{ + DestructiveHint: boolPtr(false), + IdempotentHint: false, + OpenWorldHint: boolPtr(false), + ReadOnlyHint: false, + }, + want: `{"destructiveHint":false,"idempotentHint":false,"openWorldHint":false,"readOnlyHint":false}`, + }, + { + name: "AllTrue", + in: ToolAnnotations{ + DestructiveHint: boolPtr(true), + IdempotentHint: true, + OpenWorldHint: boolPtr(true), + ReadOnlyHint: true, + Title: "my tool", + }, + want: `{"destructiveHint":true,"idempotentHint":true,"openWorldHint":true,"readOnlyHint":true,"title":"my tool"}`, + }, + { + name: "MixedValues", + in: ToolAnnotations{ + ReadOnlyHint: true, + Title: "read tool", + }, + want: `{"idempotentHint":false,"readOnlyHint":true,"title":"read tool"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.in) + if err != nil { + t.Fatalf("json.Marshal(%v) failed: %v", tt.in, err) + } + if diff := cmp.Diff(tt.want, string(got)); diff != "" { + t.Errorf("json.Marshal() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index 183226d1..693bd391 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -79,6 +79,13 @@ type ServerOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // KeepAliveFailureThreshold is the number of consecutive keepalive ping + // failures tolerated before the session is closed. A value of 0 or 1 + // closes the session on the first failure (the default). Higher values + // align with the spec's "multiple failed pings MAY trigger a connection + // reset" guidance, letting a transient miss pass without tearing down an + // otherwise live session. Has no effect unless KeepAlive is non-zero. + KeepAliveFailureThreshold int // Function called when a client session subscribes to a resource. SubscribeHandler func(context.Context, *SubscribeRequest) error // Function called when a client session unsubscribes from a resource. @@ -187,7 +194,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } - return &Server{ + s := &Server{ impl: impl, opts: opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), @@ -199,6 +206,8 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), } + s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) + return s } // AddPrompt adds a [Prompt] to the server, or replaces one with the same name. @@ -370,8 +379,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa } // Marshal the output and put the RawMessage in the StructuredContent field. + // Skip when the handler returned input requests (multi round-trip): content and + // inputRequests are mutually exclusive on the wire. var outval any = out - if elemZero != nil { + if res.InputRequests != nil { + outval = nil + } else if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. // Instead, use the zero value of the unpointered type. var z Out @@ -742,7 +755,13 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req) + res, err := prompt.handler(ctx, req) + if err == nil && res != nil { + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + } + return res, err } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { @@ -775,10 +794,15 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR } } res, err := st.handler(ctx, req) - if err == nil && res != nil && res.Content == nil { - res2 := *res - res2.Content = []Content{} // avoid "null" - res = &res2 + if err == nil && res != nil { + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.Content == nil && res.resultType != resultTypeInputRequired { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } } return res, err } @@ -826,6 +850,12 @@ func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*R if err != nil { return nil, err } + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.resultType == resultTypeInputRequired { + return res, nil + } if res == nil || res.Contents == nil { return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) } @@ -1450,13 +1480,32 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, initialized := ss.state.InitializeParams != nil ss.mu.Unlock() - // From the spec: - // "The client SHOULD NOT send requests other than pings before the server - // has responded to the initialize request." + // Per-request protocol detection (SEP-2575): if the request carries + // `io.modelcontextprotocol/protocolVersion` in its `_meta` field, it + // follows the new sessionless protocol. The initialization gate is + // skipped for such requests. + validatedMeta, perRequestErr := validateRequestMeta(req) + if perRequestErr != nil { + return nil, perRequestErr + } + + if !initialized && validatedMeta.usesNewProtocol && validatedMeta.initializeParams != nil { + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = validatedMeta.initializeParams + }) + } + switch req.Method { case methodInitialize, methodPing, notificationInitialized: + if validatedMeta.usesNewProtocol { + ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: fmt.Sprintf("%q is not supported in the new protocol", req.Method), + } + } default: - if !initialized { + if !initialized && !validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } @@ -1488,9 +1537,17 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } + var wasInit bool ss.updateState(func(state *ServerSessionState) { - state.InitializeParams = params + wasInit = state.InitializeParams != nil + if !wasInit { + state.InitializeParams = params + } }) + if wasInit { + ss.server.opts.Logger.Error("duplicate initialize request") + return nil, fmt.Errorf("duplicate %q received", methodInitialize) + } s := ss.server return &InitializeResult{ @@ -1555,7 +1612,7 @@ func (ss *ServerSession) Wait() error { // startKeepalive starts the keepalive mechanism for this server session. func (ss *ServerSession) startKeepalive(interval time.Duration) { - startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.Logger) + startKeepalive(ss, interval, ss.server.opts.KeepAliveFailureThreshold, &ss.keepaliveCancel, ss.server.opts.Logger) } // pageToken is the internal structure for the opaque pagination cursor. diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..7cc780a9 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "log" "log/slog" @@ -19,6 +20,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type testItem struct { @@ -825,6 +827,92 @@ func TestClientRootCapabilities(t *testing.T) { } } +func TestServerRejectsDuplicateInitialize(t *testing.T) { + ctx := context.Background() + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + cTransport, sTransport := NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cConn, err := cTransport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + defer cConn.Close() + + firstParams := json.RawMessage(`{ + "protocolVersion": "2025-11-25", + "clientInfo": {"name": "first-client", "version": "1.0.0"} + }`) + firstReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(1), methodInitialize, firstParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, firstReq); err != nil { + t.Fatalf("first initialize write failed: %v", err) + } + msg, err := cConn.Read(ctx) + if err != nil { + t.Fatalf("first initialize read failed: %v", err) + } + resp, ok := msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error != nil { + t.Fatalf("first initialize failed: %v", resp.Error) + } + + initializedReq, err := jsonrpc2.NewNotification(notificationInitialized, &InitializedParams{}) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, initializedReq); err != nil { + t.Fatalf("initialized notification write failed: %v", err) + } + + secondParams := json.RawMessage(`{ + "protocolVersion": "2024-11-05", + "clientInfo": {"name": "second-client", "version": "2.0.0"} + }`) + secondReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(2), methodInitialize, secondParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, secondReq); err != nil { + t.Fatalf("second initialize write failed: %v", err) + } + msg, err = cConn.Read(ctx) + if err != nil { + t.Fatalf("second initialize read failed: %v", err) + } + resp, ok = msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error == nil { + t.Fatal("second initialize unexpectedly succeeded") + } + if !strings.Contains(resp.Error.Error(), `duplicate "initialize" received`) { + t.Fatalf("second initialize error = %v, want duplicate initialize", resp.Error) + } + + got := ss.InitializeParams() + if got == nil { + t.Fatal("InitializeParams is nil") + } + if got.ProtocolVersion != "2025-11-25" { + t.Fatalf("ProtocolVersion = %q, want first initialize value", got.ProtocolVersion) + } + if got.ClientInfo == nil || got.ClientInfo.Name != "first-client" { + t.Fatalf("ClientInfo = %#v, want first initialize value", got.ClientInfo) + } +} + // TODO: move this to tool_test.go func TestToolForSchemas(t *testing.T) { // Validate that toolForErr handles schemas properly. @@ -1007,3 +1095,167 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +// SEP-2575 removes the initialization handshake. An `initialize` request +// that opts into the new protocol via `_meta.protocolVersion` must be +// rejected with `Method not found` (-32601). +func TestServerSessionHandle_RejectsInitializeOnNewProtocol(t *testing.T) { + tests := []struct { + name string + params any + wantReject bool + }{ + { + name: "initialize with new-protocol _meta is rejected", + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }, + wantReject: true, + }, + { + name: "initialize without _meta is allowed (old protocol)", + params: map[string]any{ + "protocolVersion": protocolVersion20251125, + }, + wantReject: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(tc.params), + } + _, err = ss.handle(context.Background(), req) + if tc.wantReject { + if err == nil { + t.Fatal("expected error rejecting initialize, got nil") + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error so the wire returns the right code", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("error code = %d, want %d (CodeMethodNotFound = -32601)", jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, "initialize") { + t.Errorf("error message %q does not mention %q", jerr.Message, "initialize") + } + } else { + // Old-protocol initialize should be dispatched normally; any + // CodeMethodNotFound here would mean the rejection branch + // fired incorrectly. + var jerr *jsonrpc.Error + if errors.As(err, &jerr) && jerr.Code == jsonrpc.CodeMethodNotFound { + t.Errorf("old-protocol initialize was incorrectly rejected: %v", err) + } + } + }) + } + + t.Run("rejection error encodes to wire as code -32601", func(t *testing.T) { + // Belt-and-braces check that the error type produced by handle() + // actually serializes to JSON-RPC code -32601, not a bare 0. + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }), + } + _, handleErr := ss.handle(context.Background(), req) + if handleErr == nil { + t.Fatal("expected rejection error, got nil") + } + data, encErr := jsonrpc.EncodeMessage(&jsonrpc.Response{ID: id, Error: handleErr.(*jsonrpc.Error)}) + if encErr != nil { + t.Fatal(encErr) + } + var wire struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(data, &wire); err != nil { + t.Fatal(err) + } + if wire.Error.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("wire error code = %d, want %d; full response = %s", wire.Error.Code, jsonrpc.CodeMethodNotFound, data) + } + }) +} + +// TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol verifies that +// the methods removed by SEP-2575 (`initialize`, `notifications/initialized`, +// `ping`) all return Method not found when the request opts into the new +// protocol via `_meta.protocolVersion`. +func TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol(t *testing.T) { + newProtoMeta := map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + } + + tests := []struct { + name string + method string + }{ + {"initialize", methodInitialize}, + {"ping", methodPing}, + {"notifications/initialized", notificationInitialized}, + } + + for _, tc := range tests { + t.Run(tc.name+" rejected on new protocol", func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: tc.method, + Params: mustMarshal(newProtoMeta), + } + _, err = ss.handle(context.Background(), req) + if err == nil { + t.Fatalf("method %q on new protocol: got nil error, want CodeMethodNotFound", tc.method) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("method %q: code = %d, want %d", tc.method, jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, tc.method) { + t.Errorf("method %q: message %q does not mention method name", tc.method, jerr.Message) + } + }) + } +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..afa566aa 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -465,6 +465,70 @@ func setProgressToken(p Params, pt any) { m[progressTokenKey] = pt } +// extractRequestMeta performs a lightweight partial unmarshal of the `_meta` +// field from a JSON-RPC request's raw params. +func extractRequestMeta(rawParams json.RawMessage) Meta { + if len(rawParams) == 0 { + return nil + } + var meta struct { + Meta Meta `json:"_meta"` + } + if err := internaljson.Unmarshal(rawParams, &meta); err != nil { + return nil + } + return meta.Meta +} + +type validatedMeta struct { + usesNewProtocol bool + initializeParams *InitializeParams +} + +// validateRequestMeta inspects a JSON-RPC request to detect whether it follows +// the >= 2026-06-30 protocol via the `_meta` field. +// If the request has no _meta, or no protocolVersion in _meta, it returns a non-nil +// validatedMeta with usesNewProtocol set to false, and a nil error. +// If the request has a protocolVersion in _meta: +// - For notifications, it returns usesNewProtocol set to true and a nil initializeParams. +// - For call requests, it validates the presence of clientInfo and clientCapabilities in _meta. +// If either is missing or invalid, it returns nil and a non-nil error. Otherwise, it returns +// usesNewProtocol set to true and the populated initializeParams. +func validateRequestMeta(req *jsonrpc.Request) (*validatedMeta, error) { + meta := extractRequestMeta(req.Params) + if meta == nil { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + protocolVersion, ok := meta[MetaKeyProtocolVersion].(string) + if !ok { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + // Notifications do not carry full client identity. In new protocol, only cancel notification + // is allowed in STDIO. + if !req.IsCall() { + return &validatedMeta{usesNewProtocol: true, initializeParams: nil}, nil + } + clientInfo, ok := decodeMetaValue[*Implementation](meta, MetaKeyClientInfo) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientInfo), + } + } + capabilities, ok := decodeMetaValue[*ClientCapabilities](meta, MetaKeyClientCapabilities) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientCapabilities), + } + } + return &validatedMeta{usesNewProtocol: true, initializeParams: &InitializeParams{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ClientInfo: clientInfo, + }}, nil +} + // A Request is a method request with parameters and additional information, such as the session. // Request is implemented by [*ClientRequest] and [*ServerRequest]. type Request interface { @@ -525,6 +589,94 @@ func (r *ServerRequest[P]) GetParams() Params { return r.Params } func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } +// ProtocolVersion returns the protocol version negotiated for this request. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams] established during the +// initialize handshake. +func (r *ServerRequest[P]) ProtocolVersion() string { + if m := getRequestMeta(r); m != nil { + if v, ok := m[MetaKeyProtocolVersion].(string); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ProtocolVersion + } + } + return "" +} + +// ClientInfo returns the [Implementation] identifying the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientInfo() *Implementation { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*Implementation](m, MetaKeyClientInfo); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ClientInfo + } + } + return nil +} + +// ClientCapabilities returns the [ClientCapabilities] of the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientCapabilities() *ClientCapabilities { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*ClientCapabilities](m, MetaKeyClientCapabilities); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.Capabilities + } + } + return nil +} + +// getRequestMeta returns the raw `_meta` map from the request's params, or +// nil if the params are absent. +func getRequestMeta[P Params](r *ServerRequest[P]) map[string]any { + // In practice P is a pointer type implementing Params. + if any(r.Params) == nil || r.Params.isNil() { + return nil + } + return r.Params.GetMeta() +} + +// decodeMetaValue decodes a typed value out of a `_meta` map. Values may +// arrive either as the typed Go value (when constructed in-process) or as +// the generic JSON map produced by encoding/json after wire transit. In the +// latter case, the value is re-encoded and decoded into the target type. +func decodeMetaValue[T any](m map[string]any, key string) (T, bool) { + var zero T + raw, ok := m[key] + if !ok || raw == nil { + return zero, false + } + if v, ok := raw.(T); ok { + return v, true + } + var v T + if err := remarshal(raw, &v); err != nil { + return zero, false + } + return v, true +} + func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { return &ServerRequest[P]{Session: s, Params: p} } @@ -542,6 +694,9 @@ type Params interface { // isParams discourages implementation of Params outside of this package. isParams() + + // isNil returns true if the underlying value is nil. + isNil() bool } // RequestParams is a parameter (input) type for an MCP request. @@ -596,9 +751,20 @@ type keepaliveSession interface { // It assigns the cancel function to the provided cancelPtr and starts a goroutine // that sends ping messages at the specified interval. // -// logger must be non-nil; ping failures (which terminate the keepalive loop and -// close the session) are reported via logger so they are not silently dropped. -func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, logger *slog.Logger) { +// failureThreshold is the number of consecutive ping failures tolerated before +// the session is closed; a value below 1 is treated as 1 (close on the first +// failure). A successful ping resets the counter. This mirrors the spec's +// "multiple failed pings MAY trigger a connection reset" language, letting a +// transient miss pass without tearing down an otherwise live session. +// +// logger must be non-nil; ping failures (both the tolerated ones and the final +// one that closes the session) are reported via logger so they are not silently +// dropped. +func startKeepalive(session keepaliveSession, interval time.Duration, failureThreshold int, cancelPtr *context.CancelFunc, logger *slog.Logger) { + if failureThreshold < 1 { + failureThreshold = 1 + } + ctx, cancel := context.WithCancel(context.Background()) // Assign cancel function before starting goroutine to avoid race condition. // We cannot return it because the caller may need to cancel during the @@ -609,6 +775,7 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr ticker := time.NewTicker(interval) defer ticker.Stop() + consecutiveFailures := 0 for { select { case <-ctx.Done(): @@ -617,17 +784,32 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) err := session.Ping(pingCtx, nil) pingCancel() - if err != nil { - if errors.Is(err, jsonrpc2.ErrMethodNotFound) { - // Peer doesn't support ping, stop the keepalive process. - return - } - // Ping failed; log it before closing the session so the - // failure is observable to operators. See #218. - logger.Error("keepalive ping failed; closing session", "error", err) - _ = session.Close() + if err == nil { + consecutiveFailures = 0 + continue + } + if errors.Is(err, jsonrpc2.ErrMethodNotFound) { + // Peer doesn't support ping, stop the keepalive process. return } + consecutiveFailures++ + if consecutiveFailures < failureThreshold { + // Tolerate transient failures below the threshold; log so + // the misses are still observable to operators. See #218. + logger.Warn("keepalive ping failed; tolerating below threshold", + "error", err, + "consecutiveFailures", consecutiveFailures, + "failureThreshold", failureThreshold) + continue + } + // Threshold reached; log before closing the session so the + // failure is observable to operators. See #218. + logger.Error("keepalive ping failed; closing session", + "error", err, + "consecutiveFailures", consecutiveFailures, + "failureThreshold", failureThreshold) + _ = session.Close() + return } } }() diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 23818f87..065d00b0 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,6 +4,280 @@ package mcp +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestValidateRequestMeta(t *testing.T) { + tests := []struct { + name string + method string + isNotification bool + params any + wantUsesNew bool + wantErrContains string + }{ + { + name: "no params: old protocol", + method: methodListTools, + params: nil, + wantUsesNew: false, + }, + { + name: "no _meta: old protocol", + method: methodCallTool, + params: map[string]any{"name": "x"}, + wantUsesNew: false, + }, + { + name: "_meta without protocolVersion: old protocol", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{"otherKey": "v"}, + "name": "x", + }, + wantUsesNew: false, + }, + { + name: "new protocol with all required fields", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: true, + }, + { + name: "new protocol missing clientInfo", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientInfo, + }, + { + name: "new protocol missing clientCapabilities", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientCapabilities, + }, + { + name: "notifications exempt from required fields", + method: notificationCancelled, + isNotification: true, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + }, + "requestId": "r1", + }, + wantUsesNew: true, + }, + { + name: "malformed _meta is ignored", + method: methodCallTool, + params: json.RawMessage(`{"_meta": "not an object", "name": "x"}`), + wantUsesNew: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var raw json.RawMessage + switch p := tc.params.(type) { + case json.RawMessage: + raw = p + default: + raw = mustMarshal(tc.params) + } + req := &jsonrpc.Request{Method: tc.method, Params: raw} + if !tc.isNotification { + req.ID = jsonrpc.ID{} + // Give the request an ID by parsing one. + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req.ID = id + } + + vmeta, err := validateRequestMeta(req) + usesNew := vmeta != nil && vmeta.usesNewProtocol + if usesNew != tc.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tc.wantUsesNew) + } + if tc.wantErrContains == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErrContains) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("expected *jsonrpc.Error, got %T: %v", err, err) + } + if jerr.Code != jsonrpc.CodeInvalidParams { + t.Errorf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidParams) + } + if !strings.Contains(jerr.Message, tc.wantErrContains) { + t.Errorf("error message %q does not contain %q", jerr.Message, tc.wantErrContains) + } + }) + } +} + +func TestServerRequest_PerRequestAccessors(t *testing.T) { + // A request carrying the new-protocol _meta fields populates the + // accessors with values from _meta. + caps := &ClientCapabilities{Sampling: &SamplingCapabilities{}} + info := &Implementation{Name: "c", Version: "1"} + params := &CallToolParamsRaw{ + Meta: Meta{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: info, + MetaKeyClientCapabilities: caps, + }, + Name: "x", + } + req := &ServerRequest[*CallToolParamsRaw]{Params: params} + if got := req.ProtocolVersion(); got != protocolVersion20260630 { + t.Errorf("ProtocolVersion = %q, want %q", got, protocolVersion20260630) + } + if got := req.ClientInfo(); got == nil || got.Name != "c" { + t.Errorf("ClientInfo = %+v, want Name=c", got) + } + if got := req.ClientCapabilities(); got == nil || got.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", got) + } +} + +func TestServerRequest_PerRequestAccessors_FromJSON(t *testing.T) { + // Values arriving over the wire are JSON maps; the accessors should + // re-decode them into typed Go values. + raw := json.RawMessage(`{ + "_meta": { + "io.modelcontextprotocol/protocolVersion": "2026-06-30", + "io.modelcontextprotocol/clientInfo": {"name": "wire-client", "version": "9"}, + "io.modelcontextprotocol/clientCapabilities": {"sampling": {}} + }, + "name": "tool" + }`) + var params CallToolParamsRaw + if err := json.Unmarshal(raw, ¶ms); err != nil { + t.Fatal(err) + } + req := &ServerRequest[*CallToolParamsRaw]{Params: ¶ms} + if got, want := req.ProtocolVersion(), protocolVersion20260630; got != want { + t.Errorf("ProtocolVersion = %q, want %q", got, want) + } + gotInfo := req.ClientInfo() + wantInfo := &Implementation{Name: "wire-client", Version: "9"} + if diff := cmp.Diff(wantInfo, gotInfo); diff != "" { + t.Errorf("ClientInfo mismatch (-want +got):\n%s", diff) + } + gotCaps := req.ClientCapabilities() + if gotCaps == nil || gotCaps.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", gotCaps) + } +} + +func TestServerRequest_PerRequestAccessors_FallbackToInitializeParams(t *testing.T) { + // With no _meta on the request, accessors must fall back to the + // session's InitializeParams (the old-protocol path). + ss := &ServerSession{} + ss.state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion20251125, + ClientInfo: &Implementation{Name: "old", Version: "0"}, + Capabilities: &ClientCapabilities{Elicitation: &ElicitationCapabilities{}}, + } + req := &ServerRequest[*CallToolParamsRaw]{ + Session: ss, + Params: &CallToolParamsRaw{Name: "x"}, + } + if got, want := req.ProtocolVersion(), protocolVersion20251125; got != want { + t.Errorf("ProtocolVersion fallback = %q, want %q", got, want) + } + if got := req.ClientInfo(); got == nil || got.Name != "old" { + t.Errorf("ClientInfo fallback = %+v, want Name=old", got) + } + if got := req.ClientCapabilities(); got == nil || got.Elicitation == nil { + t.Errorf("ClientCapabilities fallback = %+v, want non-nil Elicitation", got) + } +} + +func TestServerRequest_PerRequestAccessors_Empty(t *testing.T) { + // With no _meta and no session, accessors return zero values. + req := &ServerRequest[*CallToolParamsRaw]{ + Params: &CallToolParamsRaw{Name: "x"}, + } + if got := req.ProtocolVersion(); got != "" { + t.Errorf("ProtocolVersion = %q, want empty", got) + } + if got := req.ClientInfo(); got != nil { + t.Errorf("ClientInfo = %+v, want nil", got) + } + if got := req.ClientCapabilities(); got != nil { + t.Errorf("ClientCapabilities = %+v, want nil", got) + } +} + +func TestImplementationDescriptionJSON(t *testing.T) { + impl := &Implementation{ + Name: "greeter", + Title: "Greeter", + Description: "Example server for greeting tools", + Version: "v1.0.0", + } + got, err := json.Marshal(impl) + if err != nil { + t.Fatal(err) + } + want := `{"name":"greeter","title":"Greeter","description":"Example server for greeting tools","version":"v1.0.0"}` + if string(got) != want { + t.Fatalf("Implementation JSON = %s, want %s", got, want) + } + + var roundTrip Implementation + if err := json.Unmarshal(got, &roundTrip); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(impl, &roundTrip); diff != "" { + t.Fatalf("Implementation round trip mismatch (-want +got):\n%s", diff) + } + + got, err = json.Marshal(&Implementation{Name: "greeter", Version: "v1.0.0"}) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(got), "description") { + t.Fatalf("empty description should be omitted, got %s", got) + } +} + // TODO(v0.3.0): rewrite this test. // func TestToolValidate(t *testing.T) { // // Check that the tool returned from NewServerTool properly validates its input schema. diff --git a/mcp/streamable.go b/mcp/streamable.go index d3f3f4fa..0f4e65b8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -343,8 +343,14 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } + connectOpts, usesNewProtocol, err := h.ephemeralConnectOpts(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var sessionID string - if legacySessions { + if legacySessions && !usesNewProtocol { sessionID = req.Header.Get(sessionIDHeader) if sessionID == "" { sessionID = server.opts.GetSessionID() @@ -359,11 +365,6 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. logger: h.opts.Logger, } - connectOpts, err := h.ephemeralConnectOpts(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } session, err := connectStreamable(req.Context(), server, transport, connectOpts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) @@ -389,10 +390,17 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter } // ephemeralConnectOpts peeks at the request body to determine whether it -// contains an initialize or initialized message. If not, default session state -// is constructed so that the session doesn't reject the request. +// contains an initialize or initialized message or whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +// +// For old-protocol requests, default session state is synthesized so that +// the session's init gate doesn't reject the request. +// // It is used for both stateless servers and stateful servers with no session ID. -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ServerSessionOptions, error) { +// +// The returned usesNewProtocol bool reports whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, err error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 @@ -401,7 +409,7 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server var hasInitialize, hasInitialized bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, fmt.Errorf("failed to read body") + return nil, false, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -415,23 +423,28 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server case notificationInitialized: hasInitialized = true } + if protocolVersion >= protocolVersion20260630 { + usesNewProtocol = true + } } } } state := new(ServerSessionState) - if !hasInitialize { + // Only synthesize fake InitializeParams/InitializedParams for old-protocol + // requests. + if !hasInitialize && !usesNewProtocol { state.InitializeParams = &InitializeParams{ ProtocolVersion: protocolVersion, } } - if !hasInitialized { + if !hasInitialized && !usesNewProtocol { state.InitializedParams = new(InitializedParams) } state.LogLevel = "info" return &ServerSessionOptions{ State: state, - }, nil + }, usesNewProtocol, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -576,7 +589,7 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, err := h.ephemeralConnectOpts(req) + connectOpts, _, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1279,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false var initializeProtocolVersion string + headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -1296,6 +1310,41 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + // SEP-2575: requests carrying `_meta.protocolVersion` require the + // Mcp-Protocol-Version HTTP header to be present and to match the + // per-request `_meta.protocolVersion` value. + // The new (>= 2026-06-30) protocol is supported on the HTTP transport + // only when [StreamableHTTPOptions.Stateless] is true. + // + // TODO: this validation can be moved within validateMcpHeaders. + var metaVersion string + if meta := extractRequestMeta(jreq.Params); meta != nil { + metaVersion, _ = meta[MetaKeyProtocolVersion].(string) + } + if protocolVersion >= protocolVersion20260630 || metaVersion != "" { + if !c.stateless { + http.Error(w, fmt.Sprintf( + "Bad Request: protocol version %q is only supported on stateless HTTP servers (set StreamableHTTPOptions.Stateless = true)", + protocolVersion), + http.StatusBadRequest) + return + } + if headerVersion == "" { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header is required for requests carrying %q", + protocolVersionHeader, MetaKeyProtocolVersion), + http.StatusBadRequest) + return + } + if headerVersion != metaVersion { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header %q does not match request %s %q", + protocolVersionHeader, headerVersion, + MetaKeyProtocolVersion, metaVersion), + http.StatusBadRequest) + return + } + } // Include metadata for all requests (including notifications). jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..53806da9 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -769,6 +769,12 @@ func req(id int64, method string, params any) *jsonrpc.Request { return r } +func completeCallToolResult() *CallToolResult { + r := &CallToolResult{Content: []Content{}} + r.resultType = resultTypeComplete + return r +} + func resp(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), @@ -1926,39 +1932,18 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() - initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) - initResp := resp(1, &InitializeResult{ - Capabilities: &ServerCapabilities{ - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: minVersionForStandardHeaders, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - }, nil) - - initialize := streamableRequest{ - method: "POST", - messages: []jsonrpc.Message{initReq}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{initResp}, - wantSessionID: true, - } - initialized := streamableRequest{ - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {notificationInitialized}, - }, - messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, - wantStatusCode: http.StatusAccepted, + testMeta := Meta{ + MetaKeyProtocolVersion: minVersionForStandardHeaders, + MetaKeyClientInfo: map[string]any{"name": "testClient", "version": "v1.0.0"}, + MetaKeyClientCapabilities: map[string]any{}, } testStreamableHandler(t, handler, []streamableRequest{ - initialize, - initialized, { method: "POST", headers: http.Header{ @@ -1966,9 +1951,9 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(2, completeCallToolResult(), nil)}, }, { method: "POST", @@ -1977,7 +1962,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"prompts/get"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -1988,7 +1973,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"wrong-tool"}, }, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Name header value", }, @@ -1999,7 +1984,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"TOOLS/CALL"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -2010,9 +1995,9 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(6, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2023,11 +2008,12 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"us-west1"}, }, messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2038,6 +2024,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"eu-central1"}, }, messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2052,6 +2039,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { nameHeader: {"execute_sql"}, }, messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2061,6 +2049,68 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }) } +// TODO: Remove this once client operations will automatically inject metadata in the requests +func injectMetaToRequest(req *http.Request) error { + if req.Body == nil { + return nil + } + body, err := io.ReadAll(req.Body) + if err != nil { + return err + } + req.Body.Close() + + var val any + if err := json.Unmarshal(body, &val); err == nil { + var method string + if m, ok := val.(map[string]any); ok { + method, _ = m["method"].(string) + } else if list, ok := val.([]any); ok && len(list) > 0 { + if m, ok := list[0].(map[string]any); ok { + method, _ = m["method"].(string) + } + } + + if method == "initialize" || method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") { + req.Header.Set(protocolVersionHeader, "2025-11-25") + } else { + req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + var msgs []map[string]any + if m, ok := val.(map[string]any); ok { + msgs = []map[string]any{m} + } else if list, ok := val.([]any); ok { + for _, item := range list { + if m, ok := item.(map[string]any); ok { + msgs = append(msgs, m) + } + } + } + + for _, m := range msgs { + params, _ := m["params"].(map[string]any) + if params == nil { + params = make(map[string]any) + m["params"] = params + } + meta, _ := params["_meta"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + params["_meta"] = meta + } + meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders + meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"} + meta[MetaKeyClientCapabilities] = map[string]any{} + } + body, _ = json.Marshal(val) + } + } + + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + return nil +} + // TestStreamableMcpHeaderValidationErrorFormat verifies that header // validation errors return a JSON-RPC error with code -32001 and // Content-Type application/json, per SEP-2243. @@ -2076,7 +2126,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) @@ -2088,6 +2140,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } var originalMethodHeader string if req.Header.Get(methodHeader) == "tools/call" { originalMethodHeader = req.Header.Get(methodHeader) @@ -2237,7 +2292,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() @@ -2245,6 +2302,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { var capturedHeaders http.Header customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } if req.Header.Get(methodHeader) == "tools/call" { capturedHeaders = req.Header.Clone() } @@ -2347,14 +2407,28 @@ func TestStreamableFilterValidToolsIntegration(t *testing.T) { InputSchema: &jsonschema.Schema{Type: "object"}, }, noop) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } + return http.DefaultTransport.RoundTrip(req) + }), + } + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) ctx := context.Background() - session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + session, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + }, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) if err != nil { t.Fatal(err) } @@ -3207,3 +3281,314 @@ func TestStandaloneSSEEmitsCommentForHTTP2Flush(t *testing.T) { t.Fatal("timed out waiting for first SSE bytes; the standalone SSE stream must emit a DATA frame immediately so HTTP/2 reverse proxies don't buffer the HEADERS frame") } } + +// newProtocolBody builds a raw JSON body for a tools/call request that +// carries the >= 2026-06-30 per-request _meta fields. +func newProtocolBody(t *testing.T, toolName string, args any) []byte { + t.Helper() + rawArgs, err := json.Marshal(args) + if err != nil { + t.Fatal(err) + } + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "new-proto-client", "version": "9.9"}, + MetaKeyClientCapabilities: map[string]any{"sampling": map[string]any{}}, + }, + "name": toolName, + "arguments": json.RawMessage(rawArgs), + }, + }) + if err != nil { + t.Fatal(err) + } + return body +} + +func TestEphemeralConnectOpts(t *testing.T) { + mkReq := func(body []byte) *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + r.Header.Set("Content-Type", "application/json") + return r + } + + h := &StreamableHTTPHandler{opts: StreamableHTTPOptions{}} + + oldProtocolBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{"name": "x", "arguments": map[string]any{}}, + }) + if err != nil { + t.Fatal(err) + } + initializeBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": methodInitialize, + "params": map[string]any{"protocolVersion": protocolVersion20250618}, + }) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + body []byte + wantUsesNew bool + wantInitializeParams bool + wantInitializedParams bool + }{ + { + name: "new-protocol request: no synthetic state", + body: newProtocolBody(t, "x", struct{}{}), + wantUsesNew: true, + wantInitializeParams: false, + wantInitializedParams: false, + }, + { + name: "old-protocol request: synthetic state populated", + body: oldProtocolBody, + wantUsesNew: false, + wantInitializeParams: true, + wantInitializedParams: true, + }, + { + name: "initialize request: no synthetic InitializeParams", + body: initializeBody, + wantUsesNew: false, + wantInitializeParams: false, + wantInitializedParams: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mkReq(tt.body) + var pver string + if tt.wantUsesNew { + pver = protocolVersion20260630 + } else { + pver = protocolVersion20250326 + } + req.Header.Set(protocolVersionHeader, pver) + req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) + opts, usesNew, err := h.ephemeralConnectOpts(req) + if err != nil { + t.Fatal(err) + } + if usesNew != tt.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tt.wantUsesNew) + } + if got := opts.State.InitializeParams != nil; got != tt.wantInitializeParams { + t.Errorf("InitializeParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializeParams, opts.State.InitializeParams) + } + if got := opts.State.InitializedParams != nil; got != tt.wantInitializedParams { + t.Errorf("InitializedParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializedParams, opts.State.InitializedParams) + } + }) + } +} + +// statelessHandlerCapture builds a stateless server with a single tool whose +// handler captures everything we want to assert about the per-request view of +// the session and the new-protocol accessors. +type statelessHandlerCapture struct { + mu sync.Mutex + sessionInitParams *InitializeParams + reqProtocolVersion string + reqClientInfo *Implementation + reqClientCapabilities *ClientCapabilities +} + +func TestStreamableStateless_NewProtocolSession_NoFakeInit(t *testing.T) { + // SEP-2575: the MCP-Protocol-Version header is mandatory for new-protocol + // requests and must be a supported version. The 2026-06-30 version is + // not yet in the global list, so register it for the duration of the test. + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + capture := &statelessHandlerCapture{} + mcpServer := NewServer(testImpl, nil) + AddTool(mcpServer, &Tool{Name: "capture", Description: "captures request info"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + capture.mu.Lock() + defer capture.mu.Unlock() + capture.sessionInitParams = req.Session.InitializeParams() + capture.reqProtocolVersion = req.ProtocolVersion() + capture.reqClientInfo = req.ClientInfo() + capture.reqClientCapabilities = req.ClientCapabilities() + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return mcpServer }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "capture", struct{}{}) + httpReq, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + httpReq.Header.Set(protocolVersionHeader, protocolVersion20260630) + // >= 2026-06-30 also requires the Mcp-Method and Mcp-Name standard + // headers (see streamable_headers.go). + httpReq.Header.Set(methodHeader, "tools/call") + httpReq.Header.Set(nameHeader, "capture") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } + + capture.mu.Lock() + defer capture.mu.Unlock() + if capture.sessionInitParams == nil { + t.Errorf("Session.InitializeParams() is nil, want populated initializeParams for new-protocol session") + } else { + if got, want := capture.sessionInitParams.ProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("Session.InitializeParams().ProtocolVersion = %q, want %q", got, want) + } + if got, want := capture.sessionInitParams.ClientInfo.Name, "new-proto-client"; got != want { + t.Errorf("Session.InitializeParams().ClientInfo.Name = %q, want %q", got, want) + } + } + if got, want := capture.reqProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("req.ProtocolVersion() = %q, want %q", got, want) + } + if capture.reqClientInfo == nil || capture.reqClientInfo.Name != "new-proto-client" { + t.Errorf("req.ClientInfo() = %+v, want Name=new-proto-client", capture.reqClientInfo) + } + if capture.reqClientCapabilities == nil || capture.reqClientCapabilities.Sampling == nil { + t.Errorf("req.ClientCapabilities() = %+v, want non-nil Sampling", capture.reqClientCapabilities) + } +} + +// TestStreamableStateful_RejectsNewProtocol verifies that a stateful HTTP +// server rejects requests carrying _meta.protocolVersion (i.e. >= 2026-06-30 +// requests) with HTTP 400. The new protocol is +// supported on HTTP only when StreamableHTTPOptions.Stateless=true. +func TestStreamableStateful_RejectsNewProtocol(t *testing.T) { + // Make 2026-06-30 a "known" version so that the request reaches servePOST + // (otherwise the early header validation at ServeHTTP rejects it). + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Initialize a legacy session first. + initBody := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`) + initReq, err := http.NewRequest(http.MethodPost, httpServer.URL, initBody) + if err != nil { + t.Fatal(err) + } + initReq.Header.Set("Content-Type", "application/json") + initReq.Header.Set("Accept", "application/json, text/event-stream") + initResp, err := http.DefaultClient.Do(initReq) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, initResp.Body) + initResp.Body.Close() + sessionID := initResp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialize response missing %s header", sessionIDHeader) + } + + // Drive the existing session with a new-protocol request whose header and + // body agree. The cross-check passes; the stateful-rejection check fires. + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(sessionIDHeader, sessionID) + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body = %s", resp.StatusCode, respBody) + } + if !strings.Contains(string(respBody), "stateless") { + t.Errorf("body = %q, want a message mentioning 'stateless'", respBody) + } +} + +// TestStreamableStateless_AcceptsNewProtocol is the positive control: +// confirms that a stateless server still accepts new-protocol requests +// (the rejection in TestStreamableStateful_RejectsNewProtocol must not +// fire on Stateless: true). +func TestStreamableStateless_AcceptsNewProtocol(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } +} diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 255aca92..711b17a4 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -12,7 +12,8 @@ import ( "errors" "fmt" "net/http" - "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/authutil" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -114,6 +115,13 @@ type AuthServerMeta struct { // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server // supports client ID metadata documents. ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` + + // AuthorizationResponseIssParameterSupported indicates whether the authorization server + // provides the "iss" parameter in authorization responses per [RFC 9207]. + // When true, clients must verify the "iss" parameter is present and matches the Issuer field. + // + // [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 + AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata @@ -146,8 +154,7 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. } return nil, fmt.Errorf("%v", err) // Do not expose error types. } - if strings.TrimRight(asm.Issuer, "/") != strings.TrimRight(issuer, "/") { - // Validate the Issuer field (see RFC 8414, section 3.3). + if !authutil.IssuersEqual(asm.Issuer, issuer) { return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuer) } diff --git a/oauthex/client.go b/oauthex/client.go index e8f99182..1b19eed5 100644 --- a/oauthex/client.go +++ b/oauthex/client.go @@ -18,6 +18,15 @@ type ClientCredentials struct { // This is the most common authentication method for confidential clients. // OPTIONAL. If not provided, the client is treated as a public client. ClientSecretAuth *ClientSecretAuth + + // Issuer is the issuer identifier of the authorization server these + // credentials are registered with. Pre-registered credentials are bound + // to a specific authorization server; when set, an error is returned if + // the discovered authorization server does not match, per SEP-2352. + // The comparison ignores a single trailing slash, matching the + // tolerance applied during RFC 8414 Section 3.3 metadata validation. + // OPTIONAL. + Issuer string } // ClientSecretAuth holds client secret authentication credentials. diff --git a/oauthex/client_test.go b/oauthex/client_test.go index b78e9c8b..34d8188c 100644 --- a/oauthex/client_test.go +++ b/oauthex/client_test.go @@ -73,7 +73,7 @@ func TestClientCredentials_ValidateCoversAllAuthFields(t *testing.T) { var pointerFields int for i := range typ.NumField() { f := typ.Field(i) - if f.Name == "ClientID" { + if f.Name == "ClientID" || f.Name == "Issuer" { continue } if f.Type.Kind() != reflect.Ptr {