From a77ebf37d9f4d4a30700f21a6c056a6a32486dd6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 2 Jun 2026 19:56:29 +0000 Subject: [PATCH] Run SSE and Unicode transport tests in process instead of over sockets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same class of fix as the security-test migration: these files spawned uvicorn subprocesses on bind-then-close ports and polled for readiness, which races under pytest-xdist when two workers pick the same port. tests/shared/test_sse.py now drives the same Starlette apps in process: sse_client connects through an httpx_client_factory backed by the in-process streaming bridge, raw-httpx assertions use the same bridge, and the mounted-app and request-context servers are built as plain app factories instead of subprocess targets. Assertions are unchanged, with two deliberate exceptions now that the server handlers run traced in-process: - test_sse_client_timeout is deleted. It has been permanently skipped ("highlights a possible bug in SSE read timeout exception handling"), and its premise — a real network read timeout — cannot exist for an in-process app, so it could never be unskipped here. The slow:// resource branch and sse_read_timeout plumbing existed only for it. - Handlers that no test ever invoked (the main server's tools handlers, the context server's unknown-tool fallthrough) are removed or replaced with dispatch asserts, since unreachable branches now fail branch coverage instead of hiding in an untraced subprocess. tests/client/test_http_unicode.py gets the same treatment: the Unicode echo server now runs in process and streamable_http_client speaks to it through the bridge with follow_redirects enabled, matching the SDK's own client factory (Starlette's Mount 307-redirects the bare /mcp path). --- tests/client/test_http_unicode.py | 288 ++++++++------------- tests/shared/test_sse.py | 409 ++++++++++-------------------- 2 files changed, 242 insertions(+), 455 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e469..585a142617 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,11 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager +import httpx import pytest from starlette.applications import Starlette from starlette.routing import Mount @@ -19,7 +18,10 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,74 +43,62 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn - - # Need to recreate the server setup in this process - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, - }, - "required": ["text"], +async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": { + "text": {"type": "string", "description": "Text to echo back"}, }, - ), - ] - ) - - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "echo_unicode": - text = params.arguments.get("text", "") if params.arguments else "" - return types.CallToolResult( - content=[ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] + "required": ["text"], + }, + ), + ] + ) + + +async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + assert params.name == "echo_unicode" + assert params.arguments is not None + return types.CallToolResult(content=[TextContent(type="text", text=f"Echo: {params.arguments['text']}")]) + + +async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], ) - else: - raise ValueError(f"Unknown tool: {params.name}") - - async def handle_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - return types.ListPromptsResult( - prompts=[ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], - ) - ] - ) - - async def handle_get_prompt( - ctx: ServerRequestContext, params: types.GetPromptRequestParams - ) -> types.GetPromptResult: - if params.name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] + ] + ) + + +async def handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + assert params.name == "unicode_prompt" + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), ) - raise ValueError(f"Unknown prompt: {params.name}") + ] + ) + +@asynccontextmanager +async def unicode_session() -> AsyncIterator[ClientSession]: + """Yield an initialized ClientSession speaking streamable HTTP (SSE responses) to the + Unicode test server, entirely in process.""" server = Server( name="unicode_test_server", on_list_tools=handle_list_tools, @@ -116,122 +106,68 @@ async def handle_get_prompt( on_list_prompts=handle_list_prompts, on_get_prompt=handle_get_prompt, ) - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing - ) - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield - - # Create an ASGI application - app = Starlette( - debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], - lifespan=lifespan, - ) - - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) + # SSE response mode, so Unicode rides the SSE event encoding rather than a plain JSON body. + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + + async with ( + session_manager.run(), + # follow_redirects matches the SDK's own client factory; Starlette's Mount 307-redirects + # the bare /mcp path to /mcp/. + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, follow_redirects=True + ) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call() -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 + async with unicode_session() as session: + # Test 1: List tools (server→client Unicode in descriptions) + tools = await session.list_tools() + assert len(tools.tools) == 1 - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Check Unicode in tool descriptions + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await session.call_tool("echo_unicode", arguments={"text": test_string}) - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Verify server correctly received and echoed back Unicode + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts() -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + async with unicode_session() as session: + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707b..675a4acb16 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,7 @@ +"""Tests for the SSE client and server transports, driven entirely in process.""" + import json -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -9,7 +9,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -24,6 +23,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._httpx_utils import McpHttpClientFactory from mcp.shared.exceptions import MCPError from mcp.types import ( CallToolRequestParams, @@ -41,171 +41,114 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.interaction.transports import StreamingASGITransport SERVER_NAME = "test_server_for_SSE" +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +BASE_URL = "http://127.0.0.1:8000" + + +def in_process_client_factory(app: Starlette) -> McpHttpClientFactory: + """An httpx_client_factory for sse_client whose clients are served in process by `app`.""" + + def factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE GET runs until it observes a disconnect, so the bridge must let the + # application drain on close rather than cancelling it. follow_redirects matches + # create_mcp_http_client, the factory this one stands in for. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - + return factory -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - -async def _handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: +async def _handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": - text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - text = f"Slow response from {parsed.netloc}" - else: - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) - - -async def _handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=_handle_read_resource, - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool, - ) + return ReadResourceResult( + contents=[TextResourceContents(uri=uri, text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + raise MCPError(code=404, message="OOPS! no resource with that URI was found") -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] +def make_app(server: Server) -> Starlette: + """Mount `server` on a Starlette app exposing the SSE transport at /sse and /messages/.""" + # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot + # exist for an in-process app; the transport security behaviour itself is pinned by + # tests/server/test_sse_security.py. + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - server = _create_server() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) + async with sse.connect_sse(request.scope, request.receive, request._send) as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +def make_server_app() -> Starlette: + return make_app(Server(SERVER_NAME, on_read_resource=_handle_read_resource)) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: - """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: - yield client +@pytest.mark.anyio +async def test_raw_sse_connection() -> None: + """The SSE GET responds 200 with an event-stream content type, announcing the session + endpoint as its first event.""" + http_client = httpx.AsyncClient( + transport=StreamingASGITransport(make_server_app(), cancel_on_close=False), base_url=BASE_URL + ) + with anyio.fail_after(5): + async with http_client, http_client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" -# Tests -@pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: - """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group(): - - async def connection_test() -> None: - async with http_client.stream("GET", "/sse") as response: - assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - - line_number = 0 - async for line in response.aiter_lines(): # pragma: no branch - if line_number == 0: - assert line == "event: endpoint" - elif line_number == 1: - assert line.startswith("data: /messages/?session_id=") - else: - return - line_number += 1 - - # Add timeout to prevent test from hanging if it fails - with anyio.fail_after(3): - await connection_test() + lines = response.aiter_lines() + assert await anext(lines) == "event: endpoint" + assert (await anext(lines)).startswith("data: /messages/?session_id=") @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection() -> None: + """A client initializes against, and pings, a server over the SSE transport.""" + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: +async def test_sse_client_on_session_created() -> None: + """The session-created callback receives the new session ID before sse_client yields.""" + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -226,13 +169,14 @@ async def test_sse_client_on_session_created(server: None, server_url: str) -> N ], ) def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: + """The session ID is read from the endpoint URL's sessionId/session_id query parameters.""" assert _extract_session_id_from_endpoint(endpoint_url) == expected @pytest.mark.anyio -async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch -) -> None: +async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + """No session-created callback fires when the endpoint URL carries no session ID.""" + factory = in_process_client_factory(make_server_app()) callback_mock = Mock() def mock_extract(url: str) -> None: @@ -240,7 +184,7 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -250,8 +194,9 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session() -> AsyncGenerator[ClientSession, None]: + factory = in_process_client_factory(make_server_app()) + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -261,6 +206,7 @@ async def initialized_sse_client_session(server: None, server_url: str) -> Async async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, ) -> None: + """A resource read round-trips its arguments and the handler's content over SSE.""" session = initialized_sse_client_session response = await session.read_resource(uri="foobar://should-work") assert len(response.contents) == 1 @@ -272,93 +218,45 @@ async def test_sse_client_happy_request_and_response( async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, ) -> None: + """A server-side MCPError reaches the client with its message intact.""" session = initialized_sse_client_session with pytest.raises(MCPError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri="xxx://will-not-work") @pytest.mark.anyio -@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") -async def test_sse_client_timeout( # pragma: no cover - initialized_sse_client_session: ClientSession, -) -> None: - session = initialized_sse_client_session - - # sanity check that normal, fast responses are working - response = await session.read_resource(uri="foobar://1") - assert isinstance(response, ReadResourceResult) - - with anyio.move_on_after(3): - with pytest.raises(MCPError, match="Read timed out"): - response = await session.read_resource(uri="slow://2") - # we should receive an error here - return - - pytest.fail("the client should have timed out and returned an error already") - - -def run_mounted_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) +async def test_sse_client_basic_connection_mounted_app() -> None: + """The SSE transport works unchanged when its app is mounted under a sub-path.""" + main_app = Starlette(routes=[Mount("/mounted_app", app=make_server_app())]) + factory = in_process_client_factory(main_app) - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: + async with sse_client(f"{BASE_URL}/mounted_app/sse", httpx_client_factory=factory) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - headers_info: dict[str, Any] = {} - if ctx.request: - headers_info = dict(ctx.request.headers) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name in ("echo_headers", "echo_context") + assert ctx.request is not None + headers_info = dict(ctx.request.headers) if params.name == "echo_headers": return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - elif params.name == "echo_context": - context_data = { - "request_id": (params.arguments or {}).get("request_id"), - "headers": headers_info, - } - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + assert params.arguments is not None + context_data = { + "request_id": params.arguments.get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -381,115 +279,65 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing - security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] - ) - sse = SseServerTransport("/messages/", security_settings=security_settings) - context_server = Server( - "request_context_server", - on_call_tool=_handle_context_call_tool, - on_list_tools=_handle_context_list_tools, - ) - - async def handle_sse(request: Request) -> Response: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ] +def make_context_server_app() -> Starlette: + return make_app( + Server( + "request_context_server", + on_call_tool=_handle_context_call_tool, + on_list_tools=_handle_context_list_tools, + ) ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - - -@pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") - @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: - """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers +async def test_request_context_propagation() -> None: + """Custom HTTP headers on the SSE connection are visible to server handlers via ctx.request.""" + factory = in_process_client_factory(make_context_server_app()) + custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: - # Initialize the session + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=custom_headers) as streams: + async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 - headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") + content = tool_result.content[0] + assert isinstance(content, TextContent) + headers_data = json.loads(content.text) - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: - """Test that request contexts are isolated between different SSE clients.""" +async def test_request_context_isolation() -> None: + """Each SSE connection's handlers see only that connection's request headers.""" + factory = in_process_client_factory(make_context_server_app()) contexts: list[dict[str, Any]] = [] - # Create multiple clients with different headers + # Connect three clients in turn, each with its own headers. for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: + async with sse_client(f"{BASE_URL}/sse", httpx_client_factory=factory, headers=headers) as streams: + async with ClientSession(*streams) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 - context_data = json.loads( - tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" - ) - contexts.append(context_data) + content = tool_result.content[0] + assert isinstance(content, TextContent) + contexts.append(json.loads(content.text)) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -497,7 +345,7 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["headers"].get("x-custom-value") == f"value-{i}" -def test_sse_message_id_coercion(): +def test_sse_message_id_coercion() -> None: """Previously, the `RequestId` would coerce a string that looked like an integer into an integer. See for more details. @@ -531,7 +379,7 @@ def test_sse_message_id_coercion(): ("/messages/#fragment", ValueError), ], ) -def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): +def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]) -> None: """Test that SseServerTransport properly validates and normalizes endpoints.""" if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception @@ -605,7 +453,7 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: @pytest.mark.anyio -async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: +async def test_sse_session_cleanup_on_disconnect() -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 When a client disconnects, the server should remove the session from @@ -613,18 +461,21 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) POST requests to disconnected sessions return 202 Accepted followed by a ClosedResourceError when the server tries to write to the dead stream. """ + factory = in_process_client_factory(make_server_app()) captured: list[str] = [] # Connect a client session, then disconnect - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client( + f"{BASE_URL}/sse", httpx_client_factory=factory, on_session_created=captured.append + ) as streams: async with ClientSession(*streams) as session: await session.initialize() # After disconnect, POST to the stale session should return 404 # (not 202 as it did before the fix) - async with httpx.AsyncClient() as client: + async with factory() as client: response = await client.post( - f"{server_url}/messages/?session_id={captured[0]}", + f"/messages/?session_id={captured[0]}", json={"jsonrpc": "2.0", "method": "ping", "id": 99}, headers={"Content-Type": "application/json"}, )