diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..e34a3c0477 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -290,6 +290,7 @@ def _create_error_response( status_code: HTTPStatus, error_code: int = INVALID_REQUEST, headers: dict[str, str] | None = None, + request_id: RequestId | None = None, ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} @@ -302,7 +303,7 @@ def _create_error_response( # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", - id=None, + id=request_id, error=ErrorData(code=error_code, message=error_message), ) @@ -482,19 +483,16 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" if is_initialization_request: - # Check if the server already has an established session - if self.mcp_session_id: - # Check if request has a session ID - request_session_id = self._get_session_id(request) - - # If request has a session ID but doesn't match, return 404 - if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover - response = self._create_error_response( - "Not Found: Invalid or expired session ID", - HTTPStatus.NOT_FOUND, - ) - await response(scope, receive, send) - return + request_session_id = self._get_session_id(request) + if ( + self.mcp_session_id and request_session_id and request_session_id != self.mcp_session_id + ): # pragma: no cover + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return elif not await self._validate_request_headers(request, send): return @@ -525,6 +523,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Extract the request ID outside the try block for proper scope request_id = str(message.id) + if request_id in self._request_streams: + response = self._create_error_response( + f"Conflict: Request ID {request_id!r} is already in flight on this session", + HTTPStatus.CONFLICT, + request_id=message.id, + ) + await response(scope, receive, send) + return # Register this stream for the request ID self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) request_stream_reader = self._request_streams[request_id][1] diff --git a/tests/issues/test_2655_streamable_http_duplicate_request_id.py b/tests/issues/test_2655_streamable_http_duplicate_request_id.py new file mode 100644 index 0000000000..92702abef7 --- /dev/null +++ b/tests/issues/test_2655_streamable_http_duplicate_request_id.py @@ -0,0 +1,105 @@ +import anyio +import httpx +import pytest + +from mcp.server import Server, ServerRequestContext +from mcp.server.streamable_http import MCP_SESSION_ID_HEADER +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + CallToolRequestParams, + CallToolResult, + TextContent, +) + + +@pytest.mark.anyio +async def test_streamable_http_duplicate_request_id_returns_409_and_preserves_in_flight_request() -> None: + started = anyio.Event() + release = anyio.Event() + + async def handle_call_tool( + ctx: ServerRequestContext[object], + params: CallToolRequestParams, + ) -> CallToolResult: + started.set() + await release.wait() + return CallToolResult(content=[TextContent(type="text", text="ok")]) + + server = Server("test-duplicate-request-id", on_call_tool=handle_call_tool) + mcp_app = server.streamable_http_app(json_response=True, host="testserver") + + async with ( + mcp_app.router.lifespan_context(mcp_app), + httpx.ASGITransport(mcp_app) as transport, + httpx.AsyncClient(transport=transport, base_url="http://testserver", timeout=5.0) as client, + ): + base_headers = {"Accept": "application/json", "Content-Type": "application/json"} + + init_response = await client.post( + "/mcp", + headers=base_headers, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "id": "init-1", + "params": { + "clientInfo": {"name": "test-client", "version": "0"}, + "protocolVersion": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + }, + }, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + session_headers = {**base_headers, MCP_SESSION_ID_HEADER: session_id} + + initialized = await client.post( + "/mcp", + headers=session_headers, + json={"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}, + ) + assert initialized.status_code == 202 + + request_id = "dup-id-1" + slow_response: httpx.Response | None = None + + async def run_slow_request() -> None: + nonlocal slow_response + slow_response = await client.post( + "/mcp", + headers=session_headers, + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "id": request_id, + "params": {"name": "slow_tool", "arguments": {}}, + }, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(run_slow_request) + with anyio.fail_after(5): + await started.wait() + + duplicate = await client.post( + "/mcp", + headers=session_headers, + json={"jsonrpc": "2.0", "method": "ping", "id": request_id, "params": {}}, + ) + assert duplicate.status_code == 409 + duplicate_body = duplicate.json() + assert duplicate_body["jsonrpc"] == "2.0" + assert duplicate_body["id"] == request_id + assert duplicate_body["error"]["code"] == INVALID_REQUEST + + release.set() + + assert slow_response is not None + assert slow_response.status_code == 200 + slow_body = slow_response.json() + assert slow_body["jsonrpc"] == "2.0" + assert slow_body["id"] == request_id + assert slow_body["result"]["content"][0]["text"] == "ok"