From 7261e1582a7a4b5c1db963675a423bf74355930d Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Fri, 22 May 2026 11:53:11 +0200 Subject: [PATCH 01/10] auth: add AllowMissingExpiration option to RequireBearerTokenOptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some IdPs emit session-bound bearer tokens that do not carry a standalone `exp` claim — the token's lifetime is bounded by an external session and is not advertised in-band. Resource servers integrating with such IdPs need to opt in to validating the rest of the token (scopes, signature via the verifier callback) without requiring the expiration field to be present. Adds an AllowMissingExpiration bool to RequireBearerTokenOptions. Default false preserves the existing strict behaviour. When true, a TokenInfo with a zero Expiration is accepted; non-zero expirations are still checked for elapsed validity. Extends TestVerify with a "no expiration with AllowMissingExpiration accepts" case mirroring the existing strict-reject case. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/auth.go | 22 +++++++++++++++++++--- auth/auth_test.go | 5 +++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 40fa259f..f2058700 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -47,6 +47,21 @@ type RequireBearerTokenOptions struct { ResourceMetadataURL string // The required scopes. Scopes []string + // AllowMissingExpiration opts the middleware out of the + // `tokenInfo.Expiration.IsZero()` reject. Default false preserves the + // existing strict behaviour (every TokenInfo must carry an Expiration). + // + // Some IdPs emit session-bound bearer tokens that do not carry a standalone + // `exp` claim — the token's lifetime is bounded by an external session and + // is not advertised in-band. Resource servers integrating with such IdPs + // need to opt in to validating the rest of the token (scopes, signature + // via the verifier callback, etc.) without requiring the expiration field + // to be present. + // + // When enabled, the verifier is still responsible for any session-level + // validity check it can perform; this option only relaxes the middleware's + // own expiration enforcement. + AllowMissingExpiration bool } type tokenInfoKey struct{} @@ -131,9 +146,10 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO // Check expiration. if tokenInfo.Expiration.IsZero() { - return nil, "token missing expiration", http.StatusUnauthorized - } - if tokenInfo.Expiration.Before(time.Now()) { + if opts == nil || !opts.AllowMissingExpiration { + return nil, "token missing expiration", http.StatusUnauthorized + } + } else if tokenInfo.Expiration.Before(time.Now()) { return nil, "token expired", http.StatusUnauthorized } return tokenInfo, "", 0 diff --git a/auth/auth_test.go b/auth/auth_test.go index 4028c907..fe523a14 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -62,6 +62,11 @@ func TestVerify(t *testing.T) { "no expiration", nil, "Bearer noexp", "token missing expiration", 401, }, + { + "no expiration with AllowMissingExpiration accepts", + &RequireBearerTokenOptions{AllowMissingExpiration: true}, "Bearer noexp", + "", 0, + }, { "expired", nil, "Bearer expired", "token expired", 401, From bed83e104b39b62b834ecf1b0838c4013b2d1043 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Mon, 25 May 2026 21:15:26 +0800 Subject: [PATCH 02/10] fix: reject duplicate initialize requests (#962) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - reject a second `initialize` request on an already initialized server session - keep the original `ServerSession.InitializeParams()` instead of replacing it with later client parameters - add a raw JSON-RPC regression test for the duplicate-initialize path Fixes #961. ## To verify - `go test ./mcp -run TestServerRejectsDuplicateInitialize -count=1` - `go test ./mcp -count=1` - `go test ./...` --------- Co-authored-by: guglielmoc --- mcp/server.go | 10 +++++- mcp/server_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/mcp/server.go b/mcp/server.go index 183226d1..7526ea7b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1488,9 +1488,17 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams) } + var wasInit bool ss.updateState(func(state *ServerSessionState) { - state.InitializeParams = params + wasInit = state.InitializeParams != nil + if !wasInit { + state.InitializeParams = params + } }) + if wasInit { + ss.server.opts.Logger.Error("duplicate initialize request") + return nil, fmt.Errorf("duplicate %q received", methodInitialize) + } s := ss.server return &InitializeResult{ diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..dcc7fb0e 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -825,6 +825,92 @@ func TestClientRootCapabilities(t *testing.T) { } } +func TestServerRejectsDuplicateInitialize(t *testing.T) { + ctx := context.Background() + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + cTransport, sTransport := NewInMemoryTransports() + ss, err := server.Connect(ctx, sTransport, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cConn, err := cTransport.Connect(ctx) + if err != nil { + t.Fatal(err) + } + defer cConn.Close() + + firstParams := json.RawMessage(`{ + "protocolVersion": "2025-11-25", + "clientInfo": {"name": "first-client", "version": "1.0.0"} + }`) + firstReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(1), methodInitialize, firstParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, firstReq); err != nil { + t.Fatalf("first initialize write failed: %v", err) + } + msg, err := cConn.Read(ctx) + if err != nil { + t.Fatalf("first initialize read failed: %v", err) + } + resp, ok := msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error != nil { + t.Fatalf("first initialize failed: %v", resp.Error) + } + + initializedReq, err := jsonrpc2.NewNotification(notificationInitialized, &InitializedParams{}) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, initializedReq); err != nil { + t.Fatalf("initialized notification write failed: %v", err) + } + + secondParams := json.RawMessage(`{ + "protocolVersion": "2024-11-05", + "clientInfo": {"name": "second-client", "version": "2.0.0"} + }`) + secondReq, err := jsonrpc2.NewCall(jsonrpc2.Int64ID(2), methodInitialize, secondParams) + if err != nil { + t.Fatal(err) + } + if err := cConn.Write(ctx, secondReq); err != nil { + t.Fatalf("second initialize write failed: %v", err) + } + msg, err = cConn.Read(ctx) + if err != nil { + t.Fatalf("second initialize read failed: %v", err) + } + resp, ok = msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected Response, got %T", msg) + } + if resp.Error == nil { + t.Fatal("second initialize unexpectedly succeeded") + } + if !strings.Contains(resp.Error.Error(), `duplicate "initialize" received`) { + t.Fatalf("second initialize error = %v, want duplicate initialize", resp.Error) + } + + got := ss.InitializeParams() + if got == nil { + t.Fatal("InitializeParams is nil") + } + if got.ProtocolVersion != "2025-11-25" { + t.Fatalf("ProtocolVersion = %q, want first initialize value", got.ProtocolVersion) + } + if got.ClientInfo == nil || got.ClientInfo.Name != "first-client" { + t.Fatalf("ClientInfo = %#v, want first initialize value", got.ClientInfo) + } +} + // TODO: move this to tool_test.go func TestToolForSchemas(t *testing.T) { // Validate that toolForErr handles schemas properly. From c044cdb734d8d321d596f5b6077072798dce4ad6 Mon Sep 17 00:00:00 2001 From: Pavel Bazin Date: Mon, 25 May 2026 10:52:15 -0300 Subject: [PATCH 03/10] fix: do not `omitempty` `ReadOnlyHint` in `ToolAnnotations` (#908) `ReadOnlyHint` is a `bool` with omitempty, so the zero value (false) is indistinguishable from unset, and **drops out** of marshaled JSON. The current behavior causes an issue with the OpenAI MCP app submission because they require explicit hints. Consumers that explicitly set `ReadOnlyHint: false` on write tools lose the field on the wire. Removing `omitempty` ensures `false` is always serialized, which matches the MCP spec default. --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Maciej Kisiel Co-authored-by: guglielmoc --- docs/mcpgodebug.md | 6 ++++ docs/rough_edges.md | 5 +++ internal/docs/mcpgodebug.src.md | 6 ++++ internal/docs/rough_edges.src.md | 5 +++ mcp/protocol.go | 30 +++++++++++++++-- mcp/protocol_test.go | 57 ++++++++++++++++++++++++++++++++ 6 files changed, 107 insertions(+), 2 deletions(-) diff --git a/docs/mcpgodebug.md b/docs/mcpgodebug.md index 36eddb52..f1e62373 100644 --- a/docs/mcpgodebug.md +++ b/docs/mcpgodebug.md @@ -27,6 +27,12 @@ Options listed below were added and will be removed in the 1.9.0 version of the Params), restoring the previous behavior. The default behavior was changed to align with SEP-2164 and the JSON-RPC specification. +- `hintomitempty` added. If set to `1`, `ToolAnnotations` JSON marshaling + will omit `ReadOnlyHint` and `IdempotentHint` when their value is `false`, + restoring the previous behavior. The default behavior was changed to always + serialize these fields, since their Go types are bare `bool` (not `*bool`) + and omitting `false` made it indistinguishable from unset. + - `allowsessionsinstateless` added. If set to `1`, stateless streamable HTTP servers will read the `Mcp-Session-Id` request header (or generate one via `GetSessionID`), set it on response headers, and accept `DELETE` requests, diff --git a/docs/rough_edges.md b/docs/rough_edges.md index 7ee18b1a..c1a618df 100644 --- a/docs/rough_edges.md +++ b/docs/rough_edges.md @@ -63,3 +63,8 @@ v2. - `StreamableHTTPOptions.CrossOriginProtection` should not have been part of the SDK API. Cross-origin protection is a general HTTP concern, not specific to MCP, and can be applied as standard HTTP middleware. + +- `ToolAnnotations` (`mcp/protocol.go`) should have all fields typed as `*bool` + for full control to define what is being sent over the wire. Different + MCP clients have different requirements, and some of them require all fields + to be explicitly set to either `true` or `false`. diff --git a/internal/docs/mcpgodebug.src.md b/internal/docs/mcpgodebug.src.md index b5f5a781..88639a26 100644 --- a/internal/docs/mcpgodebug.src.md +++ b/internal/docs/mcpgodebug.src.md @@ -26,6 +26,12 @@ Options listed below were added and will be removed in the 1.9.0 version of the Params), restoring the previous behavior. The default behavior was changed to align with SEP-2164 and the JSON-RPC specification. +- `hintomitempty` added. If set to `1`, `ToolAnnotations` JSON marshaling + will omit `ReadOnlyHint` and `IdempotentHint` when their value is `false`, + restoring the previous behavior. The default behavior was changed to always + serialize these fields, since their Go types are bare `bool` (not `*bool`) + and omitting `false` made it indistinguishable from unset. + - `allowsessionsinstateless` added. If set to `1`, stateless streamable HTTP servers will read the `Mcp-Session-Id` request header (or generate one via `GetSessionID`), set it on response headers, and accept `DELETE` requests, diff --git a/internal/docs/rough_edges.src.md b/internal/docs/rough_edges.src.md index d573e141..90751f74 100644 --- a/internal/docs/rough_edges.src.md +++ b/internal/docs/rough_edges.src.md @@ -62,3 +62,8 @@ v2. - `StreamableHTTPOptions.CrossOriginProtection` should not have been part of the SDK API. Cross-origin protection is a general HTTP concern, not specific to MCP, and can be applied as standard HTTP middleware. + +- `ToolAnnotations` (`mcp/protocol.go`) should have all fields typed as `*bool` + for full control to define what is being sent over the wire. Different + MCP clients have different requirements, and some of them require all fields + to be explicitly set to either `true` or `false`. \ No newline at end of file diff --git a/mcp/protocol.go b/mcp/protocol.go index 1646788a..0af1ec57 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1346,6 +1346,13 @@ type Tool struct { Icons []Icon `json:"icons,omitempty"` } +// hintomitempty is a compatibility parameter that restores the pre-1.7.0 +// behavior of [ToolAnnotations] JSON marshaling, where false-valued bare bool +// fields (ReadOnlyHint, IdempotentHint) were omitted from the output. +// See the documentation for the mcpgodebug package for instructions on how to +// enable it. +var hintomitempty = mcpgodebug.Value("hintomitempty") + // Additional properties describing a Tool to clients. // // NOTE: all properties in ToolAnnotations are hints. They are not @@ -1368,7 +1375,7 @@ type ToolAnnotations struct { // (This property is meaningful only when ReadOnlyHint == false.) // // Default: false - IdempotentHint bool `json:"idempotentHint,omitempty"` + IdempotentHint bool `json:"idempotentHint"` // If true, this tool may interact with an "open world" of external entities. If // false, the tool's domain of interaction is closed. For example, the world of // a web search tool is open, whereas that of a memory tool is not. @@ -1378,11 +1385,30 @@ type ToolAnnotations struct { // If true, the tool does not modify its environment. // // Default: false - ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + ReadOnlyHint bool `json:"readOnlyHint"` // A human-readable title for the tool. Title string `json:"title,omitempty"` } +// MarshalJSON implements [json.Marshaler] for ToolAnnotations. +// +// To restore the previous behavior where false-valued ReadOnlyHint and +// IdempotentHint were omitted, set MCPGODEBUG=hintomitempty=1. +func (t ToolAnnotations) MarshalJSON() ([]byte, error) { + if hintomitempty == "1" { + type compat struct { + DestructiveHint *bool `json:"destructiveHint,omitempty"` + IdempotentHint bool `json:"idempotentHint,omitempty"` + OpenWorldHint *bool `json:"openWorldHint,omitempty"` + ReadOnlyHint bool `json:"readOnlyHint,omitempty"` + Title string `json:"title,omitempty"` + } + return json.Marshal(compat(t)) + } + type nomethod ToolAnnotations + return json.Marshal(nomethod(t)) +} + type ToolListChangedParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index 751d0812..2bab644e 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -1194,3 +1194,60 @@ func TestContentUnmarshal(t *testing.T) { var gotpm PromptMessage roundtrip(pm, &gotpm) } + +func TestToolAnnotations_MarshalJSON(t *testing.T) { + boolPtr := func(b bool) *bool { return &b } + + tests := []struct { + name string + in ToolAnnotations + want string + }{ + { + name: "ZeroValue", + in: ToolAnnotations{}, + want: `{"idempotentHint":false,"readOnlyHint":false}`, + }, + { + name: "AllFalse", + in: ToolAnnotations{ + DestructiveHint: boolPtr(false), + IdempotentHint: false, + OpenWorldHint: boolPtr(false), + ReadOnlyHint: false, + }, + want: `{"destructiveHint":false,"idempotentHint":false,"openWorldHint":false,"readOnlyHint":false}`, + }, + { + name: "AllTrue", + in: ToolAnnotations{ + DestructiveHint: boolPtr(true), + IdempotentHint: true, + OpenWorldHint: boolPtr(true), + ReadOnlyHint: true, + Title: "my tool", + }, + want: `{"destructiveHint":true,"idempotentHint":true,"openWorldHint":true,"readOnlyHint":true,"title":"my tool"}`, + }, + { + name: "MixedValues", + in: ToolAnnotations{ + ReadOnlyHint: true, + Title: "read tool", + }, + want: `{"idempotentHint":false,"readOnlyHint":true,"title":"read tool"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.in) + if err != nil { + t.Fatalf("json.Marshal(%v) failed: %v", tt.in, err) + } + if diff := cmp.Diff(tt.want, string(got)); diff != "" { + t.Errorf("json.Marshal() mismatch (-want +got):\n%s", diff) + } + }) + } +} From 4cbdd6a514a8f5718a40f1baaec8a0af99717950 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Tue, 26 May 2026 16:25:51 +0200 Subject: [PATCH 04/10] mcp: Implement stateless server (SEP-2275) (#965) ## Description This PR lays the foundational server-side groundwork for the `>= 2026-06-30` sessionless and stateless feature introduced by [SEP-2575](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2575) and SEP-2567, tracked in [design/stateless.md](https://github.com/modelcontextprotocol/go-sdk/blob/614460a2253e7772eff6c78f142bfa0428530dc5/design/stateless.md). ### SEP-2575: Stateless MCP * **Per-request protocol detection in `ServerSession.handle()`** Unmarshal `_meta` from the raw JSON-RPC params determines whether each request follows the new protocol. * **Per-request typed accessors on `ServerRequest[P]`** ```go func (r *ServerRequest[P]) ProtocolVersion() string func (r *ServerRequest[P]) ClientInfo() *Implementation func (r *ServerRequest[P]) ClientCapabilities() *ClientCapabilities ``` * **Reject client->server `initialize`, `initialized` and `ping` for new-protocol requests** * **Per-request `_meta` field name constants** Three constants for the wire-protocol field names (`MetaKeyProtocolVersion`, `MetaKeyClientInfo`, `MetaKeyClientCapabilities`) * **Stop synthesizing fake `InitializeParams` for new-protocol requests** Fixes: #966 --- mcp/protocol.go | 56 ++++- mcp/server.go | 27 ++- mcp/server_test.go | 166 +++++++++++++++ mcp/shared.go | 155 ++++++++++++++ mcp/shared_test.go | 241 ++++++++++++++++++++++ mcp/streamable.go | 77 +++++-- mcp/streamable_test.go | 453 +++++++++++++++++++++++++++++++++++++---- 7 files changed, 1113 insertions(+), 62 deletions(-) diff --git a/mcp/protocol.go b/mcp/protocol.go index 0af1ec57..fcba83f1 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -165,10 +165,12 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { } func (x *CallToolParams) isParams() {} +func (x *CallToolParams) isNil() bool { return x == nil } func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *CallToolParamsRaw) isParams() {} +func (x *CallToolParamsRaw) isNil() bool { return x == nil } func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) } @@ -187,6 +189,7 @@ type CancelledParams struct { } func (x *CancelledParams) isParams() {} +func (x *CancelledParams) isNil() bool { return x == nil } func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -374,7 +377,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } -func (*CompleteParams) isParams() {} +func (x *CompleteParams) isParams() {} +func (x *CompleteParams) isNil() bool { return x == nil } type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` @@ -422,6 +426,7 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isNil() bool { return x == nil } func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -448,6 +453,7 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isNil() bool { return x == nil } func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -654,6 +660,7 @@ type GetPromptParams struct { } func (x *GetPromptParams) isParams() {} +func (x *GetPromptParams) isNil() bool { return x == nil } func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -706,6 +713,7 @@ func (p *initializeParamsV2) toV1() *InitializeParams { } func (x *InitializeParams) isParams() {} +func (x *InitializeParams) isNil() bool { return x == nil } func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -739,6 +747,7 @@ type InitializedParams struct { } func (x *InitializedParams) isParams() {} +func (x *InitializedParams) isNil() bool { return x == nil } func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -752,6 +761,7 @@ type ListPromptsParams struct { } func (x *ListPromptsParams) isParams() {} +func (x *ListPromptsParams) isNil() bool { return x == nil } func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -780,6 +790,7 @@ type ListResourceTemplatesParams struct { } func (x *ListResourceTemplatesParams) isParams() {} +func (x *ListResourceTemplatesParams) isNil() bool { return x == nil } func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -808,6 +819,7 @@ type ListResourcesParams struct { } func (x *ListResourcesParams) isParams() {} +func (x *ListResourcesParams) isNil() bool { return x == nil } func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -833,6 +845,7 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isNil() bool { return x == nil } func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -858,6 +871,7 @@ type ListToolsParams struct { } func (x *ListToolsParams) isParams() {} +func (x *ListToolsParams) isNil() bool { return x == nil } func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -896,6 +910,7 @@ type LoggingMessageParams struct { } func (x *LoggingMessageParams) isParams() {} +func (x *LoggingMessageParams) isNil() bool { return x == nil } func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -958,6 +973,7 @@ type PingParams struct { } func (x *PingParams) isParams() {} +func (x *PingParams) isNil() bool { return x == nil } func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -978,7 +994,8 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } -func (*ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isParams() {} +func (x *ProgressNotificationParams) isNil() bool { return x == nil } // IconTheme specifies the theme an icon is designed for. type IconTheme string @@ -1048,6 +1065,7 @@ type PromptListChangedParams struct { } func (x *PromptListChangedParams) isParams() {} +func (x *PromptListChangedParams) isNil() bool { return x == nil } func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1089,6 +1107,7 @@ type ReadResourceParams struct { } func (x *ReadResourceParams) isParams() {} +func (x *ReadResourceParams) isNil() bool { return x == nil } func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1145,6 +1164,7 @@ type ResourceListChangedParams struct { } func (x *ResourceListChangedParams) isParams() {} +func (x *ResourceListChangedParams) isNil() bool { return x == nil } func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1205,6 +1225,7 @@ type RootsListChangedParams struct { } func (x *RootsListChangedParams) isParams() {} +func (x *RootsListChangedParams) isNil() bool { return x == nil } func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1288,6 +1309,7 @@ type SetLoggingLevelParams struct { } func (x *SetLoggingLevelParams) isParams() {} +func (x *SetLoggingLevelParams) isNil() bool { return x == nil } func (x *SetLoggingLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLoggingLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1416,6 +1438,7 @@ type ToolListChangedParams struct { } func (x *ToolListChangedParams) isParams() {} +func (x *ToolListChangedParams) isNil() bool { return x == nil } func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1429,7 +1452,8 @@ type SubscribeParams struct { URI string `json:"uri"` } -func (*SubscribeParams) isParams() {} +func (x *SubscribeParams) isParams() {} +func (x *SubscribeParams) isNil() bool { return x == nil } // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous @@ -1442,7 +1466,8 @@ type UnsubscribeParams struct { URI string `json:"uri"` } -func (*UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isParams() {} +func (x *UnsubscribeParams) isNil() bool { return x == nil } // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the @@ -1455,7 +1480,8 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } -func (*ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isParams() {} +func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } // TODO(jba): add CompleteRequest and related types. @@ -1494,7 +1520,8 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isNil() bool { return x == nil } func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1526,7 +1553,8 @@ type ElicitationCompleteParams struct { ElicitationID string `json:"elicitationId"` } -func (*ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isParams() {} +func (x *ElicitationCompleteParams) isNil() bool { return x == nil } // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. @@ -1656,3 +1684,17 @@ const ( notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" ) + +// Per-request _meta field names for the >= 2026-06-30 protocol version. +// +// These keys appear inside a Params._meta map and carry information that +// previously came from the initialization handshake (SEP-2575). +const ( + // MetaKeyProtocolVersion identifies the MCP protocol version that the + // request follows. + MetaKeyProtocolVersion = "io.modelcontextprotocol/protocolVersion" + // MetaKeyClientInfo carries the client's [Implementation]. + MetaKeyClientInfo = "io.modelcontextprotocol/clientInfo" + // MetaKeyClientCapabilities carries the client's [ClientCapabilities]. + MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" +) diff --git a/mcp/server.go b/mcp/server.go index 7526ea7b..8d24147e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1450,13 +1450,32 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, initialized := ss.state.InitializeParams != nil ss.mu.Unlock() - // From the spec: - // "The client SHOULD NOT send requests other than pings before the server - // has responded to the initialize request." + // Per-request protocol detection (SEP-2575): if the request carries + // `io.modelcontextprotocol/protocolVersion` in its `_meta` field, it + // follows the new sessionless protocol. The initialization gate is + // skipped for such requests. + validatedMeta, perRequestErr := validateRequestMeta(req) + if perRequestErr != nil { + return nil, perRequestErr + } + + if !initialized && validatedMeta.usesNewProtocol && validatedMeta.initializeParams != nil { + ss.updateState(func(state *ServerSessionState) { + state.InitializeParams = validatedMeta.initializeParams + }) + } + switch req.Method { case methodInitialize, methodPing, notificationInitialized: + if validatedMeta.usesNewProtocol { + ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: fmt.Sprintf("%q is not supported in the new protocol", req.Method), + } + } default: - if !initialized { + if !initialized && !validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method invalid during initialization", "method", req.Method) return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } diff --git a/mcp/server_test.go b/mcp/server_test.go index dcc7fb0e..7cc780a9 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "log" "log/slog" @@ -19,6 +20,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) type testItem struct { @@ -1093,3 +1095,167 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +// SEP-2575 removes the initialization handshake. An `initialize` request +// that opts into the new protocol via `_meta.protocolVersion` must be +// rejected with `Method not found` (-32601). +func TestServerSessionHandle_RejectsInitializeOnNewProtocol(t *testing.T) { + tests := []struct { + name string + params any + wantReject bool + }{ + { + name: "initialize with new-protocol _meta is rejected", + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }, + wantReject: true, + }, + { + name: "initialize without _meta is allowed (old protocol)", + params: map[string]any{ + "protocolVersion": protocolVersion20251125, + }, + wantReject: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(tc.params), + } + _, err = ss.handle(context.Background(), req) + if tc.wantReject { + if err == nil { + t.Fatal("expected error rejecting initialize, got nil") + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error so the wire returns the right code", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("error code = %d, want %d (CodeMethodNotFound = -32601)", jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, "initialize") { + t.Errorf("error message %q does not mention %q", jerr.Message, "initialize") + } + } else { + // Old-protocol initialize should be dispatched normally; any + // CodeMethodNotFound here would mean the rejection branch + // fired incorrectly. + var jerr *jsonrpc.Error + if errors.As(err, &jerr) && jerr.Code == jsonrpc.CodeMethodNotFound { + t.Errorf("old-protocol initialize was incorrectly rejected: %v", err) + } + } + }) + } + + t.Run("rejection error encodes to wire as code -32601", func(t *testing.T) { + // Belt-and-braces check that the error type produced by handle() + // actually serializes to JSON-RPC code -32601, not a bare 0. + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: methodInitialize, + Params: mustMarshal(map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "protocolVersion": protocolVersion20260630, + }), + } + _, handleErr := ss.handle(context.Background(), req) + if handleErr == nil { + t.Fatal("expected rejection error, got nil") + } + data, encErr := jsonrpc.EncodeMessage(&jsonrpc.Response{ID: id, Error: handleErr.(*jsonrpc.Error)}) + if encErr != nil { + t.Fatal(encErr) + } + var wire struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(data, &wire); err != nil { + t.Fatal(err) + } + if wire.Error.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("wire error code = %d, want %d; full response = %s", wire.Error.Code, jsonrpc.CodeMethodNotFound, data) + } + }) +} + +// TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol verifies that +// the methods removed by SEP-2575 (`initialize`, `notifications/initialized`, +// `ping`) all return Method not found when the request opts into the new +// protocol via `_meta.protocolVersion`. +func TestServerSessionHandle_RejectsRemovedMethodsOnNewProtocol(t *testing.T) { + newProtoMeta := map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + } + + tests := []struct { + name string + method string + }{ + {"initialize", methodInitialize}, + {"ping", methodPing}, + {"notifications/initialized", notificationInitialized}, + } + + for _, tc := range tests { + t.Run(tc.name+" rejected on new protocol", func(t *testing.T) { + ss := &ServerSession{server: NewServer(testImpl, nil)} + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req := &jsonrpc.Request{ + ID: id, + Method: tc.method, + Params: mustMarshal(newProtoMeta), + } + _, err = ss.handle(context.Background(), req) + if err == nil { + t.Fatalf("method %q on new protocol: got nil error, want CodeMethodNotFound", tc.method) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("error type = %T, want *jsonrpc.Error", err) + } + if jerr.Code != jsonrpc.CodeMethodNotFound { + t.Errorf("method %q: code = %d, want %d", tc.method, jerr.Code, jsonrpc.CodeMethodNotFound) + } + if !strings.Contains(jerr.Message, tc.method) { + t.Errorf("method %q: message %q does not mention method name", tc.method, jerr.Message) + } + }) + } +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..1caacac3 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -465,6 +465,70 @@ func setProgressToken(p Params, pt any) { m[progressTokenKey] = pt } +// extractRequestMeta performs a lightweight partial unmarshal of the `_meta` +// field from a JSON-RPC request's raw params. +func extractRequestMeta(rawParams json.RawMessage) Meta { + if len(rawParams) == 0 { + return nil + } + var meta struct { + Meta Meta `json:"_meta"` + } + if err := internaljson.Unmarshal(rawParams, &meta); err != nil { + return nil + } + return meta.Meta +} + +type validatedMeta struct { + usesNewProtocol bool + initializeParams *InitializeParams +} + +// validateRequestMeta inspects a JSON-RPC request to detect whether it follows +// the >= 2026-06-30 protocol via the `_meta` field. +// If the request has no _meta, or no protocolVersion in _meta, it returns a non-nil +// validatedMeta with usesNewProtocol set to false, and a nil error. +// If the request has a protocolVersion in _meta: +// - For notifications, it returns usesNewProtocol set to true and a nil initializeParams. +// - For call requests, it validates the presence of clientInfo and clientCapabilities in _meta. +// If either is missing or invalid, it returns nil and a non-nil error. Otherwise, it returns +// usesNewProtocol set to true and the populated initializeParams. +func validateRequestMeta(req *jsonrpc.Request) (*validatedMeta, error) { + meta := extractRequestMeta(req.Params) + if meta == nil { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + protocolVersion, ok := meta[MetaKeyProtocolVersion].(string) + if !ok { + return &validatedMeta{usesNewProtocol: false, initializeParams: nil}, nil + } + // Notifications do not carry full client identity. In new protocol, only cancel notification + // is allowed in STDIO. + if !req.IsCall() { + return &validatedMeta{usesNewProtocol: true, initializeParams: nil}, nil + } + clientInfo, ok := decodeMetaValue[*Implementation](meta, MetaKeyClientInfo) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientInfo), + } + } + capabilities, ok := decodeMetaValue[*ClientCapabilities](meta, MetaKeyClientCapabilities) + if !ok { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: fmt.Sprintf("missing or invalid _meta field %q", MetaKeyClientCapabilities), + } + } + return &validatedMeta{usesNewProtocol: true, initializeParams: &InitializeParams{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ClientInfo: clientInfo, + }}, nil +} + // A Request is a method request with parameters and additional information, such as the session. // Request is implemented by [*ClientRequest] and [*ServerRequest]. type Request interface { @@ -525,6 +589,94 @@ func (r *ServerRequest[P]) GetParams() Params { return r.Params } func (r *ClientRequest[P]) GetExtra() *RequestExtra { return nil } func (r *ServerRequest[P]) GetExtra() *RequestExtra { return r.Extra } +// ProtocolVersion returns the protocol version negotiated for this request. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams] established during the +// initialize handshake. +func (r *ServerRequest[P]) ProtocolVersion() string { + if m := getRequestMeta(r); m != nil { + if v, ok := m[MetaKeyProtocolVersion].(string); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ProtocolVersion + } + } + return "" +} + +// ClientInfo returns the [Implementation] identifying the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientInfo() *Implementation { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*Implementation](m, MetaKeyClientInfo); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.ClientInfo + } + } + return nil +} + +// ClientCapabilities returns the [ClientCapabilities] of the calling client. +// +// For requests following the >= 2026-06-30 protocol, the value is read from +// the per-request `_meta` field. For older protocol requests, the value falls +// back to the session-level [InitializeParams]. +func (r *ServerRequest[P]) ClientCapabilities() *ClientCapabilities { + if m := getRequestMeta(r); m != nil { + if v, ok := decodeMetaValue[*ClientCapabilities](m, MetaKeyClientCapabilities); ok { + return v + } + } + if r.Session != nil { + if p := r.Session.InitializeParams(); p != nil { + return p.Capabilities + } + } + return nil +} + +// getRequestMeta returns the raw `_meta` map from the request's params, or +// nil if the params are absent. +func getRequestMeta[P Params](r *ServerRequest[P]) map[string]any { + // In practice P is a pointer type implementing Params. + if any(r.Params) == nil || r.Params.isNil() { + return nil + } + return r.Params.GetMeta() +} + +// decodeMetaValue decodes a typed value out of a `_meta` map. Values may +// arrive either as the typed Go value (when constructed in-process) or as +// the generic JSON map produced by encoding/json after wire transit. In the +// latter case, the value is re-encoded and decoded into the target type. +func decodeMetaValue[T any](m map[string]any, key string) (T, bool) { + var zero T + raw, ok := m[key] + if !ok || raw == nil { + return zero, false + } + if v, ok := raw.(T); ok { + return v, true + } + var v T + if err := remarshal(raw, &v); err != nil { + return zero, false + } + return v, true +} + func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { return &ServerRequest[P]{Session: s, Params: p} } @@ -542,6 +694,9 @@ type Params interface { // isParams discourages implementation of Params outside of this package. isParams() + + // isNil returns true if the underlying value is nil. + isNil() bool } // RequestParams is a parameter (input) type for an MCP request. diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 23818f87..e4f563d1 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,6 +4,247 @@ package mcp +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestValidateRequestMeta(t *testing.T) { + tests := []struct { + name string + method string + isNotification bool + params any + wantUsesNew bool + wantErrContains string + }{ + { + name: "no params: old protocol", + method: methodListTools, + params: nil, + wantUsesNew: false, + }, + { + name: "no _meta: old protocol", + method: methodCallTool, + params: map[string]any{"name": "x"}, + wantUsesNew: false, + }, + { + name: "_meta without protocolVersion: old protocol", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{"otherKey": "v"}, + "name": "x", + }, + wantUsesNew: false, + }, + { + name: "new protocol with all required fields", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: true, + }, + { + name: "new protocol missing clientInfo", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientCapabilities: map[string]any{}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientInfo, + }, + { + name: "new protocol missing clientCapabilities", + method: methodCallTool, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "c", "version": "1"}, + }, + "name": "x", + }, + wantUsesNew: false, + wantErrContains: MetaKeyClientCapabilities, + }, + { + name: "notifications exempt from required fields", + method: notificationCancelled, + isNotification: true, + params: map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + }, + "requestId": "r1", + }, + wantUsesNew: true, + }, + { + name: "malformed _meta is ignored", + method: methodCallTool, + params: json.RawMessage(`{"_meta": "not an object", "name": "x"}`), + wantUsesNew: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var raw json.RawMessage + switch p := tc.params.(type) { + case json.RawMessage: + raw = p + default: + raw = mustMarshal(tc.params) + } + req := &jsonrpc.Request{Method: tc.method, Params: raw} + if !tc.isNotification { + req.ID = jsonrpc.ID{} + // Give the request an ID by parsing one. + id, err := jsonrpc.MakeID("test") + if err != nil { + t.Fatal(err) + } + req.ID = id + } + + vmeta, err := validateRequestMeta(req) + usesNew := vmeta != nil && vmeta.usesNewProtocol + if usesNew != tc.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tc.wantUsesNew) + } + if tc.wantErrContains == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErrContains) + } + var jerr *jsonrpc.Error + if !errors.As(err, &jerr) { + t.Fatalf("expected *jsonrpc.Error, got %T: %v", err, err) + } + if jerr.Code != jsonrpc.CodeInvalidParams { + t.Errorf("error code = %d, want %d", jerr.Code, jsonrpc.CodeInvalidParams) + } + if !strings.Contains(jerr.Message, tc.wantErrContains) { + t.Errorf("error message %q does not contain %q", jerr.Message, tc.wantErrContains) + } + }) + } +} + +func TestServerRequest_PerRequestAccessors(t *testing.T) { + // A request carrying the new-protocol _meta fields populates the + // accessors with values from _meta. + caps := &ClientCapabilities{Sampling: &SamplingCapabilities{}} + info := &Implementation{Name: "c", Version: "1"} + params := &CallToolParamsRaw{ + Meta: Meta{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: info, + MetaKeyClientCapabilities: caps, + }, + Name: "x", + } + req := &ServerRequest[*CallToolParamsRaw]{Params: params} + if got := req.ProtocolVersion(); got != protocolVersion20260630 { + t.Errorf("ProtocolVersion = %q, want %q", got, protocolVersion20260630) + } + if got := req.ClientInfo(); got == nil || got.Name != "c" { + t.Errorf("ClientInfo = %+v, want Name=c", got) + } + if got := req.ClientCapabilities(); got == nil || got.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", got) + } +} + +func TestServerRequest_PerRequestAccessors_FromJSON(t *testing.T) { + // Values arriving over the wire are JSON maps; the accessors should + // re-decode them into typed Go values. + raw := json.RawMessage(`{ + "_meta": { + "io.modelcontextprotocol/protocolVersion": "2026-06-30", + "io.modelcontextprotocol/clientInfo": {"name": "wire-client", "version": "9"}, + "io.modelcontextprotocol/clientCapabilities": {"sampling": {}} + }, + "name": "tool" + }`) + var params CallToolParamsRaw + if err := json.Unmarshal(raw, ¶ms); err != nil { + t.Fatal(err) + } + req := &ServerRequest[*CallToolParamsRaw]{Params: ¶ms} + if got, want := req.ProtocolVersion(), protocolVersion20260630; got != want { + t.Errorf("ProtocolVersion = %q, want %q", got, want) + } + gotInfo := req.ClientInfo() + wantInfo := &Implementation{Name: "wire-client", Version: "9"} + if diff := cmp.Diff(wantInfo, gotInfo); diff != "" { + t.Errorf("ClientInfo mismatch (-want +got):\n%s", diff) + } + gotCaps := req.ClientCapabilities() + if gotCaps == nil || gotCaps.Sampling == nil { + t.Errorf("ClientCapabilities = %+v, want non-nil Sampling", gotCaps) + } +} + +func TestServerRequest_PerRequestAccessors_FallbackToInitializeParams(t *testing.T) { + // With no _meta on the request, accessors must fall back to the + // session's InitializeParams (the old-protocol path). + ss := &ServerSession{} + ss.state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion20251125, + ClientInfo: &Implementation{Name: "old", Version: "0"}, + Capabilities: &ClientCapabilities{Elicitation: &ElicitationCapabilities{}}, + } + req := &ServerRequest[*CallToolParamsRaw]{ + Session: ss, + Params: &CallToolParamsRaw{Name: "x"}, + } + if got, want := req.ProtocolVersion(), protocolVersion20251125; got != want { + t.Errorf("ProtocolVersion fallback = %q, want %q", got, want) + } + if got := req.ClientInfo(); got == nil || got.Name != "old" { + t.Errorf("ClientInfo fallback = %+v, want Name=old", got) + } + if got := req.ClientCapabilities(); got == nil || got.Elicitation == nil { + t.Errorf("ClientCapabilities fallback = %+v, want non-nil Elicitation", got) + } +} + +func TestServerRequest_PerRequestAccessors_Empty(t *testing.T) { + // With no _meta and no session, accessors return zero values. + req := &ServerRequest[*CallToolParamsRaw]{ + Params: &CallToolParamsRaw{Name: "x"}, + } + if got := req.ProtocolVersion(); got != "" { + t.Errorf("ProtocolVersion = %q, want empty", got) + } + if got := req.ClientInfo(); got != nil { + t.Errorf("ClientInfo = %+v, want nil", got) + } + if got := req.ClientCapabilities(); got != nil { + t.Errorf("ClientCapabilities = %+v, want nil", got) + } +} + // TODO(v0.3.0): rewrite this test. // func TestToolValidate(t *testing.T) { // // Check that the tool returned from NewServerTool properly validates its input schema. diff --git a/mcp/streamable.go b/mcp/streamable.go index d3f3f4fa..0f4e65b8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -343,8 +343,14 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } + connectOpts, usesNewProtocol, err := h.ephemeralConnectOpts(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var sessionID string - if legacySessions { + if legacySessions && !usesNewProtocol { sessionID = req.Header.Get(sessionIDHeader) if sessionID == "" { sessionID = server.opts.GetSessionID() @@ -359,11 +365,6 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. logger: h.opts.Logger, } - connectOpts, err := h.ephemeralConnectOpts(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } session, err := connectStreamable(req.Context(), server, transport, connectOpts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) @@ -389,10 +390,17 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter } // ephemeralConnectOpts peeks at the request body to determine whether it -// contains an initialize or initialized message. If not, default session state -// is constructed so that the session doesn't reject the request. +// contains an initialize or initialized message or whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +// +// For old-protocol requests, default session state is synthesized so that +// the session's init gate doesn't reject the request. +// // It is used for both stateless servers and stateful servers with no session ID. -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ServerSessionOptions, error) { +// +// The returned usesNewProtocol bool reports whether the protocol version +// header indicates a protocol version >= 2026-06-30 (SEP-2575). +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, err error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 @@ -401,7 +409,7 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server var hasInitialize, hasInitialized bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, fmt.Errorf("failed to read body") + return nil, false, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -415,23 +423,28 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*Server case notificationInitialized: hasInitialized = true } + if protocolVersion >= protocolVersion20260630 { + usesNewProtocol = true + } } } } state := new(ServerSessionState) - if !hasInitialize { + // Only synthesize fake InitializeParams/InitializedParams for old-protocol + // requests. + if !hasInitialize && !usesNewProtocol { state.InitializeParams = &InitializeParams{ ProtocolVersion: protocolVersion, } } - if !hasInitialized { + if !hasInitialized && !usesNewProtocol { state.InitializedParams = new(InitializedParams) } state.LogLevel = "info" return &ServerSessionOptions{ State: state, - }, nil + }, usesNewProtocol, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -576,7 +589,7 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, err := h.ephemeralConnectOpts(req) + connectOpts, _, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1279,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false var initializeProtocolVersion string + headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -1296,6 +1310,41 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + // SEP-2575: requests carrying `_meta.protocolVersion` require the + // Mcp-Protocol-Version HTTP header to be present and to match the + // per-request `_meta.protocolVersion` value. + // The new (>= 2026-06-30) protocol is supported on the HTTP transport + // only when [StreamableHTTPOptions.Stateless] is true. + // + // TODO: this validation can be moved within validateMcpHeaders. + var metaVersion string + if meta := extractRequestMeta(jreq.Params); meta != nil { + metaVersion, _ = meta[MetaKeyProtocolVersion].(string) + } + if protocolVersion >= protocolVersion20260630 || metaVersion != "" { + if !c.stateless { + http.Error(w, fmt.Sprintf( + "Bad Request: protocol version %q is only supported on stateless HTTP servers (set StreamableHTTPOptions.Stateless = true)", + protocolVersion), + http.StatusBadRequest) + return + } + if headerVersion == "" { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header is required for requests carrying %q", + protocolVersionHeader, MetaKeyProtocolVersion), + http.StatusBadRequest) + return + } + if headerVersion != metaVersion { + http.Error(w, fmt.Sprintf( + "Bad Request: %s header %q does not match request %s %q", + protocolVersionHeader, headerVersion, + MetaKeyProtocolVersion, metaVersion), + http.StatusBadRequest) + return + } + } // Include metadata for all requests (including notifications). jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..e566bf25 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1926,39 +1926,18 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() - initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) - initResp := resp(1, &InitializeResult{ - Capabilities: &ServerCapabilities{ - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: minVersionForStandardHeaders, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - }, nil) - - initialize := streamableRequest{ - method: "POST", - messages: []jsonrpc.Message{initReq}, - wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{initResp}, - wantSessionID: true, - } - initialized := streamableRequest{ - method: "POST", - headers: http.Header{ - protocolVersionHeader: {minVersionForStandardHeaders}, - methodHeader: {notificationInitialized}, - }, - messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, - wantStatusCode: http.StatusAccepted, + testMeta := Meta{ + MetaKeyProtocolVersion: minVersionForStandardHeaders, + MetaKeyClientInfo: map[string]any{"name": "testClient", "version": "v1.0.0"}, + MetaKeyClientCapabilities: map[string]any{}, } testStreamableHandler(t, handler, []streamableRequest{ - initialize, - initialized, { method: "POST", headers: http.Header{ @@ -1966,7 +1945,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, }, @@ -1977,7 +1956,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"prompts/get"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -1988,7 +1967,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"wrong-tool"}, }, - messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Name header value", }, @@ -1999,7 +1978,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"TOOLS/CALL"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusBadRequest, wantBodyContaining: "Mcp-Method header value", }, @@ -2010,7 +1989,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { methodHeader: {"tools/call"}, nameHeader: {"my-tool"}, }, - messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, }, @@ -2023,6 +2002,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"us-west1"}, }, messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, @@ -2038,6 +2018,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { paramHeaderPrefix + "Region": {"eu-central1"}, }, messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2052,6 +2033,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { nameHeader: {"execute_sql"}, }, messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Meta: testMeta, Name: "execute_sql", Arguments: map[string]any{"region": "us-west1"}, })}, @@ -2061,6 +2043,68 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }) } +// TODO: Remove this once client operations will automatically inject metadata in the requests +func injectMetaToRequest(req *http.Request) error { + if req.Body == nil { + return nil + } + body, err := io.ReadAll(req.Body) + if err != nil { + return err + } + req.Body.Close() + + var val any + if err := json.Unmarshal(body, &val); err == nil { + var method string + if m, ok := val.(map[string]any); ok { + method, _ = m["method"].(string) + } else if list, ok := val.([]any); ok && len(list) > 0 { + if m, ok := list[0].(map[string]any); ok { + method, _ = m["method"].(string) + } + } + + if method == "initialize" || method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") { + req.Header.Set(protocolVersionHeader, "2025-11-25") + } else { + req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + var msgs []map[string]any + if m, ok := val.(map[string]any); ok { + msgs = []map[string]any{m} + } else if list, ok := val.([]any); ok { + for _, item := range list { + if m, ok := item.(map[string]any); ok { + msgs = append(msgs, m) + } + } + } + + for _, m := range msgs { + params, _ := m["params"].(map[string]any) + if params == nil { + params = make(map[string]any) + m["params"] = params + } + meta, _ := params["_meta"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + params["_meta"] = meta + } + meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders + meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"} + meta[MetaKeyClientCapabilities] = map[string]any{} + } + body, _ = json.Marshal(val) + } + } + + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + return nil +} + // TestStreamableMcpHeaderValidationErrorFormat verifies that header // validation errors return a JSON-RPC error with code -32001 and // Content-Type application/json, per SEP-2243. @@ -2076,7 +2120,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { return &CallToolResult{}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) @@ -2088,6 +2134,9 @@ func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } var originalMethodHeader string if req.Header.Get(methodHeader) == "tools/call" { originalMethodHeader = req.Header.Get(methodHeader) @@ -2237,7 +2286,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil }) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() @@ -2245,6 +2296,9 @@ func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { var capturedHeaders http.Header customClient := &http.Client{ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } if req.Header.Get(methodHeader) == "tools/call" { capturedHeaders = req.Header.Clone() } @@ -2347,14 +2401,28 @@ func TestStreamableFilterValidToolsIntegration(t *testing.T) { InputSchema: &jsonschema.Schema{Type: "object"}, }, noop) - handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ + Stateless: true, + }) defer handler.closeAll() httpServer := httptest.NewServer(mustNotPanic(t, handler)) defer httpServer.Close() + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if err := injectMetaToRequest(req); err != nil { + return nil, err + } + return http.DefaultTransport.RoundTrip(req) + }), + } + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) ctx := context.Background() - session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + session, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + }, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) if err != nil { t.Fatal(err) } @@ -3207,3 +3275,314 @@ func TestStandaloneSSEEmitsCommentForHTTP2Flush(t *testing.T) { t.Fatal("timed out waiting for first SSE bytes; the standalone SSE stream must emit a DATA frame immediately so HTTP/2 reverse proxies don't buffer the HEADERS frame") } } + +// newProtocolBody builds a raw JSON body for a tools/call request that +// carries the >= 2026-06-30 per-request _meta fields. +func newProtocolBody(t *testing.T, toolName string, args any) []byte { + t.Helper() + rawArgs, err := json.Marshal(args) + if err != nil { + t.Fatal(err) + } + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "_meta": map[string]any{ + MetaKeyProtocolVersion: protocolVersion20260630, + MetaKeyClientInfo: map[string]any{"name": "new-proto-client", "version": "9.9"}, + MetaKeyClientCapabilities: map[string]any{"sampling": map[string]any{}}, + }, + "name": toolName, + "arguments": json.RawMessage(rawArgs), + }, + }) + if err != nil { + t.Fatal(err) + } + return body +} + +func TestEphemeralConnectOpts(t *testing.T) { + mkReq := func(body []byte) *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + r.Header.Set("Content-Type", "application/json") + return r + } + + h := &StreamableHTTPHandler{opts: StreamableHTTPOptions{}} + + oldProtocolBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{"name": "x", "arguments": map[string]any{}}, + }) + if err != nil { + t.Fatal(err) + } + initializeBody, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": methodInitialize, + "params": map[string]any{"protocolVersion": protocolVersion20250618}, + }) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + body []byte + wantUsesNew bool + wantInitializeParams bool + wantInitializedParams bool + }{ + { + name: "new-protocol request: no synthetic state", + body: newProtocolBody(t, "x", struct{}{}), + wantUsesNew: true, + wantInitializeParams: false, + wantInitializedParams: false, + }, + { + name: "old-protocol request: synthetic state populated", + body: oldProtocolBody, + wantUsesNew: false, + wantInitializeParams: true, + wantInitializedParams: true, + }, + { + name: "initialize request: no synthetic InitializeParams", + body: initializeBody, + wantUsesNew: false, + wantInitializeParams: false, + wantInitializedParams: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := mkReq(tt.body) + var pver string + if tt.wantUsesNew { + pver = protocolVersion20260630 + } else { + pver = protocolVersion20250326 + } + req.Header.Set(protocolVersionHeader, pver) + req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) + opts, usesNew, err := h.ephemeralConnectOpts(req) + if err != nil { + t.Fatal(err) + } + if usesNew != tt.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", usesNew, tt.wantUsesNew) + } + if got := opts.State.InitializeParams != nil; got != tt.wantInitializeParams { + t.Errorf("InitializeParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializeParams, opts.State.InitializeParams) + } + if got := opts.State.InitializedParams != nil; got != tt.wantInitializedParams { + t.Errorf("InitializedParams non-nil = %v, want %v (value = %+v)", + got, tt.wantInitializedParams, opts.State.InitializedParams) + } + }) + } +} + +// statelessHandlerCapture builds a stateless server with a single tool whose +// handler captures everything we want to assert about the per-request view of +// the session and the new-protocol accessors. +type statelessHandlerCapture struct { + mu sync.Mutex + sessionInitParams *InitializeParams + reqProtocolVersion string + reqClientInfo *Implementation + reqClientCapabilities *ClientCapabilities +} + +func TestStreamableStateless_NewProtocolSession_NoFakeInit(t *testing.T) { + // SEP-2575: the MCP-Protocol-Version header is mandatory for new-protocol + // requests and must be a supported version. The 2026-06-30 version is + // not yet in the global list, so register it for the duration of the test. + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + capture := &statelessHandlerCapture{} + mcpServer := NewServer(testImpl, nil) + AddTool(mcpServer, &Tool{Name: "capture", Description: "captures request info"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + capture.mu.Lock() + defer capture.mu.Unlock() + capture.sessionInitParams = req.Session.InitializeParams() + capture.reqProtocolVersion = req.ProtocolVersion() + capture.reqClientInfo = req.ClientInfo() + capture.reqClientCapabilities = req.ClientCapabilities() + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return mcpServer }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "capture", struct{}{}) + httpReq, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + httpReq.Header.Set(protocolVersionHeader, protocolVersion20260630) + // >= 2026-06-30 also requires the Mcp-Method and Mcp-Name standard + // headers (see streamable_headers.go). + httpReq.Header.Set(methodHeader, "tools/call") + httpReq.Header.Set(nameHeader, "capture") + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } + + capture.mu.Lock() + defer capture.mu.Unlock() + if capture.sessionInitParams == nil { + t.Errorf("Session.InitializeParams() is nil, want populated initializeParams for new-protocol session") + } else { + if got, want := capture.sessionInitParams.ProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("Session.InitializeParams().ProtocolVersion = %q, want %q", got, want) + } + if got, want := capture.sessionInitParams.ClientInfo.Name, "new-proto-client"; got != want { + t.Errorf("Session.InitializeParams().ClientInfo.Name = %q, want %q", got, want) + } + } + if got, want := capture.reqProtocolVersion, protocolVersion20260630; got != want { + t.Errorf("req.ProtocolVersion() = %q, want %q", got, want) + } + if capture.reqClientInfo == nil || capture.reqClientInfo.Name != "new-proto-client" { + t.Errorf("req.ClientInfo() = %+v, want Name=new-proto-client", capture.reqClientInfo) + } + if capture.reqClientCapabilities == nil || capture.reqClientCapabilities.Sampling == nil { + t.Errorf("req.ClientCapabilities() = %+v, want non-nil Sampling", capture.reqClientCapabilities) + } +} + +// TestStreamableStateful_RejectsNewProtocol verifies that a stateful HTTP +// server rejects requests carrying _meta.protocolVersion (i.e. >= 2026-06-30 +// requests) with HTTP 400. The new protocol is +// supported on HTTP only when StreamableHTTPOptions.Stateless=true. +func TestStreamableStateful_RejectsNewProtocol(t *testing.T) { + // Make 2026-06-30 a "known" version so that the request reaches servePOST + // (otherwise the early header validation at ServeHTTP rejects it). + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Initialize a legacy session first. + initBody := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`) + initReq, err := http.NewRequest(http.MethodPost, httpServer.URL, initBody) + if err != nil { + t.Fatal(err) + } + initReq.Header.Set("Content-Type", "application/json") + initReq.Header.Set("Accept", "application/json, text/event-stream") + initResp, err := http.DefaultClient.Do(initReq) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, initResp.Body) + initResp.Body.Close() + sessionID := initResp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialize response missing %s header", sessionIDHeader) + } + + // Drive the existing session with a new-protocol request whose header and + // body agree. The cross-check passes; the stateful-rejection check fires. + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(sessionIDHeader, sessionID) + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400; body = %s", resp.StatusCode, respBody) + } + if !strings.Contains(string(respBody), "stateless") { + t.Errorf("body = %q, want a message mentioning 'stateless'", respBody) + } +} + +// TestStreamableStateless_AcceptsNewProtocol is the positive control: +// confirms that a stateless server still accepts new-protocol requests +// (the rejection in TestStreamableStateful_RejectsNewProtocol must not +// fire on Stateless: true). +func TestStreamableStateless_AcceptsNewProtocol(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "noop"}, + func(ctx context.Context, req *CallToolRequest, args struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil, nil + }) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + body := newProtocolBody(t, "noop", struct{}{}) + req, err := http.NewRequest(http.MethodPost, httpServer.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set(protocolVersionHeader, protocolVersion20260630) + req.Header.Set(methodHeader, "tools/call") + req.Header.Set(nameHeader, "noop") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want 200; body = %s", resp.StatusCode, respBody) + } +} From 2d47cc96646020b446391d0b039e1ea5ec06414f Mon Sep 17 00:00:00 2001 From: Max Gerber <89937743+max-stytch@users.noreply.github.com> Date: Wed, 27 May 2026 02:39:08 -0700 Subject: [PATCH 05/10] auth: issuer mix-up mitigation (#859) This PR functions as a Reference implementation of [SEP-2468](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2468) / [RFC9207](https://datatracker.ietf.org/doc/rfc9207/). This PR hardens the MCP OAuth Client functionality against Mix-Up attacks: > Mix-up attacks aim to steal an authorization code or access token by > tricking the client into sending the authorization code or access > token to the attacker instead of the honest authorization or resource > server This PR hardens the client by adding support for a new `iss` parameter in authorization responses: - Authorization Servers broadcast support for the `iss` parameter via the `authorization_response_iss_parameter_supported` metadata parameter - If the parameter is supported, clients expect to receive the `iss` parameter in the authorization response - Clients compare the `iss` parameter in the authorization response to the `Issuer` parameter in the authorization metadata. The two must match exactly for the response to be processed. Fixes #941 --- auth/authorization_code.go | 30 ++++++++++ auth/authorization_code_test.go | 55 +++++++++++++++++++ conformance/everything-client/main.go | 6 ++ docs/protocol.md | 23 +++++++- examples/auth/client/main.go | 1 + examples/server/auth-middleware/go.mod | 6 +- examples/server/auth-middleware/go.sum | 2 + examples/server/rate-limiting/go.mod | 2 +- internal/docs/protocol.src.md | 22 +++++++- .../oauthtest/fake_authorization_server.go | 6 +- oauthex/auth_meta.go | 7 +++ 11 files changed, 151 insertions(+), 9 deletions(-) diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 47541f99..144f0ad5 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -50,6 +50,12 @@ type AuthorizationResult struct { Code string // State string returned by the authorization server. State string + // Iss is the issuer identifier returned by the authorization server in the + // authorization response per [RFC 9207]. The AuthorizationCodeFetcher should + // populate this from the "iss" query parameter in the redirect URI if present. + // + // [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 + Iss string } // AuthorizationArgs is the input to [AuthorizationCodeFetcher]. @@ -318,6 +324,9 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ // Purposefully leaving the error unwrappable so it can be handled by the caller. return err } + if err := validateIssuerResponse(authRes.Iss, asm.Issuer, asm.AuthorizationResponseIssParameterSupported); err != nil { + return err + } err = h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) if err != nil { @@ -560,6 +569,27 @@ func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg }, nil } +// validateIssuerResponse validates the "iss" parameter in an authorization response +// per [RFC 9207]. +// +// [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 +func validateIssuerResponse(iss, expectedIssuer string, issParameterSupported bool) error { + if issParameterSupported { + if iss == "" { + return fmt.Errorf("authorization server advertises RFC 9207 iss parameter support but none was received in the authorization response") + } + if iss != expectedIssuer { + return fmt.Errorf("authorization response issuer %q does not match expected issuer %q", iss, expectedIssuer) + } + } else { + if iss != "" { + return fmt.Errorf("authorization server does not advertise RFC 9207 iss parameter support but iss was received in the authorization response") + } + } + + return nil +} + // exchangeAuthorizationCode exchanges the authorization code for a token // and stores it in a token source. func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index 12520054..d4de06ce 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -78,6 +78,7 @@ func TestAuthorize(t *testing.T) { return &AuthorizationResult{ Code: location.Query().Get("code"), State: location.Query().Get("state"), + Iss: location.Query().Get("iss"), }, nil }, }) @@ -176,6 +177,7 @@ func TestAuthorize_ScopeAccumulation(t *testing.T) { return &AuthorizationResult{ Code: loc.Query().Get("code"), State: loc.Query().Get("state"), + Iss: loc.Query().Get("iss"), }, nil }, }) @@ -736,6 +738,59 @@ func TestDynamicRegistration(t *testing.T) { } } +func TestValidateIssuerResponse(t *testing.T) { + const expectedIssuer = "https://auth.example.com" + + tests := []struct { + name string + iss string + issSupported bool + wantErr bool + wantErrContains string + }{ + { + name: "ValidIss", + iss: expectedIssuer, + issSupported: true, + }, + { + name: "WrongIss", + iss: "https://attacker.example.com", + issSupported: true, + wantErr: true, + wantErrContains: "does not match expected issuer", + }, + { + name: "MissingIssWhenRequired", + iss: "", + issSupported: true, + wantErr: true, + wantErrContains: "RFC 9207", + }, + { + name: "MissingIssWhenNotRequired", + iss: "", + issSupported: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateIssuerResponse(tt.iss, expectedIssuer, tt.issSupported) + if tt.wantErr { + if err == nil { + t.Fatalf("validateIssuerResponse() = nil, want error containing %q", tt.wantErrContains) + } + if !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("validateIssuerResponse() error = %q, want it to contain %q", err.Error(), tt.wantErrContains) + } + } else if err != nil { + t.Fatalf("validateIssuerResponse() unexpected error = %v", err) + } + }) + } +} + func TestInferApplicationType(t *testing.T) { tests := []struct { name string diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 947b095e..55a55ade 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -66,6 +66,11 @@ func init() { "auth/token-endpoint-auth-basic", "auth/token-endpoint-auth-post", "auth/token-endpoint-auth-none", + "auth/iss-supported", + "auth/iss-not-advertised", + "auth/iss-supported-missing", + "auth/iss-wrong-issuer", + "auth/iss-unexpected", } for _, scenario := range authScenarios { registerScenario(scenario, runAuthClient) @@ -232,6 +237,7 @@ func fetchAuthorizationCodeAndState(ctx context.Context, args *auth.Authorizatio return &auth.AuthorizationResult{ Code: locURL.Query().Get("code"), State: locURL.Query().Get("state"), + Iss: locURL.Query().Get("iss"), }, nil } diff --git a/docs/protocol.md b/docs/protocol.md index cfcdb855..f45d1dfc 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -15,6 +15,7 @@ 1. [Token Passthrough](#token-passthrough) 1. [Server-Side Request Forgery](#server-side-request-forgery) 1. [Session Hijacking](#session-hijacking) + 1. [Issuer Mix-Up](#issuer-mix-up) 1. [Utilities](#utilities) 1. [Cancellation](#cancellation) 1. [Ping](#ping) @@ -322,6 +323,7 @@ This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) - [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) +- [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207) Authorization Server Issuer Identification To use it, configure the handler and assign it to the transport: @@ -333,11 +335,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Open the args.URL in a browser and return the resulting code and state. + // Open the args.URL in a browser and return the resulting code, state, and iss. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + iss := ... // "iss" query parameter from the redirect URI (RFC 9207) + return &auth.AuthorizationResult{Code: code, State: state, Iss: iss}, nil }, }) @@ -490,6 +493,22 @@ sets `UserID` on the returned `TokenInfo`, the streamable transport will: `TokenInfo.UserID` to enable this protection. This prevents an attacker with a valid token from hijacking another user's session by guessing or obtaining their session ID. +### Issuer Mix-Up + +The [mitigation](https://www.rfc-editor.org/rfc/rfc9207) against issuer mix-up attacks is +implemented per [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207). The SDK client validates +the `iss` parameter in authorization responses to ensure they originated from the expected +authorization server: + +- If `iss` is present in the redirect URI, the SDK verifies it matches the issuer from the + authorization server's metadata. A mismatch results in an error. +- If `iss` is absent but the authorization server advertises + `authorization_response_iss_parameter_supported: true` in its [RFC 8414](https://www.rfc-editor.org/rfc/rfc8414) + metadata, the SDK rejects the response with an error. + +The `AuthorizationCodeFetcher` is responsible for extracting the `iss` query parameter from +the redirect URI and returning it in [`AuthorizationResult.Iss`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationResult). + ## Utilities ### Cancellation diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go index f514ebae..7e5d4645 100644 --- a/examples/auth/client/main.go +++ b/examples/auth/client/main.go @@ -36,6 +36,7 @@ func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { r.authChan <- &auth.AuthorizationResult{ Code: req.URL.Query().Get("code"), State: req.URL.Query().Get("state"), + Iss: req.URL.Query().Get("iss"), } fmt.Fprint(w, "Authentication successful. You can close this window.") }) diff --git a/examples/server/auth-middleware/go.mod b/examples/server/auth-middleware/go.mod index 8690256a..f1bfc378 100644 --- a/examples/server/auth-middleware/go.mod +++ b/examples/server/auth-middleware/go.mod @@ -1,16 +1,16 @@ module auth-middleware-example -go 1.23.0 +go 1.25.0 require ( - github.com/golang-jwt/jwt/v5 v5.2.2 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/modelcontextprotocol/go-sdk v0.3.0 ) require ( github.com/google/jsonschema-go v0.4.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect - golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect ) replace github.com/modelcontextprotocol/go-sdk => ../../../ diff --git a/examples/server/auth-middleware/go.sum b/examples/server/auth-middleware/go.sum index d257e104..a14b1344 100644 --- a/examples/server/auth-middleware/go.sum +++ b/examples/server/auth-middleware/go.sum @@ -1,5 +1,6 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= @@ -9,5 +10,6 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/examples/server/rate-limiting/go.mod b/examples/server/rate-limiting/go.mod index 61c8788c..dc375543 100644 --- a/examples/server/rate-limiting/go.mod +++ b/examples/server/rate-limiting/go.mod @@ -1,6 +1,6 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting -go 1.23.0 +go 1.25.0 require ( github.com/modelcontextprotocol/go-sdk v0.3.0 diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 593742c7..4044a6fd 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -247,6 +247,7 @@ This handler supports: - [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) - [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) - [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) +- [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207) Authorization Server Issuer Identification To use it, configure the handler and assign it to the transport: @@ -258,11 +259,12 @@ authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandle // PreregisteredClientConfig: ... // DynamicClientRegistrationConfig: ... AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { - // Open the args.URL in a browser and return the resulting code and state. + // Open the args.URL in a browser and return the resulting code, state, and iss. // See full example in examples/auth/client/main.go. code := ... state := ... - return &auth.AuthorizationResult{Code: code, State: state}, nil + iss := ... // "iss" query parameter from the redirect URI (RFC 9207) + return &auth.AuthorizationResult{Code: code, State: state, Iss: iss}, nil }, }) @@ -415,6 +417,22 @@ sets `UserID` on the returned `TokenInfo`, the streamable transport will: `TokenInfo.UserID` to enable this protection. This prevents an attacker with a valid token from hijacking another user's session by guessing or obtaining their session ID. +### Issuer Mix-Up + +The [mitigation](https://www.rfc-editor.org/rfc/rfc9207) against issuer mix-up attacks is +implemented per [RFC 9207](https://www.rfc-editor.org/rfc/rfc9207). The SDK client validates +the `iss` parameter in authorization responses to ensure they originated from the expected +authorization server: + +- If `iss` is present in the redirect URI, the SDK verifies it matches the issuer from the + authorization server's metadata. A mismatch results in an error. +- If `iss` is absent but the authorization server advertises + `authorization_response_iss_parameter_supported: true` in its [RFC 8414](https://www.rfc-editor.org/rfc/rfc8414) + metadata, the SDK rejects the response with an error. + +The `AuthorizationCodeFetcher` is responsible for extracting the `iss` query parameter from +the redirect URI and returning it in [`AuthorizationResult.Iss`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationResult). + ## Utilities ### Cancellation diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go index b27e00ed..c180f862 100644 --- a/internal/oauthtest/fake_authorization_server.go +++ b/internal/oauthtest/fake_authorization_server.go @@ -15,6 +15,7 @@ import ( "maps" "net/http" "net/http/httptest" + "net/url" "slices" "testing" @@ -178,6 +179,8 @@ func (s *FakeAuthorizationServer) handleMetadata(w http.ResponseWriter, r *http. CodeChallengeMethodsSupported: []string{"S256"}, ClientIDMetadataDocumentSupported: cimdSupported, TokenEndpointAuthMethodsSupported: []string{"client_secret_post", "client_secret_basic"}, + // Advertise RFC 9207 support: the authorize endpoint includes "iss" in responses. + AuthorizationResponseIssParameterSupported: true, } // Set CORS headers for cross-origin client discovery. w.Header().Set("Access-Control-Allow-Origin", "*") @@ -267,8 +270,9 @@ func (s *FakeAuthorizationServer) handleAuthorize(w http.ResponseWriter, r *http } state := r.URL.Query().Get("state") + issuer := s.URL() + s.config.IssuerPath - redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, state) + redirectURL := fmt.Sprintf("%s?code=%s&state=%s&iss=%s", redirectURI, code, state, url.QueryEscape(issuer)) http.Redirect(w, r, redirectURL, http.StatusFound) } diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 255aca92..7ebf6e22 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -114,6 +114,13 @@ type AuthServerMeta struct { // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server // supports client ID metadata documents. ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` + + // AuthorizationResponseIssParameterSupported indicates whether the authorization server + // provides the "iss" parameter in authorization responses per [RFC 9207]. + // When true, clients must verify the "iss" parameter is present and matches the Issuer field. + // + // [RFC 9207]: https://www.rfc-editor.org/rfc/rfc9207 + AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata From 189a85ad92fff4a5c7fef29532d30fcaa2cf10d3 Mon Sep 17 00:00:00 2001 From: Yaroslav Date: Fri, 29 May 2026 09:29:34 +0200 Subject: [PATCH 06/10] feat: multi-round-trip request implementation (SEP-2322) (#950) ### Context Design (`design/mrtr.md`) and implementation proposal for Multi Round-Trip Requests (MRTR) per [SEP-2322](https://github.com/CaitieM20/modelcontextprotocol/blob/de6d76fba3078eda957dadb3cec51ca8ab851b5c/seps/2322-MRTR.md). ### Changes - Introduce `InputRequest` and `InputResponse` sealed interfaces with corresponding `InputRequestMap` and `InputResponseMap` types for custom map JSON codec implementation. - Extend `CallToolParams`, `GetPromptParams`, and `ReadResourceParams` with `InputResponses`, `RequestState` and an unexported `resultType` ("complete" or "input_required") for retry round-trips. The type is set by the SDK based on input requests presence. - Add client-side MRTR middleware enabled by default that automatically fulfills input requests and retries. - Added MRTROptions on ClientOptions to allow disabling the middleware and configuring the number of retries. - Add server-side MRTR middleware for backward compatibility: transparently bridges input requests to direct Elicit/CreateMessage/ListRoots calls for older clients --- design/mrtr.md | 264 +++++++++++++++++++ go.mod | 1 + go.sum | 2 + mcp/client.go | 10 +- mcp/content_nil_test.go | 3 +- mcp/mcp_test.go | 5 +- mcp/mrtr.go | 268 ++++++++++++++++++++ mcp/mrtr_test.go | 546 ++++++++++++++++++++++++++++++++++++++++ mcp/protocol.go | 397 ++++++++++++++++++++++++++++- mcp/protocol_test.go | 65 +++++ mcp/server.go | 37 ++- mcp/streamable_test.go | 12 +- 12 files changed, 1587 insertions(+), 23 deletions(-) create mode 100644 design/mrtr.md create mode 100644 mcp/mrtr.go create mode 100644 mcp/mrtr_test.go diff --git a/design/mrtr.md b/design/mrtr.md new file mode 100644 index 00000000..16321a46 --- /dev/null +++ b/design/mrtr.md @@ -0,0 +1,264 @@ +## Context + +A proposal for implementing Multi Round-Trip Requests +(MRTR) as defined in [SEP-2322](https://github.com/CaitieM20/modelcontextprotocol/blob/de6d76fba3078eda957dadb3cec51ca8ab851b5c/seps/2322-MRTR.md). + +In the new protocol version servers can't initiate requests to clients, but when a server requires additional input for completing `tools/call`, `prompts/get`, or `resources/read` it can return an incomplete result along with a set of `inputRequests`. The client fulfills them locally and retries the same call with `inputResponses` attached. + +## Goals + +**Must have:** + +* Backward compatibility. +* Correct representation on the wire. + +**Nice to have:** + +* Minimal changes to the exported API surface. +* Hard for server implementers to construct an invalid payload. +* Simple input request handling for clients. +* Protocol-version-independent code. +* Consistency with the rest of the SDK. + +## Proposal + +`ServerSession` methods return an error for new-version protocol connections. + +`InputRequest`/`InputResponse` is introduced as a sealed-interface: +```go +// Implemented by *ElicitParams, *CreateMessageParams, *ListRootsParams +type InputRequest interface{ isInputRequest() } + +type InputRequestMap map[string]InputRequest +// MarshalJSON encodes as map[string]struct{ Method string; Params InputRequest } +func (m InputRequestMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Params InputRequest } +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { ... } + +// Implemented by *ElicitResult, *CreateMessageResult, *ListRootsResult. +type InputResponse interface{ isInputResult() } + +type InputResponseMap map[string]InputResponse +// MarshalJSON encodes as map[string]struct{ Method string; Result InputResponse } +func (m InputResponseMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Result InputResponse } +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { ... } +``` + +All affected methods' `*Params` are extended with `InputResponseMap` and `RequestState` fields: +```go +type CallToolParams struct { + ... + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + RequestState string `json:"requestState,omitempty"` +} +// Same for GetPromptParams, ReadResourceParams +``` + +`InputRequests` and `RequestState` fields are added directly to `CallToolResult`, `GetPromptResult`, and `ReadResourceResult` as exported. +Result type discriminator (completed, input_required) is unexported so that SDK users don't need to set it to the correct constant in addition to setting either `Content` or `InputRequests`. Handler execution result is validated and augmented before marshaling: +```go +type CallToolResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` + resultType string // set by the SDK and used in MarshalJSON() +} +// Same for GetPromptResult, ReadResourceResult. +``` +Alternatively, the field could only exist on `wire struct`, but this would make us return `complete` to older clients or empty string to newer clients, because there's no access to negotiated protocol version in `MarshalJSON`. + +Servers request additional input by constructing a correct struct literal: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return &mcp.CallToolResult{ + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` +The SDK validates at runtime that a handler does not return both content and `InputRequests` — doing so logs a warning and returns a `CodeInternalError` JSON-RPC error. + +An unexported receiving middleware is installed on the server for backward compatibility with older clients. When a handler returns `InputRequests` and the connected client uses a protocol version that does not support MRTR, the middleware fulfills the requests by calling `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the client directly and reinvokes the handler once with the collected `InputResponses`. If any of these calls fail, the entire request fails. Input requests are fulfilled concurrently. This lets server developers write protocol-version-independent code. + +An unexported sending middleware is installed on the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. Clients have an option to disable it and write a retry loop manually using `NeedsInput()`: +```go +type MultiRoundTripOptions struct { + Disabled bool +} + +client := mcp.NewClient(impl, &mcp.ClientOptions{MultiRoundTrip: &mcp.MultiRoundTripOptions{Disabled: true}}) +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if result.NeedsInput() { ... } +``` + +`NeedsInput()` checks the unexported `resultType` field rather than `InputRequests`, correctly handling the load-shedding case where the server returns `input_required` with an empty map. + +**Pros** + +This is arguably the simplest and the most transparent approach which is also closest to the spec. +What gets explicitly set on the server can be observed on the wire and on the client. +The opt-out client middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. +The server middleware makes handler code protocol-version-independent — the same handler works for both old and new clients. + +**Cons** + +The biggest downside of the proposal is that server developers can construct incorrect responses (both content and input requests) and this will only be validated at runtime. + +## Alternatives considered + +### Unexported fields + +MRTR fields can be unexported, accessible only through getters, constructible only through constructor functions, and handled explicitly in custom `(Unm|M)arshalJSON`. This will make it impossible for developers to construct incorrect responses and for clients to perform an erroneous `len(result.InputRequests) > 0` check in the load-shedding case. +```go +type CallToolResult struct { + ... + inputRequests InputRequestMap + requestState string + resultType string +} + +func (r *CallToolResult) InputRequests() (InputRequestMap, bool) { ... } + +// InputRequiredResult struct exists for backward-compatibility in case of new fields being needed for input request results. +type InputRequiredResult struct { + InputRequests InputRequestMap + RequestState string +} + +// RequireInput constructs a tool call, prompt or resource result with input requests set. +// mrtrResult provides methods for setting private fields on these types. +func RequireInput[T any, TP interface { *T; mrtrResult }](r InputRequiredResult) TP { ... } +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return mcp.RequireInput[mcp.CallToolResult](mcp.InputRequiredResult{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }), nil, nil + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if requests, ok := result.InputRequests(); ok { ... } +``` + +The biggest downside of this approach is the obscure data model with hidden fields. An incomplete `mcp.CallToolResult` looks like an uninitialized struct until `InputRequests` method result is examined. +In addition to this, the verbose `RequireInput` syntax (no auto type inference from assignment target) does not look idiomatic and fits poorly into the existing SDK APIs. + +--- + +### `InputRequiredError` type + +We could explore a different data channel - `error` return value. This would give us the natural "happy path is when all inputs are provided" flow on the server side, and good result interpretability on the client side (impossible to confuse with a successful response). +The new error could be converted to the correct wire representation at the marshaling stage. +```go +type InputRequiredError struct { + InputRequests InputRequestMap + RequestState string +} + +func (e *InputRequiredError) Error() string { + return fmt.Sprintf("input required: %d request(s)", len(e.InputRequests)) +} +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return nil, zero, &mcp.InputRequiredError{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + } + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +var inputReqErr *mcp.InputRequiredError +if errors.As(err, &inputReqErr) { ... } +``` + +The downsides of this approach are: +* The drift from the protocol, where MRTR is not an error flow. +* Obscure "customError -> non-error protocol type on wirte -> customError" data lifecycle. +* Things get confusing for error-processing middleware. + +--- + +### New functions + +We could introduce new functions with a different handler signature where the return type is a sealed interface. This would give us compiler-enforced correctness for values constructed by tool handlers and clients would be forced to unpack `mcp.RoundTripCallToolResult` and make a concious decision for how to handle it. +```go +type RoundTripToolHandler func(context.Context, *CallToolRequest) (RoundTripCallToolResult, error) +type RoundTripToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (RoundTripCallToolResult, Out, error) + +// RoundTripCallToolResult is implemented by CallToolResult and IncompleteResult +type RoundTripCallToolResult interface { isMRTRResult() } + +type IncompleteResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` +} + +func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) +func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) +``` + +`Server.AddTool` wraps the old `ToolHandler` into a `RoundTripToolHandler` to update its function signature: +```go +mcp.AddRoundTripTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (mcp.RoundTripCallToolResult, MyOut, error) { + if needsInput(in) { + return &mcp.IncompleteResult{ + ResultType: mcp.ResultTypeInputRequired, + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +The downsides of this approach are: +* SEP suggests `ResultType` will potentially be extended with new values, `RoundTrip` in new function names will not allow us to cleanly extend the sealed interface with new types. But an overly generic name for new functions will make the API use-case less clear. +* Different code +* SDK takes the same action (puts it on the wire) regardless of the returned type, it exists only for enforcing correctness of the user code. +* Exported API surface bloat: +7 exported functions. + +--- + +### Exported Middleware + +We could flip "unexported MRTR middleware with opt-out option" to "exported middleware with opt-in requirement". +```go +func AutoMRTR(opts *MultiRoundTripOptions) Middleware { ... } +type MultiRoundTripOptions struct { + MaxRetries int +} +client := mcp.NewClient(impl, nil) +client.AddSendingMiddleware(mcp.AutoMRTR(&mcp.MultiRoundTripOptions{ + MaxRetries: 5, +})) +``` +This would change semantics of `*Handler` fields - depending on the protocol version in use, an extra initialization step will be required for them to "take effect". + +--- + +### Server API protocol version bridging + +Converting `ServerSession.Elicit`/`CreateMessage`/`ListRoots` calls into MRTR wire format transparently (suspend the handler, return `input_required`, resume on retry). Rejected because of a significant implementation effort and the fact that it contradicts the design goal of MRTR where servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. + diff --git a/go.mod b/go.mod index 860d6fa0..3287a957 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( require ( github.com/segmentio/asm v1.1.3 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect ) diff --git a/go.sum b/go.sum index 377a7b11..c13454aa 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..9f6f2955 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -58,13 +58,17 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } - return &Client{ + c := &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } + if opts.MultiRoundTrip == nil || !opts.MultiRoundTrip.Disabled { + c.AddSendingMiddleware(clientMultiRoundTripMiddleware()) + } + return c } // ClientOptions configures the behavior of the client. @@ -154,6 +158,10 @@ type ClientOptions struct { ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) LoggingMessageHandler func(context.Context, *LoggingMessageRequest) ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // MultiRoundTrip configures the automatic MultiRoundTrip (Multi Round-Trip Requests) middleware. + // By default (nil), the middleware is enabled with default settings. + // Set Disabled to true to opt out of automatic MultiRoundTrip handling. + MultiRoundTrip *MultiRoundTripOptions // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index ba263e6f..86bb86ac 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -223,4 +224,4 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(mcp.CallToolResult{}, mcp.GetPromptResult{}, mcp.ReadResourceResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..bad35086 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -206,7 +207,7 @@ func TestEndToEnd(t *testing.T) { Role: "user", }}, } - if diff := cmp.Diff(wantReview, gotReview); diff != "" { + if diff := cmp.Diff(wantReview, gotReview, ctrCmpOpts...); diff != "" { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } @@ -2371,4 +2372,4 @@ func TestSetErrorPreservesContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} diff --git a/mcp/mrtr.go b/mcp/mrtr.go new file mode 100644 index 00000000..8ad5f1ae --- /dev/null +++ b/mcp/mrtr.go @@ -0,0 +1,268 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/sync/errgroup" +) + +const maxMultiRoundTripRetries = 10 +const maxLoadSheddingMultiRoundTripRetries = 3 + +// MultiRoundTripOptions configures the client-side multi round-trip request (SEP-2322) +// middleware. The middleware is enabled by default and automatically fulfills input +// requests from the server by invoking the appropriate client handlers and +// retrying the original call. +type MultiRoundTripOptions struct { + // Disabled prevents the automatic multi-round-tirp middleware from being installed. + // When true, the client returns input-required results directly and callers must + // handle the retry loop themselves using [CallToolResult.NeedsInput], + // [GetPromptResult.NeedsInput], or [ReadResourceResult.NeedsInput]. + Disabled bool +} + +type multiRoundTripResponse interface { + setResultType(resultType) + inputRequests() map[string]InputRequest + requestState() string + hasContent() bool +} + +func handleMultiRoundTripResult(ss *ServerSession, logger *slog.Logger, res multiRoundTripResponse) error { + if res == nil { + return nil + } + hasInputRequests := res.inputRequests() != nil + + if hasInputRequests && res.hasContent() { + logger.Warn("handler returned both content and inputRequests") + return &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server bug: result has both content and inputRequests", + } + } + + if clientSupportsMultiRoundTrip(ss) { + // For older clients the resultType is left unset. Input requests will be handled + // by serverMultiRoundTripMiddleware client calls and handler reinvocation. + if hasInputRequests { + res.setResultType(resultTypeInputRequired) + } else { + res.setResultType(resultTypeComplete) + } + } + return nil +} + +func clientSupportsMultiRoundTrip(ss *ServerSession) bool { + protocolVersion := latestProtocolVersion + if iparams := ss.InitializeParams(); iparams != nil { + protocolVersion = iparams.ProtocolVersion + } + return protocolVersion >= protocolVersion20260630 +} + +func clientMultiRoundTripMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + loadSheddingFailures := 0 + for retries := 1; ; retries++ { + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + mrtrResult, ok := res.(multiRoundTripResponse) + if !ok { + return res, nil + } + reqMap := mrtrResult.inputRequests() + if reqMap == nil { + return res, nil + } + if len(reqMap) == 0 { + loadSheddingFailures++ + } + if loadSheddingFailures >= maxLoadSheddingMultiRoundTripRetries { + return nil, fmt.Errorf("multi-round-trip: exceeded maximum load-shedding retries (%d)", maxLoadSheddingMultiRoundTripRetries) + } + if retries >= maxMultiRoundTripRetries { + return nil, fmt.Errorf("multi-round-trip: exceeded maximum retries (%d)", maxMultiRoundTripRetries) + } + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, nil + } + responses, err := fulfillInputRequests(ctx, cs, reqMap) + if err != nil { + return nil, err + } + setMultiRoundTripRetryParams(req, responses, mrtrResult.requestState()) + } + } + } +} + +// serverMultiRoundTripMiddleware is a receiving middleware for servers that transparently +// handles multi-round-trip for clients on older protocol versions. When a handler returns +// InputRequests and the client does not support multi-round-trip, the middleware fulfills +// the requests by calling the client directly and reinvokes the handler once with the responses. +func serverMultiRoundTripMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + ss, ok := req.GetSession().(*ServerSession) + if !ok { + return next(ctx, method, req) + } + if clientSupportsMultiRoundTrip(ss) { + return next(ctx, method, req) + } + + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + mrtrResult, ok := res.(multiRoundTripResponse) + if !ok { + return res, nil + } + reqMap := mrtrResult.inputRequests() + if reqMap == nil { + return res, nil + } + if len(reqMap) == 0 { + return nil, fmt.Errorf("the server is busy, retry later") + } + responses, err := fulfillServerInputRequests(ctx, ss, reqMap) + if err != nil { + return nil, err + } + setMultiRoundTripRetryParams(req, responses, mrtrResult.requestState()) + return next(ctx, method, req) + } + } +} + +func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillServerInputRequest(ctx, ss, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("multi-round-trip: %w", err) + } + return responses, nil +} + +func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return ss.Elicit(ctx, p) + case *CreateMessageParams: + return ss.CreateMessageWithTools(ctx, createMessageParamsToWithTools(p)) + case *CreateMessageWithToolsParams: + return ss.CreateMessageWithTools(ctx, p) + case *ListRootsParams: + return ss.ListRoots(ctx, p) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} + +func createMessageParamsToWithTools(p *CreateMessageParams) *CreateMessageWithToolsParams { + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) + } + return &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } +} + +func setMultiRoundTripRetryParams(req Request, responses InputResponseMap, state string) { + switch p := req.GetParams().(type) { + case *CallToolParams: + p.InputResponses = responses + p.RequestState = state + case *CallToolParamsRaw: + p.InputResponses = responses + p.RequestState = state + case *GetPromptParams: + p.InputResponses = responses + p.RequestState = state + case *ReadResourceParams: + p.InputResponses = responses + p.RequestState = state + } +} + +func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillInputRequest(ctx, cs, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("multi round-trip: %w", err) + } + return responses, nil +} + +func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return cs.client.elicit(ctx, newClientRequest(cs, p)) + case *CreateMessageParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: createMessageParamsToWithTools(p)}) + case *CreateMessageWithToolsParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: p}) + case *ListRootsParams: + return cs.client.listRoots(ctx, newClientRequest(cs, p)) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go new file mode 100644 index 00000000..12eba7ea --- /dev/null +++ b/mcp/mrtr_test.go @@ -0,0 +1,546 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "slices" + "sync/atomic" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" +) + +func TestMultiRoundTrip_ManualRetry(t *testing.T) { + type deployResult struct { + Deployed bool `json:"deployed"` + Reason string `json:"reason,omitempty"` + } + + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }, nil, nil + } + + resp, ok := req.Params.InputResponses["confirm"] + if !ok { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Please confirm (retry)"}}, + }, nil, nil + } + + if req.Params.RequestState == "" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "no_state"}, nil + } + if elicitResult := resp.(*ElicitResult); elicitResult != nil && elicitResult.Action != "accept" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil + } + + return &CallToolResult{}, &deployResult{Deployed: true}, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + // Round 1: initiate deployment + res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if got := len(res.InputRequests); got != 1 { + t.Fatalf("len(res.InputRequests) = %d, want 1", got) + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("res.InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + // Round 2: retry with confirmation + res, err = conn.CallTool(ctx, &CallToolParams{ + Name: "deploy", + InputResponses: InputResponseMap{ + "confirm": &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, + }, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("CallTool() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + + if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } +} + +func TestMultiRoundTrip_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + tests := []struct { + name string + inputRequests InputRequestMap + wantResult map[string]any + }{ + { + name: "elicit", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Deploy?"}, + }, + wantResult: map[string]any{"ids": []any{"confirm"}}, + }, + { + name: "createMessage", + inputRequests: InputRequestMap{ + "summarize": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "summarize"}}}, + MaxTokens: 100, + }, + }, + wantResult: map[string]any{"ids": []any{"summarize"}}, + }, + { + name: "listRoots", + inputRequests: InputRequestMap{ + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"roots"}}, + }, + { + name: "all three", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "draft": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "write"}}}, + MaxTokens: 50, + }, + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"confirm", "draft", "roots"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + inputRequests := tt.inputRequests + AddTool(srv, &Tool{Name: "act"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil, nil + } + // Collect the IDs of fulfilled responses. + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + return &CallToolResult{}, map[string]any{"ids": ids}, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "response"}, + }, nil + }, + }) + conn.client.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + + res, err := conn.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + + // Sort the expected IDs for stable comparison. + if wantIDs, ok := tt.wantResult["ids"].([]any); ok { + slices.SortFunc(wantIDs, func(a, b any) int { + if a.(string) < b.(string) { + return -1 + } + return 1 + }) + } + + if diff := cmp.Diff(tt.wantResult, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestMultiRoundTrip_MaxRetries(t *testing.T) { + testCases := []struct { + name string + requests InputRequestMap + wantRetries int + }{ + { + name: "load shedding", + requests: InputRequestMap{}, + wantRetries: maxLoadSheddingMultiRoundTripRetries, + }, + { + name: "input request", + requests: InputRequestMap{"confirm": &ElicitParams{Message: "Again?"}}, + wantRetries: maxMultiRoundTripRetries, + }, + } + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + var serverCalls atomic.Int32 + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "loop"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + serverCalls.Add(1) + return &CallToolResult{InputRequests: tc.requests, RequestState: "loop-state"}, nil, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + _, err := conn.CallTool(ctx, &CallToolParams{Name: "loop"}) + if err == nil { + t.Fatal("CallTool() err = nil, want error for exceeded max retries") + } + if serverCalls.Load() != int32(tc.wantRetries) { + t.Errorf("serverCalls = %d, want %d", serverCalls.Load(), tc.wantRetries) + } + }) + } +} + +func TestMultiRoundTrip_ServerMiddleware(t *testing.T) { + // multiRoundTripToolHandler returns a ToolHandler (plain, non-generic) that requests + // the given inputRequests on the first call and returns the fulfilled + // response IDs on the second. + multiRoundTripToolHandler := func(inputRequests InputRequestMap) ToolHandler { + return func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil + } + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + content := &TextContent{Text: fmt.Sprintf("%v", ids)} + return &CallToolResult{Content: []Content{content}}, nil + } + } + + tests := []struct { + name string + inputRequests InputRequestMap + wantText string + }{ + { + name: "elicit via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Sure?"}, + }, + wantText: "[confirm]", + }, + { + name: "elicit and listRoots via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "roots": &ListRootsParams{}, + }, + wantText: "[confirm roots]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddTool( + &Tool{Name: "act", InputSchema: &jsonschema.Schema{Type: "object"}}, + multiRoundTripToolHandler(tt.inputRequests), + ) + + // Connect with an OLD protocol version where multi-round-trip is not supported. + // The server middleware should handle it transparently. + st, ct := NewInMemoryTransports() + ss, err := srv.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + c.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if got := res.Content[0].(*TextContent).Text; got != tt.wantText { + t.Errorf("result text = %q, want %q", got, tt.wantText) + } + }) + } +} + +func TestMultiRoundTrip_GetPrompt_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } + if got := res.Messages[0].Content.(*TextContent).Text; got != "review this code" { + t.Errorf("message text = %q, want %q", got, "review this code") + } +} + +func TestMultiRoundTrip_GetPrompt_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + res, err = conn.GetPrompt(ctx, &GetPromptParams{ + Name: "review", + InputResponses: InputResponseMap{"confirm": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("GetPrompt() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } +} + +func TestMultiRoundTrip_ReadResource_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } + if got := res.Contents[0].Text; got != "resource data" { + t.Errorf("resource text = %q, want %q", got, "resource data") + } +} + +func TestMultiRoundTrip_ReadResource_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnect(t, srv, &ClientOptions{ + MultiRoundTrip: &MultiRoundTripOptions{Disabled: true}, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["auth"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[auth] type = %T, want *ElicitParams", res.InputRequests["auth"]) + } + + res, err = conn.ReadResource(ctx, &ReadResourceParams{ + URI: "test://data", + InputResponses: InputResponseMap{"auth": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("ReadResource() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } +} + +func mustConnect(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { + t.Helper() + st, ct := NewInMemoryTransports() + ss, err := s.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = ss.Close() + }) + + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = cs.Close() + }) + return cs +} diff --git a/mcp/protocol.go b/mcp/protocol.go index fcba83f1..bebdc196 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -13,6 +13,178 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) +// resultType indicates whether a result is complete or requires further input +// from the client via the multi round-trip request protocol. +type resultType string + +const ( + // resultTypeComplete indicates the result is final. + // This is the default when ResultType is empty. + resultTypeComplete resultType = "complete" + + // resultTypeInputRequired indicates the server needs additional client + // input before it can complete the request. The client should fulfill the + // InputRequests and retry the call with the responses. + resultTypeInputRequired resultType = "input_required" +) + +// InputRequest is a type for parameters that a server can include in the response +// to request input from client (SEP-2322). Implementations are [*ElicitParams], +// [*CreateMessageParams], and [*ListRootsParams]. +type InputRequest interface{ isInputRequest() } + +// InputRequestMap maps server-assigned request IDs to [InputRequest] values. +// It is used in result types to tell the client what input the server needs. +type InputRequestMap map[string]InputRequest + +func (m InputRequestMap) MarshalJSON() ([]byte, error) { + if m == nil { + return json.Marshal(map[string]any(nil)) + } + type wire struct { + Method string `json:"method"` + Params InputRequest `json:"params,omitempty"` + } + typeToMethod := func(v InputRequest) (string, error) { + switch v.(type) { + case *ElicitParams: + return methodElicit, nil + case *CreateMessageParams, *CreateMessageWithToolsParams: + return methodCreateMessage, nil + case *ListRootsParams: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Params: v} + } + return json.Marshal(converted) +} + +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + if rawMap == nil { + return nil + } + result := make(InputRequestMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + } + } + *m = result + return nil +} + +// InputResponse is a type for results that a client sends back when fulfilling +// a server input request (SEP-2322). Implementations are [*ElicitResult], +// [*CreateMessageResult], and [*ListRootsResult]. +type InputResponse interface{ isInputResponse() } + +// InputResponseMap maps request IDs (from [InputRequestMap]) to [InputResponse] +// values. It is used in params types when retrying a call after an +// input-required result. +type InputResponseMap map[string]InputResponse + +func (m InputResponseMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Result InputResponse `json:"result,omitempty"` + } + typeToMethod := func(v InputResponse) (string, error) { + switch v.(type) { + case *ElicitResult: + return methodElicit, nil + case *CreateMessageResult, *CreateMessageWithToolsResult: + return methodCreateMessage, nil + case *ListRootsResult: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Result: v} + } + return json.Marshal(converted) +} + +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Result json.RawMessage `json:"result"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputResponseMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputResponse method: %q", raw.Method) + } + } + *m = result + return nil +} + // Optional annotations for the client. The client can use annotations to inform // how objects are used or displayed. type Annotations struct { @@ -46,6 +218,14 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -61,6 +241,14 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -107,6 +295,24 @@ type CallToolResult struct { // the Content field. IsError bool `json:"isError,omitempty"` + // InputRequests is a map of server-assigned IDs to input requests. + // Populated only when ResultType is ResultTypeInputRequired. + // The client must fulfill these and echo the IDs back in InputResponses + // when retrying the call. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + + // RequestState is an opaque string the client must echo back when + // retrying after an input-required result. Servers use this to carry + // context between independent requests. + // + // Unauthenticated servers must encrypt, sign and verify this value. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. Empty or ResultTypeComplete means the call succeeded + // normally. ResultTypeInputRequired means the client should fulfill the + // InputRequests and retry the call. + resultType resultType // The error passed to setError, if any. // It is not marshaled, and therefore it is only visible on the server. // Its only use is in server sending middleware, where it can be accessed @@ -145,13 +351,49 @@ func (r *CallToolResult) GetError() error { func (*CallToolResult) isResult() {} -// UnmarshalJSON handles the unmarshalling of content into the Content -// interface. +func (r *CallToolResult) setResultType(rt resultType) { r.resultType = rt } +func (r *CallToolResult) requestState() string { return r.RequestState } +func (r *CallToolResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *CallToolResult) hasContent() bool { + return len(r.Content) > 0 || r.StructuredContent != nil +} + +// NeedsInput reports whether this result requires further client input. +// This is true when the server returned ResultType "input_required". +// When NeedsInput returns true, check InputRequests for the set of +// requests the server needs fulfilled before retrying the call. +// An empty InputRequests with NeedsInput true indicates load-shedding. +func (r *CallToolResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *CallToolResult) MarshalJSON() ([]byte, error) { + type res CallToolResult // avoid recursion + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + func (x *CallToolResult) UnmarshalJSON(data []byte) error { type res CallToolResult // avoid recursion var wire struct { res - Content []*wireContent `json:"content"` + Content []*wireContent `json:"content"` + ResultType resultType `json:"resultType"` } if err := internaljson.Unmarshal(data, &wire); err != nil { return err @@ -160,6 +402,7 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } + wire.res.resultType = wire.ResultType *x = CallToolResult(wire.res) return nil } @@ -426,6 +669,7 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isInputRequest() {} func (x *CreateMessageParams) isNil() bool { return x == nil } func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -453,6 +697,7 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isInputRequest() {} func (x *CreateMessageWithToolsParams) isNil() bool { return x == nil } func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -553,7 +798,8 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isInputResponse() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -598,7 +844,8 @@ var createMessageWithToolsResultAllow = map[string]bool{ "tool_use": true, } -func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isInputResponse() {} // MarshalJSON marshals the result. When Content has a single element, it is // marshaled as a single object for compatibility with pre-2025-11-25 @@ -657,6 +904,13 @@ type GetPromptParams struct { Arguments map[string]string `json:"arguments,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *GetPromptParams) isParams() {} @@ -672,10 +926,67 @@ type GetPromptResult struct { // An optional description for the prompt. Description string `json:"description,omitempty"` Messages []*PromptMessage `json:"messages"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for multi-round-trip retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType resultType } func (*GetPromptResult) isResult() {} +func (r *GetPromptResult) setResultType(rt resultType) { r.resultType = rt } +func (r *GetPromptResult) requestState() string { return r.RequestState } +func (r *GetPromptResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *GetPromptResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *GetPromptResult) MarshalJSON() ([]byte, error) { + type res GetPromptResult + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + +func (x *GetPromptResult) UnmarshalJSON(data []byte) error { + type res GetPromptResult + var wire struct { + res + ResultType resultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = GetPromptResult(wire.res) + return nil +} + // InitializeParams is sent by the client to initialize the session. type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -845,6 +1156,7 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isInputRequest() {} func (x *ListRootsParams) isNil() bool { return x == nil } func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -859,7 +1171,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isInputResponse() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -1104,6 +1417,13 @@ type ReadResourceParams struct { // The URI of the resource to read. The URI can use any protocol; it is up to // the server how to interpret it. URI string `json:"uri"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *ReadResourceParams) isParams() {} @@ -1117,10 +1437,67 @@ type ReadResourceResult struct { // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Contents []*ResourceContents `json:"contents"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for multi-round-trip retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType resultType } func (*ReadResourceResult) isResult() {} +func (r *ReadResourceResult) setResultType(rt resultType) { r.resultType = rt } +func (r *ReadResourceResult) requestState() string { return r.RequestState } +func (r *ReadResourceResult) inputRequests() map[string]InputRequest { + if r == nil { + return nil + } + return r.InputRequests +} +func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *ReadResourceResult) NeedsInput() bool { return r.resultType == resultTypeInputRequired } + +func (x *ReadResourceResult) MarshalJSON() ([]byte, error) { + type res ReadResourceResult + type wire struct { + res + ResultType resultType `json:"resultType,omitempty"` + InputRequests json.RawMessage `json:"inputRequests,omitempty"` // shadows res.InputRequests + } + w := wire{res: res(*x), ResultType: x.resultType} + if x.InputRequests != nil { + ir, err := json.Marshal(x.InputRequests) + if err != nil { + return nil, err + } + w.InputRequests = ir + } + return json.Marshal(w) +} + +func (x *ReadResourceResult) UnmarshalJSON(data []byte) error { + type res ReadResourceResult + var wire struct { + res + ResultType resultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = ReadResourceResult(wire.res) + return nil +} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -1520,8 +1897,9 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} -func (x *ElicitParams) isNil() bool { return x == nil } +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isInputRequest() {} +func (x *ElicitParams) isNil() bool { return x == nil } func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1541,7 +1919,8 @@ type ElicitResult struct { Content map[string]any `json:"content,omitempty"` } -func (*ElicitResult) isResult() {} +func (*ElicitResult) isResult() {} +func (*ElicitResult) isInputResponse() {} // ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. type ElicitationCompleteParams struct { diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index 2bab644e..edf3623d 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -1151,6 +1151,71 @@ func TestToWithTools_Conversion(t *testing.T) { } } +func TestInputRequestMapJSON(t *testing.T) { + t.Run("nil is omitted from JSON", func(t *testing.T) { + result := CallToolResult{Content: []Content{&TextContent{Text: "ok"}}} + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if _, ok := raw["inputRequests"]; ok { + t.Errorf("nil InputRequests should be omitted, got %s", raw["inputRequests"]) + } + }) + + t.Run("non-nil empty round-trips", func(t *testing.T) { + result := CallToolResult{ + Content: []Content{&TextContent{Text: "ok"}}, + InputRequests: InputRequestMap{}, + } + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatal(err) + } + if string(raw["inputRequests"]) != "{}" { + t.Errorf("empty InputRequests should marshal to {}, got %s", raw["inputRequests"]) + } + var got CallToolResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.InputRequests == nil { + t.Error("empty InputRequests should round-trip as non-nil") + } + }) + + t.Run("populated round-trips", func(t *testing.T) { + result := CallToolResult{ + Content: []Content{&TextContent{Text: "ok"}}, + InputRequests: InputRequestMap{ + "r1": &ElicitParams{Message: "confirm?"}, + }, + } + data, err := json.Marshal(&result) + if err != nil { + t.Fatal(err) + } + var got CallToolResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if got.InputRequests == nil { + t.Fatal("InputRequests should not be nil after round-trip") + } + if _, ok := got.InputRequests["r1"]; !ok { + t.Error("expected key r1 in InputRequests") + } + }) +} + func TestContentUnmarshal(t *testing.T) { // Verify that types with a Content field round-trip properly. roundtrip := func(in, out any) { diff --git a/mcp/server.go b/mcp/server.go index 8d24147e..b86b3e6c 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -187,7 +187,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } - return &Server{ + s := &Server{ impl: impl, opts: opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), @@ -199,6 +199,8 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), } + s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) + return s } // AddPrompt adds a [Prompt] to the server, or replaces one with the same name. @@ -370,8 +372,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa } // Marshal the output and put the RawMessage in the StructuredContent field. + // Skip when the handler returned input requests (multi round-trip): content and + // inputRequests are mutually exclusive on the wire. var outval any = out - if elemZero != nil { + if res.InputRequests != nil { + outval = nil + } else if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. // Instead, use the zero value of the unpointered type. var z Out @@ -742,7 +748,13 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req) + res, err := prompt.handler(ctx, req) + if err == nil && res != nil { + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + } + return res, err } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { @@ -775,10 +787,15 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR } } res, err := st.handler(ctx, req) - if err == nil && res != nil && res.Content == nil { - res2 := *res - res2.Content = []Content{} // avoid "null" - res = &res2 + if err == nil && res != nil { + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.Content == nil && res.resultType != resultTypeInputRequired { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } } return res, err } @@ -826,6 +843,12 @@ func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*R if err != nil { return nil, err } + if err := handleMultiRoundTripResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.resultType == resultTypeInputRequired { + return res, nil + } if res == nil || res.Contents == nil { return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e566bf25..53806da9 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -769,6 +769,12 @@ func req(id int64, method string, params any) *jsonrpc.Request { return r } +func completeCallToolResult() *CallToolResult { + r := &CallToolResult{Content: []Content{}} + r.resultType = resultTypeComplete + return r +} + func resp(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), @@ -1947,7 +1953,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(2, completeCallToolResult(), nil)}, }, { method: "POST", @@ -1991,7 +1997,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Meta: testMeta, Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(6, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2007,7 +2013,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, completeCallToolResult(), nil)}, }, { method: "POST", From dfb45f1e119ea40e63e0de563cbb06834a58ace1 Mon Sep 17 00:00:00 2001 From: Guglielmo Colombo Date: Fri, 29 May 2026 14:46:34 +0200 Subject: [PATCH 07/10] mcp: add optional issuer validator for pre-registered client validation (SEP-2352) (#946) Fixes #978 --- auth/authorization_code.go | 3 + auth/authorization_code_test.go | 82 +++++++++++++++++++++++++ auth/extauth/client_credentials.go | 6 +- auth/extauth/client_credentials_test.go | 45 ++++++++++++++ internal/authutil/util.go | 13 ++++ internal/authutil/util_test.go | 29 +++++++++ oauthex/auth_meta.go | 6 +- oauthex/client.go | 9 +++ oauthex/client_test.go | 2 +- 9 files changed, 190 insertions(+), 5 deletions(-) create mode 100644 internal/authutil/util.go create mode 100644 internal/authutil/util_test.go diff --git a/auth/authorization_code.go b/auth/authorization_code.go index 144f0ad5..a3daeecb 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -507,6 +507,9 @@ func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm * // 2. Attempt to use pre-registered client configuration. preCfg := h.config.PreregisteredClient if preCfg != nil { + if preCfg.Issuer != "" && !authutil.IssuersEqual(preCfg.Issuer, asm.Issuer) { + return nil, fmt.Errorf("authorization server issuer %q does not match pre-registered credentials issuer %q", asm.Issuer, preCfg.Issuer) + } authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) clientSecret := "" if preCfg.ClientSecretAuth != nil { diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index d4de06ce..c84b0032 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -609,6 +609,8 @@ func TestHandleRegistration(t *testing.T) { asm *oauthex.AuthServerMeta want *resolvedClientConfig wantError bool + issuerMatch bool + issuerSuffix string }{ { name: "ClientIDMetadataDocument", @@ -647,6 +649,79 @@ func TestHandleRegistration(t *testing.T) { authStyle: oauth2.AuthStyleInParams, }, }, + { + name: "Preregistered_IssuerMatch", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "", // set dynamically in the test + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInParams, + }, + issuerMatch: true, + }, + { + name: "Preregistered_IssuerMismatch", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "https://other-issuer.example.com", + }, + }, + wantError: true, + }, + { + name: "Preregistered_IssuerMatchTrailingSlash", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "pre_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ + ClientSecret: "pre_client_secret", + }, + Issuer: "", // set dynamically in the test (with trailing slash) + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInParams, + }, + issuerMatch: true, + issuerSuffix: "/", + }, { name: "NoneSupported", handlerConfig: &AuthorizationCodeHandlerConfig{ @@ -660,6 +735,10 @@ func TestHandleRegistration(t *testing.T) { t.Run(tt.name, func(t *testing.T) { s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig}) s.Start(t) + // Set the Issuer dynamically if requested by the test case. + if tt.issuerMatch { + tt.handlerConfig.PreregisteredClient.Issuer = s.URL() + tt.issuerSuffix + } tt.handlerConfig.AuthorizationCodeFetcher = func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { return nil, nil } @@ -679,6 +758,9 @@ func TestHandleRegistration(t *testing.T) { } return } + if tt.wantError { + t.Fatal("handleRegistration() expected error, got nil") + } if got.registrationType != tt.want.registrationType { t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, tt.want.registrationType) } diff --git a/auth/extauth/client_credentials.go b/auth/extauth/client_credentials.go index b95fcaee..c65d8b06 100644 --- a/auth/extauth/client_credentials.go +++ b/auth/extauth/client_credentials.go @@ -128,6 +128,11 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ } } + creds := h.config.Credentials + if creds.Issuer != "" && !authutil.IssuersEqual(creds.Issuer, asm.Issuer) { + return fmt.Errorf("authorization server issuer %q does not match pre-registered credentials issuer %q", asm.Issuer, creds.Issuer) + } + // Determine requestedScopes: use PRM's scopes_supported if available. requestedScopes := scopesFromChallenges(wwwChallenges) if len(requestedScopes) == 0 && len(prm.ScopesSupported) > 0 { @@ -140,7 +145,6 @@ func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Requ requestedScopes = authutil.UnionScopes(h.grantedScopes[asm.Issuer], requestedScopes) // Step 3: Exchange client credentials for an access token. - creds := h.config.Credentials cfg := &clientcredentials.Config{ ClientID: creds.ClientID, ClientSecret: creds.ClientSecretAuth.ClientSecret, diff --git a/auth/extauth/client_credentials_test.go b/auth/extauth/client_credentials_test.go index 550e56af..b2973de7 100644 --- a/auth/extauth/client_credentials_test.go +++ b/auth/extauth/client_credentials_test.go @@ -221,6 +221,51 @@ func TestClientCredentialsHandler_Authorize(t *testing.T) { } }) + t.Run("issuer mismatch", func(t *testing.T) { + config := validClientCredentialsConfig() + config.Credentials.Issuer = "https://other-issuer.example.com" + handler, err := NewClientCredentialsHandler(config) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + err = handler.Authorize(t.Context(), req, resp) + if err == nil { + t.Fatal("expected Authorize to fail with issuer mismatch") + } + if !strings.Contains(err.Error(), "does not match") { + t.Errorf("error %q does not mention issuer mismatch", err.Error()) + } + }) + + t.Run("issuer match ignoring trailing slash", func(t *testing.T) { + config := validClientCredentialsConfig() + // authServer.URL() has no trailing slash; configure with one to + // verify the comparison tolerates the difference (per RFC 8414 §3.3 + // normalization applied in oauthex.IssuersEqual). + config.Credentials.Issuer = authServer.URL() + "/" + handler, err := NewClientCredentialsHandler(config) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("Authorize() unexpected error = %v", err) + } + }) + t.Run("PRM via resource_metadata in challenge", func(t *testing.T) { prmMux := http.NewServeMux() prmMux.Handle("/custom-prm", auth.ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ diff --git a/internal/authutil/util.go b/internal/authutil/util.go new file mode 100644 index 00000000..880d1bfa --- /dev/null +++ b/internal/authutil/util.go @@ -0,0 +1,13 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package authutil + +import "strings" + +// IssuersEqual reports whether two OAuth 2.0 authorization server issuer +// identifiers refer to the same server comparing them without the final trailing slash. +func IssuersEqual(a, b string) bool { + return strings.TrimSuffix(a, "/") == strings.TrimSuffix(b, "/") +} diff --git a/internal/authutil/util_test.go b/internal/authutil/util_test.go new file mode 100644 index 00000000..9d3ff7a4 --- /dev/null +++ b/internal/authutil/util_test.go @@ -0,0 +1,29 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package authutil + +import "testing" + +func TestIssuersEqual(t *testing.T) { + tests := []struct { + a, b string + want bool + }{ + {"https://issuer.example.com", "https://issuer.example.com", true}, + {"https://issuer.example.com/", "https://issuer.example.com", true}, + {"https://issuer.example.com", "https://issuer.example.com/", true}, + {"https://issuer.example.com/", "https://issuer.example.com/", true}, + {"https://issuer.example.com/tenant", "https://issuer.example.com/tenant", true}, + {"https://issuer.example.com/tenant/", "https://issuer.example.com/tenant", true}, + {"https://issuer.example.com", "https://other.example.com", false}, + {"https://issuer.example.com/a", "https://issuer.example.com/b", false}, + {"", "", true}, + } + for _, tt := range tests { + if got := IssuersEqual(tt.a, tt.b); got != tt.want { + t.Errorf("IssuersEqual(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want) + } + } +} diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 7ebf6e22..711b17a4 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -12,7 +12,8 @@ import ( "errors" "fmt" "net/http" - "strings" + + "github.com/modelcontextprotocol/go-sdk/internal/authutil" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -153,8 +154,7 @@ func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http. } return nil, fmt.Errorf("%v", err) // Do not expose error types. } - if strings.TrimRight(asm.Issuer, "/") != strings.TrimRight(issuer, "/") { - // Validate the Issuer field (see RFC 8414, section 3.3). + if !authutil.IssuersEqual(asm.Issuer, issuer) { return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuer) } diff --git a/oauthex/client.go b/oauthex/client.go index e8f99182..1b19eed5 100644 --- a/oauthex/client.go +++ b/oauthex/client.go @@ -18,6 +18,15 @@ type ClientCredentials struct { // This is the most common authentication method for confidential clients. // OPTIONAL. If not provided, the client is treated as a public client. ClientSecretAuth *ClientSecretAuth + + // Issuer is the issuer identifier of the authorization server these + // credentials are registered with. Pre-registered credentials are bound + // to a specific authorization server; when set, an error is returned if + // the discovered authorization server does not match, per SEP-2352. + // The comparison ignores a single trailing slash, matching the + // tolerance applied during RFC 8414 Section 3.3 metadata validation. + // OPTIONAL. + Issuer string } // ClientSecretAuth holds client secret authentication credentials. diff --git a/oauthex/client_test.go b/oauthex/client_test.go index b78e9c8b..34d8188c 100644 --- a/oauthex/client_test.go +++ b/oauthex/client_test.go @@ -73,7 +73,7 @@ func TestClientCredentials_ValidateCoversAllAuthFields(t *testing.T) { var pointerFields int for i := range typ.NumField() { f := typ.Field(i) - if f.Name == "ClientID" { + if f.Name == "ClientID" || f.Name == "Issuer" { continue } if f.Type.Kind() != reflect.Ptr { From 8805aa85f7531bf478dd2ddf5406aaf222ff5e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomas=20Daba=C5=A1inskas?= Date: Sun, 31 May 2026 11:25:49 +0300 Subject: [PATCH 08/10] mcp: add configurable keepalive failure threshold (#982) mcp: add configurable keepalive failure threshold Introduce `KeepAliveFailureThreshold` option in both `ClientOptions` and `ServerOptions` to control how many consecutive keepalive ping failures are tolerated before closing the session. This aligns with the MCP spec's guidance that "multiple failed pings MAY trigger a connection reset," allowing operators to tune resilience against transient network hiccups without immediately tearing down otherwise healthy sessions. A threshold of 0 or 1 (the default) closes on the first failure, preserving existing behavior. Higher values let isolated misses pass while still closing the session once consecutive failures reach the threshold. A successful ping resets the counter. Tolerated failures are logged at WARN level; the final failure that closes the session is logged at ERROR level. This is rework of #979. --- mcp/client.go | 9 +++++- mcp/mcp_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 9 +++++- mcp/shared.go | 51 +++++++++++++++++++++++++-------- 4 files changed, 131 insertions(+), 14 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 9f6f2955..979172ba 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -166,6 +166,13 @@ type ClientOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // KeepAliveFailureThreshold is the number of consecutive keepalive ping + // failures tolerated before the session is closed. A value of 0 or 1 + // closes the session on the first failure (the default). Higher values + // align with the spec's "multiple failed pings MAY trigger a connection + // reset" guidance, letting a transient miss pass without tearing down an + // otherwise live session. Has no effect unless KeepAlive is non-zero. + KeepAliveFailureThreshold int } // toolContextKeyType is the context key type for passing tool definitions @@ -441,7 +448,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await // startKeepalive starts the keepalive mechanism for this client session. func (cs *ClientSession) startKeepalive(interval time.Duration) { - startKeepalive(cs, interval, &cs.keepaliveCancel, cs.client.opts.Logger) + startKeepalive(cs, interval, cs.client.opts.KeepAliveFailureThreshold, &cs.keepaliveCancel, cs.client.opts.Logger) } // AddRoots adds the given roots to the client, diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index bad35086..5c8e7d12 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1920,6 +1920,82 @@ func TestKeepAliveFailure_Logged(t *testing.T) { }) } +// scriptedKeepaliveSession is a keepaliveSession test double whose Ping +// returns errors from a script (one entry consumed per call; the last entry +// repeats once exhausted), and records how many times Close was called. Ping +// returns immediately so the keepalive loop's pace is driven purely by the +// ticker, making the test deterministic under synctest. +type scriptedKeepaliveSession struct { + pingErrs []error + pingCalls atomic.Int64 + closeCalls atomic.Int64 +} + +func (s *scriptedKeepaliveSession) Ping(context.Context, *PingParams) error { + n := int(s.pingCalls.Add(1)) - 1 + if n >= len(s.pingErrs) { + n = len(s.pingErrs) - 1 + } + return s.pingErrs[n] +} + +func (s *scriptedKeepaliveSession) Close() error { + s.closeCalls.Add(1) + return nil +} + +// TestStartKeepalive_FailureThreshold verifies that the session is kept alive +// across consecutive ping failures below the threshold and only closed once the +// threshold is reached. +func TestStartKeepalive_FailureThreshold(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const interval = 100 * time.Millisecond + sess := &scriptedKeepaliveSession{pingErrs: []error{errors.New("boom")}} + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + var cancel context.CancelFunc + startKeepalive(sess, interval, 3, &cancel, logger) + defer cancel() + + // After two ticks → two failures, still below threshold 3: not closed. + time.Sleep(2*interval + interval/2) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 0 { + t.Fatalf("session closed below threshold: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + + // Third tick → third failure reaches threshold: session closed. + time.Sleep(interval) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 1 { + t.Fatalf("expected one Close at threshold, got closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + }) +} + +// TestStartKeepalive_SuccessResetsFailures verifies that a successful ping +// resets the consecutive-failure counter, so an isolated failure between +// successes never accumulates toward the threshold. +func TestStartKeepalive_SuccessResetsFailures(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const interval = 100 * time.Millisecond + // fail, success, fail, fail, then success (the tail repeats): the run + // never has 3 consecutive failures, so the session is never closed. + sess := &scriptedKeepaliveSession{pingErrs: []error{ + errors.New("boom"), nil, errors.New("boom"), errors.New("boom"), nil, + }} + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + var cancel context.CancelFunc + startKeepalive(sess, interval, 3, &cancel, logger) + defer cancel() + + time.Sleep(6 * interval) + synctest.Wait() + if got := sess.closeCalls.Load(); got != 0 { + t.Fatalf("session closed despite a success resetting the counter: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load()) + } + }) +} + func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. diff --git a/mcp/server.go b/mcp/server.go index b86b3e6c..693bd391 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -79,6 +79,13 @@ type ServerOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // KeepAliveFailureThreshold is the number of consecutive keepalive ping + // failures tolerated before the session is closed. A value of 0 or 1 + // closes the session on the first failure (the default). Higher values + // align with the spec's "multiple failed pings MAY trigger a connection + // reset" guidance, letting a transient miss pass without tearing down an + // otherwise live session. Has no effect unless KeepAlive is non-zero. + KeepAliveFailureThreshold int // Function called when a client session subscribes to a resource. SubscribeHandler func(context.Context, *SubscribeRequest) error // Function called when a client session unsubscribes from a resource. @@ -1605,7 +1612,7 @@ func (ss *ServerSession) Wait() error { // startKeepalive starts the keepalive mechanism for this server session. func (ss *ServerSession) startKeepalive(interval time.Duration) { - startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.Logger) + startKeepalive(ss, interval, ss.server.opts.KeepAliveFailureThreshold, &ss.keepaliveCancel, ss.server.opts.Logger) } // pageToken is the internal structure for the opaque pagination cursor. diff --git a/mcp/shared.go b/mcp/shared.go index 1caacac3..afa566aa 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -751,9 +751,20 @@ type keepaliveSession interface { // It assigns the cancel function to the provided cancelPtr and starts a goroutine // that sends ping messages at the specified interval. // -// logger must be non-nil; ping failures (which terminate the keepalive loop and -// close the session) are reported via logger so they are not silently dropped. -func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, logger *slog.Logger) { +// failureThreshold is the number of consecutive ping failures tolerated before +// the session is closed; a value below 1 is treated as 1 (close on the first +// failure). A successful ping resets the counter. This mirrors the spec's +// "multiple failed pings MAY trigger a connection reset" language, letting a +// transient miss pass without tearing down an otherwise live session. +// +// logger must be non-nil; ping failures (both the tolerated ones and the final +// one that closes the session) are reported via logger so they are not silently +// dropped. +func startKeepalive(session keepaliveSession, interval time.Duration, failureThreshold int, cancelPtr *context.CancelFunc, logger *slog.Logger) { + if failureThreshold < 1 { + failureThreshold = 1 + } + ctx, cancel := context.WithCancel(context.Background()) // Assign cancel function before starting goroutine to avoid race condition. // We cannot return it because the caller may need to cancel during the @@ -764,6 +775,7 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr ticker := time.NewTicker(interval) defer ticker.Stop() + consecutiveFailures := 0 for { select { case <-ctx.Done(): @@ -772,17 +784,32 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) err := session.Ping(pingCtx, nil) pingCancel() - if err != nil { - if errors.Is(err, jsonrpc2.ErrMethodNotFound) { - // Peer doesn't support ping, stop the keepalive process. - return - } - // Ping failed; log it before closing the session so the - // failure is observable to operators. See #218. - logger.Error("keepalive ping failed; closing session", "error", err) - _ = session.Close() + if err == nil { + consecutiveFailures = 0 + continue + } + if errors.Is(err, jsonrpc2.ErrMethodNotFound) { + // Peer doesn't support ping, stop the keepalive process. return } + consecutiveFailures++ + if consecutiveFailures < failureThreshold { + // Tolerate transient failures below the threshold; log so + // the misses are still observable to operators. See #218. + logger.Warn("keepalive ping failed; tolerating below threshold", + "error", err, + "consecutiveFailures", consecutiveFailures, + "failureThreshold", failureThreshold) + continue + } + // Threshold reached; log before closing the session so the + // failure is observable to operators. See #218. + logger.Error("keepalive ping failed; closing session", + "error", err, + "consecutiveFailures", consecutiveFailures, + "failureThreshold", failureThreshold) + _ = session.Close() + return } } }() From 6d2bbff7d8534cec198ec4563f1e5dc6fa4cc78e Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sun, 31 May 2026 16:35:46 +0800 Subject: [PATCH 09/10] fix: add implementation description metadata (#981) ## Summary - add the optional `description` field to `mcp.Implementation` - keep the field omitted when empty - cover JSON marshal/unmarshal behavior with a focused regression test Fixes #977. ## To verify - `go test ./mcp -run TestImplementationDescriptionJSON -count=1` - `go test ./mcp -run TestServerRequest_PerRequestAccessors -count=1` - `go test ./mcp -count=1` - `go test ./...` - `git diff --check` Co-authored-by: Guglielmo Colombo --- mcp/protocol.go | 10 ++++++---- mcp/shared_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/mcp/protocol.go b/mcp/protocol.go index bebdc196..f401d220 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1935,16 +1935,18 @@ type ElicitationCompleteParams struct { func (x *ElicitationCompleteParams) isParams() {} func (x *ElicitationCompleteParams) isNil() bool { return x == nil } -// An Implementation describes the name and version of an MCP implementation, with an optional -// title for UI representation. +// An Implementation describes the name and version of an MCP implementation, with +// optional display metadata. type Implementation struct { // Intended for programmatic or logical use, but used as a display name in past // specs or fallback (if title isn't present). Name string `json:"name"` // Intended for UI and end-user contexts — optimized to be human-readable and // easily understood, even by those unfamiliar with domain-specific terminology. - Title string `json:"title,omitempty"` - Version string `json:"version"` + Title string `json:"title,omitempty"` + // A human-readable description of the implementation. + Description string `json:"description,omitempty"` + Version string `json:"version"` // WebsiteURL for the server, if any. WebsiteURL string `json:"websiteUrl,omitempty"` // Icons for the Server, if any. diff --git a/mcp/shared_test.go b/mcp/shared_test.go index e4f563d1..065d00b0 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -245,6 +245,39 @@ func TestServerRequest_PerRequestAccessors_Empty(t *testing.T) { } } +func TestImplementationDescriptionJSON(t *testing.T) { + impl := &Implementation{ + Name: "greeter", + Title: "Greeter", + Description: "Example server for greeting tools", + Version: "v1.0.0", + } + got, err := json.Marshal(impl) + if err != nil { + t.Fatal(err) + } + want := `{"name":"greeter","title":"Greeter","description":"Example server for greeting tools","version":"v1.0.0"}` + if string(got) != want { + t.Fatalf("Implementation JSON = %s, want %s", got, want) + } + + var roundTrip Implementation + if err := json.Unmarshal(got, &roundTrip); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(impl, &roundTrip); diff != "" { + t.Fatalf("Implementation round trip mismatch (-want +got):\n%s", diff) + } + + got, err = json.Marshal(&Implementation{Name: "greeter", Version: "v1.0.0"}) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(got), "description") { + t.Fatalf("empty description should be omitted, got %s", got) + } +} + // TODO(v0.3.0): rewrite this test. // func TestToolValidate(t *testing.T) { // // Check that the tool returned from NewServerTool properly validates its input schema. From ad054d3c5e3edb77f3034230455d8d75eb57e105 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:48:53 +0200 Subject: [PATCH 10/10] build(deps): bump github/codeql-action from 4.35.2 to 4.36.0 (#986) Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.35.2 to 4.36.0.
Release notes

Sourced from github/codeql-action's releases.

v4.36.0

  • Breaking change: Bump the minimum required CodeQL bundle version to 2.19.4. #3894
  • Add support for SHA-256 Git object IDs. #3893
  • Update default CodeQL bundle version to 2.25.5. #3926

v4.35.5

  • We have improved how the JavaScript bundles for the CodeQL Action are generated to avoid duplication across bundles and reduce the size of the repository by around 70%. This should have no effect on the runtime behaviour of the CodeQL Action. #3899
  • For performance and accuracy reasons, improved incremental analysis will now only be enabled on a pull request when diff-informed analysis is also enabled for that run. If diff-informed analysis is unavailable (for example, because the PR diff ranges could not be computed), the action will fall back to a full analysis. #3791
  • If multiple inputs are provided for the GitHub-internal analysis-kinds input, only code-scanning will be enabled. The analysis-kinds input is experimental, for GitHub-internal use only, and may change without notice at any time. #3892
  • Added an experimental change which, when running a Code Scanning analysis for a PR with improved incremental analysis enabled, prefers CodeQL CLI versions that have a cached overlay-base database for the configured languages. This speeds up analysis for a repository when there is not yet a cached overlay-base database for the latest CLI version. We expect to roll this change out to everyone in May. #3880

v4.35.4

  • Update default CodeQL bundle version to 2.25.4. #3881

v4.35.3

  • Upcoming breaking change: Add a deprecation warning for customers using CodeQL version 2.19.3 and earlier. These versions of CodeQL were discontinued on 9 April 2026 alongside GitHub Enterprise Server 3.15, and will be unsupported by the next minor release of the CodeQL Action. #3837
  • Configurations for private registries that use Cloudsmith or GCP OIDC are now accepted. #3850
  • Best-effort connection tests for private registries now use GET requests instead of HEAD for better compatibility with various registry implementations. For NuGet feeds, the test is now always performed against the service index. #3853
  • Fixed a bug where two diagnostics produced within the same millisecond could overwrite each other on disk, causing one of them to be lost. #3852
  • Update default CodeQL bundle version to 2.25.3. #3865
Changelog

Sourced from github/codeql-action's changelog.

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

[UNRELEASED]

No user facing changes.

4.36.0 - 22 May 2026

  • Breaking change: Bump the minimum required CodeQL bundle version to 2.19.4. #3894
  • Add support for SHA-256 Git object IDs. #3893
  • Update default CodeQL bundle version to 2.25.5. #3926

4.35.5 - 15 May 2026

  • We have improved how the JavaScript bundles for the CodeQL Action are generated to avoid duplication across bundles and reduce the size of the repository by around 70%. This should have no effect on the runtime behaviour of the CodeQL Action. #3899
  • For performance and accuracy reasons, improved incremental analysis will now only be enabled on a pull request when diff-informed analysis is also enabled for that run. If diff-informed analysis is unavailable (for example, because the PR diff ranges could not be computed), the action will fall back to a full analysis. #3791
  • If multiple inputs are provided for the GitHub-internal analysis-kinds input, only code-scanning will be enabled. The analysis-kinds input is experimental, for GitHub-internal use only, and may change without notice at any time. #3892
  • Added an experimental change which, when running a Code Scanning analysis for a PR with improved incremental analysis enabled, prefers CodeQL CLI versions that have a cached overlay-base database for the configured languages. This speeds up analysis for a repository when there is not yet a cached overlay-base database for the latest CLI version. We expect to roll this change out to everyone in May. #3880

4.35.4 - 07 May 2026

  • Update default CodeQL bundle version to 2.25.4. #3881

4.35.3 - 01 May 2026

  • Upcoming breaking change: Add a deprecation warning for customers using CodeQL version 2.19.3 and earlier. These versions of CodeQL were discontinued on 9 April 2026 alongside GitHub Enterprise Server 3.15, and will be unsupported by the next minor release of the CodeQL Action. #3837
  • Configurations for private registries that use Cloudsmith or GCP OIDC are now accepted. #3850
  • Best-effort connection tests for private registries now use GET requests instead of HEAD for better compatibility with various registry implementations. For NuGet feeds, the test is now always performed against the service index. #3853
  • Fixed a bug where two diagnostics produced within the same millisecond could overwrite each other on disk, causing one of them to be lost. #3852
  • Update default CodeQL bundle version to 2.25.3. #3865

4.35.2 - 15 Apr 2026

  • The undocumented TRAP cache cleanup feature that could be enabled using the CODEQL_ACTION_CLEANUP_TRAP_CACHES environment variable is deprecated and will be removed in May 2026. If you are affected by this, we recommend disabling TRAP caching by passing the trap-caching: false input to the init Action. #3795
  • The Git version 2.36.0 requirement for improved incremental analysis now only applies to repositories that contain submodules. #3789
  • Python analysis on GHES no longer extracts the standard library, relying instead on models of the standard library. This should result in significantly faster extraction and analysis times, while the effect on alerts should be minimal. #3794
  • Fixed a bug in the validation of OIDC configurations for private registries that was added in CodeQL Action 4.33.0 / 3.33.0. #3807
  • Update default CodeQL bundle version to 2.25.2. #3823

4.35.1 - 27 Mar 2026

4.35.0 - 27 Mar 2026

... (truncated)

Commits
  • 7211b7c Merge pull request #3927 from github/update-v4.36.0-ebc2d9e2b
  • 7740f2f Update changelog for v4.36.0
  • ebc2d9e Merge pull request #3926 from github/update-bundle/codeql-bundle-v2.25.5
  • d1f74b7 Add changelog note
  • 2dc40ce Update default bundle to codeql-bundle-v2.25.5
  • 8449852 Merge pull request #3910 from github/henrymercer/repo-size-diff-check
  • 72ac23c Update excluded required check list
  • c5297a2 Merge pull request #3919 from github/henrymercer/workflow-concurrency
  • 8ffeae7 CI: Automatically cancel non-generated workflows
  • f3f52bf Revert getErrorMessage import
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=4.35.2&new-version=4.36.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codeql.yml | 4 ++-- .github/workflows/scorecard.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 4aa91793..ca0d7323 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,11 +38,11 @@ jobs: - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Initialize CodeQL - uses: github/codeql-action/init@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/init@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/analyze@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index bb8b659d..c8035cb0 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -73,6 +73,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/upload-sarif@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: sarif_file: results.sarif