diff --git a/mcp/streamable.go b/mcp/streamable.go index e6a9bfe5..0daac38b 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1271,7 +1271,8 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } - protocolVersion := protocolVersionFromContext(req.Context()) + headerProtocolVersion := protocolVersionFromContext(req.Context()) + protocolVersion := headerProtocolVersion if protocolVersion == "" { protocolVersion = protocolVersion20250326 } @@ -1291,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + var initializeID jsonrpc.ID var initializeProtocolVersion string headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { @@ -1304,6 +1306,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } if jreq.Method == methodInitialize { isInitialize = true + initializeID = jreq.ID // Extract the protocol version from InitializeParams. var params InitializeParams if err := internaljson.Unmarshal(jreq.Params, ¶ms); err == nil { @@ -1321,6 +1324,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques if meta := extractRequestMeta(jreq.Params); meta != nil { metaVersion, _ = meta[MetaKeyProtocolVersion].(string) } + if jreq.Method == methodInitialize && metaVersion == "" && headerVersion >= protocolVersion20260630 { + metaVersion = initializeProtocolVersion + } if protocolVersion >= protocolVersion20260630 || metaVersion != "" { if !c.stateless { http.Error(w, fmt.Sprintf( @@ -1371,6 +1377,22 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } + if headerProtocolVersion != "" && initializeProtocolVersion != "" && headerProtocolVersion != initializeProtocolVersion { + resp := &jsonrpc.Response{ + ID: initializeID, + Error: jsonrpc2.NewError( + CodeHeaderMismatch, + fmt.Sprintf("header mismatch: %s header value %q does not match body protocolVersion %q", protocolVersionHeader, headerProtocolVersion, initializeProtocolVersion), + ), + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if data, err := jsonrpc2.EncodeMessage(resp); err == nil { + w.Write(data) + } + return + } + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a27696a1..5769d5ae 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -972,6 +972,19 @@ func TestStreamableServerTransport(t *testing.T) { }, wantSessions: 1, }, + { + name: "initialize protocol version header mismatch", + requests: []streamableRequest{ + { + method: "POST", + headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + messages: []jsonrpc.Message{req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20250618})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + }, + wantSessions: 0, + }, { name: "batch rejected on 2025-06-18", requests: []streamableRequest{ @@ -2049,6 +2062,88 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }) } +// TODO: Remove this once client operations will automatically inject metadata in the requests +func injectMetaToRequest(req *http.Request) error { + if req.Body == nil { + return nil + } + body, err := io.ReadAll(req.Body) + if err != nil { + return err + } + req.Body.Close() + + var val any + if err := json.Unmarshal(body, &val); err == nil { + var method string + if m, ok := val.(map[string]any); ok { + method, _ = m["method"].(string) + } else if list, ok := val.([]any); ok && len(list) > 0 { + if m, ok := list[0].(map[string]any); ok { + method, _ = m["method"].(string) + } + } + + if method == "initialize" { + var protocolVersion string + if m, ok := val.(map[string]any); ok { + if params, _ := m["params"].(map[string]any); params != nil { + protocolVersion, _ = params["protocolVersion"].(string) + } + } else if list, ok := val.([]any); ok && len(list) > 0 { + if m, ok := list[0].(map[string]any); ok { + if params, _ := m["params"].(map[string]any); params != nil { + protocolVersion, _ = params["protocolVersion"].(string) + } + } + } + if protocolVersion == "" { + protocolVersion = protocolVersion20251125 + } + req.Header.Set(protocolVersionHeader, protocolVersion) + if protocolVersion >= minVersionForStandardHeaders { + req.Header.Set(methodHeader, methodInitialize) + } + } else if method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") { + req.Header.Set(protocolVersionHeader, protocolVersion20251125) + } else { + req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + var msgs []map[string]any + if m, ok := val.(map[string]any); ok { + msgs = []map[string]any{m} + } else if list, ok := val.([]any); ok { + for _, item := range list { + if m, ok := item.(map[string]any); ok { + msgs = append(msgs, m) + } + } + } + + for _, m := range msgs { + params, _ := m["params"].(map[string]any) + if params == nil { + params = make(map[string]any) + m["params"] = params + } + meta, _ := params["_meta"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + params["_meta"] = meta + } + meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders + meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"} + meta[MetaKeyClientCapabilities] = map[string]any{} + } + body, _ = json.Marshal(val) + } + } + + req.Body = io.NopCloser(bytes.NewReader(body)) + req.ContentLength = int64(len(body)) + return nil +} + // TestStreamableMcpHeaderValidationErrorFormat verifies that header // validation errors return a JSON-RPC error with code -32001 and // Content-Type application/json, per SEP-2243.