Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,8 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
return
}

protocolVersion := protocolVersionFromContext(req.Context())
headerProtocolVersion := protocolVersionFromContext(req.Context())
protocolVersion := headerProtocolVersion
if protocolVersion == "" {
protocolVersion = protocolVersion20250326
}
Expand All @@ -1291,6 +1292,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
calls := make(map[jsonrpc.ID]struct{})
tokenInfo := auth.TokenInfoFromContext(req.Context())
isInitialize := false
var initializeID jsonrpc.ID
var initializeProtocolVersion string
headerVersion := protocolVersionFromContext(req.Context())
for _, msg := range incoming {
Expand All @@ -1304,6 +1306,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
}
if jreq.Method == methodInitialize {
isInitialize = true
initializeID = jreq.ID
// Extract the protocol version from InitializeParams.
var params InitializeParams
if err := internaljson.Unmarshal(jreq.Params, &params); err == nil {
Expand All @@ -1321,6 +1324,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
if meta := extractRequestMeta(jreq.Params); meta != nil {
metaVersion, _ = meta[MetaKeyProtocolVersion].(string)
}
if jreq.Method == methodInitialize && metaVersion == "" && headerVersion >= protocolVersion20260630 {
metaVersion = initializeProtocolVersion
}
if protocolVersion >= protocolVersion20260630 || metaVersion != "" {
if !c.stateless {
http.Error(w, fmt.Sprintf(
Expand Down Expand Up @@ -1371,6 +1377,22 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
}
}

if headerProtocolVersion != "" && initializeProtocolVersion != "" && headerProtocolVersion != initializeProtocolVersion {
resp := &jsonrpc.Response{
ID: initializeID,
Error: jsonrpc2.NewError(
CodeHeaderMismatch,
fmt.Sprintf("header mismatch: %s header value %q does not match body protocolVersion %q", protocolVersionHeader, headerProtocolVersion, initializeProtocolVersion),
),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
}

// Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*)
if !isBatch && len(incoming) == 1 {
if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil {
Expand Down
95 changes: 95 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,19 @@
},
wantSessions: 1,
},
{
name: "initialize protocol version header mismatch",
requests: []streamableRequest{
{
method: "POST",
headers: http.Header{protocolVersionHeader: {protocolVersion20251125}},
messages: []jsonrpc.Message{req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20250618})},
wantStatusCode: http.StatusBadRequest,
wantBodyContaining: "header mismatch",
},
},
wantSessions: 0,
},
{
name: "batch rejected on 2025-06-18",
requests: []streamableRequest{
Expand Down Expand Up @@ -2049,6 +2062,88 @@
})
}

// TODO: Remove this once client operations will automatically inject metadata in the requests
func injectMetaToRequest(req *http.Request) error {

Check failure on line 2066 in mcp/streamable_test.go

View workflow job for this annotation

GitHub Actions / lint

func injectMetaToRequest is unused (U1000)
if req.Body == nil {
return nil
}
body, err := io.ReadAll(req.Body)
if err != nil {
return err
}
req.Body.Close()

var val any
if err := json.Unmarshal(body, &val); err == nil {
var method string
if m, ok := val.(map[string]any); ok {
method, _ = m["method"].(string)
} else if list, ok := val.([]any); ok && len(list) > 0 {
if m, ok := list[0].(map[string]any); ok {
method, _ = m["method"].(string)
}
}

if method == "initialize" {
var protocolVersion string
if m, ok := val.(map[string]any); ok {
if params, _ := m["params"].(map[string]any); params != nil {
protocolVersion, _ = params["protocolVersion"].(string)
}
} else if list, ok := val.([]any); ok && len(list) > 0 {
if m, ok := list[0].(map[string]any); ok {
if params, _ := m["params"].(map[string]any); params != nil {
protocolVersion, _ = params["protocolVersion"].(string)
}
}
}
if protocolVersion == "" {
protocolVersion = protocolVersion20251125
}
req.Header.Set(protocolVersionHeader, protocolVersion)
if protocolVersion >= minVersionForStandardHeaders {
req.Header.Set(methodHeader, methodInitialize)
}
} else if method == "notifications/initialized" || strings.HasPrefix(method, "notifications/") {
req.Header.Set(protocolVersionHeader, protocolVersion20251125)
} else {
req.Header.Set(protocolVersionHeader, minVersionForStandardHeaders)

var msgs []map[string]any
if m, ok := val.(map[string]any); ok {
msgs = []map[string]any{m}
} else if list, ok := val.([]any); ok {
for _, item := range list {
if m, ok := item.(map[string]any); ok {
msgs = append(msgs, m)
}
}
}

for _, m := range msgs {
params, _ := m["params"].(map[string]any)
if params == nil {
params = make(map[string]any)
m["params"] = params
}
meta, _ := params["_meta"].(map[string]any)
if meta == nil {
meta = make(map[string]any)
params["_meta"] = meta
}
meta[MetaKeyProtocolVersion] = minVersionForStandardHeaders
meta[MetaKeyClientInfo] = map[string]any{"name": "testClient", "version": "v1.0.0"}
meta[MetaKeyClientCapabilities] = map[string]any{}
}
body, _ = json.Marshal(val)
}
}

req.Body = io.NopCloser(bytes.NewReader(body))
req.ContentLength = int64(len(body))
return nil
}

// TestStreamableMcpHeaderValidationErrorFormat verifies that header
// validation errors return a JSON-RPC error with code -32001 and
// Content-Type application/json, per SEP-2243.
Expand Down
Loading