|
27 | 27 | from starlette.requests import Request |
28 | 28 | from starlette.routing import Mount |
29 | 29 |
|
| 30 | +import mcp.client.streamable_http as streamable_http_module |
30 | 31 | from mcp import MCPError, types |
31 | 32 | from mcp.client.session import ClientSession |
32 | | -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
| 33 | +from mcp.client.streamable_http import ( |
| 34 | + RequestContext as StreamableHTTPClientRequestContext, |
| 35 | +) |
| 36 | +from mcp.client.streamable_http import ( |
| 37 | + StreamableHTTPTransport, |
| 38 | + streamable_http_client, |
| 39 | +) |
33 | 40 | from mcp.server import Server, ServerRequestContext |
34 | 41 | from mcp.server.streamable_http import ( |
35 | 42 | MCP_PROTOCOL_VERSION_HEADER, |
|
45 | 52 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
46 | 53 | from mcp.server.transport_security import TransportSecuritySettings |
47 | 54 | from mcp.shared._context import RequestContext |
48 | | -from mcp.shared._context_streams import create_context_streams |
| 55 | +from mcp.shared._context_streams import ContextSendStream, create_context_streams |
49 | 56 | from mcp.shared._httpx_utils import ( |
50 | 57 | MCP_DEFAULT_SSE_READ_TIMEOUT, |
51 | 58 | MCP_DEFAULT_TIMEOUT, |
@@ -1803,6 +1810,110 @@ async def test_handle_sse_event_skips_empty_data(): |
1803 | 1810 | await read_stream.aclose() |
1804 | 1811 |
|
1805 | 1812 |
|
| 1813 | +class _FakeStreamResponse(httpx.Response): |
| 1814 | + def __init__(self) -> None: |
| 1815 | + super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp")) |
| 1816 | + self.closed_by_transport = False |
| 1817 | + |
| 1818 | + async def aclose(self) -> None: |
| 1819 | + self.closed_by_transport = True |
| 1820 | + await super().aclose() |
| 1821 | + |
| 1822 | + |
| 1823 | +def _response_sse(request_id: int | str) -> ServerSentEvent: |
| 1824 | + return ServerSentEvent( |
| 1825 | + event="message", |
| 1826 | + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), |
| 1827 | + id="response-event", |
| 1828 | + ) |
| 1829 | + |
| 1830 | + |
| 1831 | +def _make_streamable_http_request_context( |
| 1832 | + request_id: int | str, |
| 1833 | + client: httpx.AsyncClient, |
| 1834 | + write_stream: ContextSendStream[SessionMessage | Exception], |
| 1835 | +) -> StreamableHTTPClientRequestContext: |
| 1836 | + return StreamableHTTPClientRequestContext( |
| 1837 | + client=client, |
| 1838 | + session_id=None, |
| 1839 | + session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")), |
| 1840 | + metadata=None, |
| 1841 | + read_stream_writer=write_stream, |
| 1842 | + ) |
| 1843 | + |
| 1844 | + |
| 1845 | +@pytest.mark.anyio |
| 1846 | +async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): |
| 1847 | + """Terminal POST SSE responses are drained instead of force-closed.""" |
| 1848 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1849 | + response = _FakeStreamResponse() |
| 1850 | + |
| 1851 | + class FakeEventSource: |
| 1852 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1853 | + self.response = response |
| 1854 | + |
| 1855 | + async def aiter_sse(self): |
| 1856 | + yield _response_sse(1) |
| 1857 | + yield ServerSentEvent(event="message", data="", id="drained-event") |
| 1858 | + |
| 1859 | + async def fail_reconnect(*args: Any, **kwargs: Any) -> None: |
| 1860 | + raise AssertionError("terminal responses should not reconnect after draining") |
| 1861 | + |
| 1862 | + monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource) |
| 1863 | + monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect) |
| 1864 | + |
| 1865 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1866 | + async with httpx.AsyncClient() as client: |
| 1867 | + try: |
| 1868 | + ctx = _make_streamable_http_request_context(1, client, write_stream) |
| 1869 | + await transport._handle_sse_response(response, ctx) |
| 1870 | + |
| 1871 | + assert response.closed_by_transport is False |
| 1872 | + message = await read_stream.receive() |
| 1873 | + assert isinstance(message, SessionMessage) |
| 1874 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1875 | + assert message.message.id == 1 |
| 1876 | + finally: |
| 1877 | + await write_stream.aclose() |
| 1878 | + await read_stream.aclose() |
| 1879 | + |
| 1880 | + |
| 1881 | +@pytest.mark.anyio |
| 1882 | +async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): |
| 1883 | + """Resumed GET responses use EOF draining instead of response.aclose().""" |
| 1884 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1885 | + response = _FakeStreamResponse() |
| 1886 | + |
| 1887 | + class FakeReconnectionEventSource: |
| 1888 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1889 | + self.response = response |
| 1890 | + |
| 1891 | + async def aiter_sse(self): |
| 1892 | + yield _response_sse("abc") |
| 1893 | + yield ServerSentEvent(event="message", data="", id="drained-event") |
| 1894 | + |
| 1895 | + @asynccontextmanager |
| 1896 | + async def fake_aconnect_sse(*args: Any, **kwargs: Any): |
| 1897 | + yield FakeReconnectionEventSource(response) |
| 1898 | + |
| 1899 | + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) |
| 1900 | + |
| 1901 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1902 | + async with httpx.AsyncClient() as client: |
| 1903 | + try: |
| 1904 | + ctx = _make_streamable_http_request_context("abc", client, write_stream) |
| 1905 | + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) |
| 1906 | + |
| 1907 | + assert response.closed_by_transport is False |
| 1908 | + message = await read_stream.receive() |
| 1909 | + assert isinstance(message, SessionMessage) |
| 1910 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1911 | + assert message.message.id == "abc" |
| 1912 | + finally: |
| 1913 | + await write_stream.aclose() |
| 1914 | + await read_stream.aclose() |
| 1915 | + |
| 1916 | + |
1806 | 1917 | @pytest.mark.anyio |
1807 | 1918 | async def test_priming_event_not_sent_for_old_protocol_version(): |
1808 | 1919 | """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" |
|
0 commit comments