Skip to content

Commit c549da7

Browse files
committed
fix: drain completed streamable HTTP SSE responses
1 parent 616476f commit c549da7

2 files changed

Lines changed: 132 additions & 9 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
244247
is_complete = await self._handle_sse_event(
245248
sse,
246249
ctx.read_stream_writer,
247250
original_request_id,
248251
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249252
)
250253
if is_complete:
251-
await event_source.response.aclose()
252-
break
254+
response_complete = True
253255

254256
async def _handle_post_request(self, ctx: RequestContext) -> None:
255257
"""Handle a POST request with response processing."""
@@ -342,6 +344,7 @@ async def _handle_sse_response(
342344

343345
try:
344346
event_source = EventSource(response)
347+
response_complete = False
345348
async for sse in event_source.aiter_sse(): # pragma: no branch
346349
# Track last event ID for potential reconnection
347350
if sse.id:
@@ -351,6 +354,9 @@ async def _handle_sse_response(
351354
if sse.retry is not None:
352355
retry_interval_ms = sse.retry
353356

357+
if response_complete:
358+
continue
359+
354360
is_complete = await self._handle_sse_event(
355361
sse,
356362
ctx.read_stream_writer,
@@ -359,10 +365,11 @@ async def _handle_sse_response(
359365
is_initialization=is_initialization,
360366
)
361367
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
368+
# keep draining the stream so the underlying HTTP connection remains reusable.
363369
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
370+
response_complete = True
371+
if response_complete:
372+
return # Normal completion, no reconnect needed
366373
except Exception:
367374
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368375

@@ -404,22 +411,27 @@ async def _handle_reconnection(
404411
# Track for potential further reconnection
405412
reconnect_last_event_id: str = last_event_id
406413
reconnect_retry_ms = retry_interval_ms
414+
response_complete = False
407415

408416
async for sse in event_source.aiter_sse():
409417
if sse.id: # pragma: no branch
410418
reconnect_last_event_id = sse.id
411419
if sse.retry is not None:
412420
reconnect_retry_ms = sse.retry
413421

422+
if response_complete:
423+
continue
424+
414425
is_complete = await self._handle_sse_event(
415426
sse,
416427
ctx.read_stream_writer,
417428
original_request_id,
418429
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419430
)
420431
if is_complete:
421-
await event_source.response.aclose()
422-
return
432+
response_complete = True
433+
if response_complete:
434+
return
423435

424436
# Stream ended again without response - reconnect again (reset attempt counter)
425437
logger.info("SSE stream disconnected, reconnecting...")

tests/shared/test_streamable_http.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,16 @@
2727
from starlette.requests import Request
2828
from starlette.routing import Mount
2929

30+
import mcp.client.streamable_http as streamable_http_module
3031
from mcp import MCPError, types
3132
from mcp.client.session import ClientSession
32-
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
33+
from mcp.client.streamable_http import (
34+
RequestContext as StreamableHTTPClientRequestContext,
35+
)
36+
from mcp.client.streamable_http import (
37+
StreamableHTTPTransport,
38+
streamable_http_client,
39+
)
3340
from mcp.server import Server, ServerRequestContext
3441
from mcp.server.streamable_http import (
3542
MCP_PROTOCOL_VERSION_HEADER,
@@ -45,7 +52,7 @@
4552
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4653
from mcp.server.transport_security import TransportSecuritySettings
4754
from mcp.shared._context import RequestContext
48-
from mcp.shared._context_streams import create_context_streams
55+
from mcp.shared._context_streams import ContextSendStream, create_context_streams
4956
from mcp.shared._httpx_utils import (
5057
MCP_DEFAULT_SSE_READ_TIMEOUT,
5158
MCP_DEFAULT_TIMEOUT,
@@ -1803,6 +1810,110 @@ async def test_handle_sse_event_skips_empty_data():
18031810
await read_stream.aclose()
18041811

18051812

1813+
class _FakeStreamResponse(httpx.Response):
1814+
def __init__(self) -> None:
1815+
super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp"))
1816+
self.closed_by_transport = False
1817+
1818+
async def aclose(self) -> None:
1819+
self.closed_by_transport = True
1820+
await super().aclose()
1821+
1822+
1823+
def _response_sse(request_id: int | str) -> ServerSentEvent:
1824+
return ServerSentEvent(
1825+
event="message",
1826+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
1827+
id="response-event",
1828+
)
1829+
1830+
1831+
def _make_streamable_http_request_context(
1832+
request_id: int | str,
1833+
client: httpx.AsyncClient,
1834+
write_stream: ContextSendStream[SessionMessage | Exception],
1835+
) -> StreamableHTTPClientRequestContext:
1836+
return StreamableHTTPClientRequestContext(
1837+
client=client,
1838+
session_id=None,
1839+
session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")),
1840+
metadata=None,
1841+
read_stream_writer=write_stream,
1842+
)
1843+
1844+
1845+
@pytest.mark.anyio
1846+
async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1847+
"""Terminal POST SSE responses are drained instead of force-closed."""
1848+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1849+
response = _FakeStreamResponse()
1850+
1851+
class FakeEventSource:
1852+
def __init__(self, response: _FakeStreamResponse) -> None:
1853+
self.response = response
1854+
1855+
async def aiter_sse(self):
1856+
yield _response_sse(1)
1857+
yield ServerSentEvent(event="message", data="", id="drained-event")
1858+
1859+
async def fail_reconnect(*args: Any, **kwargs: Any) -> None:
1860+
raise AssertionError("terminal responses should not reconnect after draining")
1861+
1862+
monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource)
1863+
monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect)
1864+
1865+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1866+
async with httpx.AsyncClient() as client:
1867+
try:
1868+
ctx = _make_streamable_http_request_context(1, client, write_stream)
1869+
await transport._handle_sse_response(response, ctx)
1870+
1871+
assert response.closed_by_transport is False
1872+
message = await read_stream.receive()
1873+
assert isinstance(message, SessionMessage)
1874+
assert isinstance(message.message, types.JSONRPCResponse)
1875+
assert message.message.id == 1
1876+
finally:
1877+
await write_stream.aclose()
1878+
await read_stream.aclose()
1879+
1880+
1881+
@pytest.mark.anyio
1882+
async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1883+
"""Resumed GET responses use EOF draining instead of response.aclose()."""
1884+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1885+
response = _FakeStreamResponse()
1886+
1887+
class FakeReconnectionEventSource:
1888+
def __init__(self, response: _FakeStreamResponse) -> None:
1889+
self.response = response
1890+
1891+
async def aiter_sse(self):
1892+
yield _response_sse("abc")
1893+
yield ServerSentEvent(event="message", data="", id="drained-event")
1894+
1895+
@asynccontextmanager
1896+
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
1897+
yield FakeReconnectionEventSource(response)
1898+
1899+
monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)
1900+
1901+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1902+
async with httpx.AsyncClient() as client:
1903+
try:
1904+
ctx = _make_streamable_http_request_context("abc", client, write_stream)
1905+
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)
1906+
1907+
assert response.closed_by_transport is False
1908+
message = await read_stream.receive()
1909+
assert isinstance(message, SessionMessage)
1910+
assert isinstance(message.message, types.JSONRPCResponse)
1911+
assert message.message.id == "abc"
1912+
finally:
1913+
await write_stream.aclose()
1914+
await read_stream.aclose()
1915+
1916+
18061917
@pytest.mark.anyio
18071918
async def test_priming_event_not_sent_for_old_protocol_version():
18081919
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)