@@ -185,17 +185,6 @@ async def _handle_list_tools(
185185 description = "Tool that sends notification1, closes stream, sends notification2, notification3" ,
186186 input_schema = {"type" : "object" , "properties" : {}},
187187 ),
188- Tool (
189- name = "tool_with_multiple_stream_closes" ,
190- description = "Tool that closes SSE stream multiple times during execution" ,
191- input_schema = {
192- "type" : "object" ,
193- "properties" : {
194- "checkpoints" : {"type" : "integer" , "default" : 3 },
195- "sleep_time" : {"type" : "number" , "default" : 0.2 },
196- },
197- },
198- ),
199188 Tool (
200189 name = "tool_with_standalone_stream_close" ,
201190 description = "Tool that closes standalone GET stream mid-operation" ,
@@ -207,7 +196,6 @@ async def _handle_list_tools(
207196
208197async def _handle_call_tool (ctx : ServerRequestContext [ServerState ], params : CallToolRequestParams ) -> CallToolResult :
209198 name = params .name
210- args = params .arguments or {}
211199
212200 # When the tool is called, send a notification to test GET stream
213201 if name == "test_tool_with_standalone_notification" :
@@ -301,25 +289,6 @@ async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: Call
301289 )
302290 return CallToolResult (content = [TextContent (type = "text" , text = "All notifications sent" )])
303291
304- elif name == "tool_with_multiple_stream_closes" :
305- num_checkpoints = args .get ("checkpoints" , 3 )
306- sleep_time = args .get ("sleep_time" , 0.2 )
307-
308- for i in range (num_checkpoints ):
309- await ctx .session .send_log_message (
310- level = "info" ,
311- data = f"checkpoint_{ i } " ,
312- logger = "multi_close_tool" ,
313- related_request_id = ctx .request_id ,
314- )
315-
316- assert ctx .close_sse_stream is not None
317- await ctx .close_sse_stream ()
318-
319- await anyio .sleep (sleep_time )
320-
321- return CallToolResult (content = [TextContent (type = "text" , text = f"Completed { num_checkpoints } checkpoints" )])
322-
323292 elif name == "tool_with_standalone_stream_close" :
324293 await ctx .session .send_resource_updated (uri = "http://notification_1" )
325294 await anyio .sleep (0.1 )
@@ -350,19 +319,21 @@ async def running_app(
350319 is_json_response_enabled : bool = False ,
351320 event_store : EventStore | None = None ,
352321 retry_interval : int | None = None ,
322+ server : Server [Any ] | None = None ,
353323) -> AsyncIterator [Starlette ]:
354324 """Serve the test server's streamable HTTP app in process for the duration.
355325
356326 Args:
357327 is_json_response_enabled: If True, use JSON responses instead of SSE streams.
358328 event_store: Optional event store for testing resumability.
359329 retry_interval: Retry interval in milliseconds for SSE polling.
330+ server: Server to mount; defaults to the file's shared test server.
360331 """
361332 # DNS-rebinding protection validates Host/Origin headers against a network attack that cannot
362333 # exist for an in-process app; the protection itself is pinned by
363334 # tests/server/test_streamable_http_security.py.
364335 session_manager = StreamableHTTPSessionManager (
365- app = _create_server (),
336+ app = server if server is not None else _create_server (),
366337 event_store = event_store ,
367338 json_response = is_json_response_enabled ,
368339 security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = False ),
@@ -912,7 +883,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session
912883 """A tool call reaches the handler and returns its content."""
913884 # First list tools
914885 tools = await initialized_client_session .list_tools ()
915- assert len (tools .tools ) == 9
886+ assert len (tools .tools ) == 8
916887 assert tools .tools [0 ].name == "test_tool"
917888
918889 # Call the tool
@@ -945,7 +916,7 @@ async def test_streamable_http_client_session_persistence(basic_app: Starlette)
945916
946917 # Make multiple requests to verify session persistence
947918 tools = await session .list_tools ()
948- assert len (tools .tools ) == 9
919+ assert len (tools .tools ) == 8
949920
950921 # Read a resource
951922 resource = await session .read_resource (uri = "foobar://test-persist" )
@@ -970,7 +941,7 @@ async def test_streamable_http_client_json_response(json_app: Starlette) -> None
970941
971942 # Check tool listing
972943 tools = await session .list_tools ()
973- assert len (tools .tools ) == 9
944+ assert len (tools .tools ) == 8
974945
975946 # Call a tool and verify JSON response handling
976947 result = await session .call_tool ("test_tool" , {})
@@ -1056,7 +1027,7 @@ async def test_streamable_http_client_session_termination(basic_app: Starlette)
10561027
10571028 # Make a request to confirm session is working
10581029 tools = await session .list_tools ()
1059- assert len (tools .tools ) == 9
1030+ assert len (tools .tools ) == 8
10601031
10611032 async with make_client (basic_app , headers = headers ) as httpx_client2 :
10621033 async with streamable_http_client (f"{ BASE_URL } /mcp" , http_client = httpx_client2 ) as (
@@ -1117,7 +1088,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11171088
11181089 # Make a request to confirm session is working
11191090 tools = await session .list_tools ()
1120- assert len (tools .tools ) == 9
1091+ assert len (tools .tools ) == 8
11211092
11221093 async with make_client (basic_app , headers = headers ) as httpx_client2 :
11231094 async with streamable_http_client (f"{ BASE_URL } /mcp" , http_client = httpx_client2 ) as (
@@ -1943,27 +1914,58 @@ async def message_handler(
19431914
19441915
19451916@pytest .mark .anyio
1946- async def test_streamable_http_multiple_reconnections (
1947- event_app : tuple [SimpleEventStore , Starlette ],
1948- ) -> None :
1949- """Verify multiple close_sse_stream() calls each trigger a client reconnect.
1917+ async def test_streamable_http_multiple_reconnections () -> None :
1918+ """Every close_sse_stream() severs a live connection and triggers its own client reconnect.
19501919
1951- Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure
1952- client has time to reconnect before the next checkpoint.
1920+ The tool closes its SSE stream three times; before each next cycle it waits until the
1921+ client has observed the previous cycle's two new resumption tokens (the checkpoint and the
1922+ new connection's priming event). The priming event is sent only after the server has
1923+ re-registered the resumed stream, so once the client holds its token the next close is
1924+ guaranteed to sever a live connection rather than silently no-op — making the exact token
1925+ count below a consequence of causality, not timing margins. This pins reconnect-per-close
1926+ accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval.
19531927
19541928 With 3 checkpoints, we expect 8 resumption tokens:
19551929 - 1 priming (initial POST connection)
19561930 - 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2)
19571931 - 3 priming (one per reconnect after each close)
19581932 - 1 response
19591933 """
1960- _ , app = event_app
19611934 resumption_tokens : list [str ] = []
1935+ # milestones[n] fires when the client has observed n tokens. After the initial priming
1936+ # (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the
1937+ # reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens.
1938+ milestones = {3 : anyio .Event (), 5 : anyio .Event (), 7 : anyio .Event ()}
19621939
19631940 async def on_resumption_token (token : str ) -> None :
19641941 resumption_tokens .append (token )
1942+ milestone = milestones .get (len (resumption_tokens ))
1943+ if milestone is not None :
1944+ milestone .set ()
1945+
1946+ async def handle_call_tool (ctx : ServerRequestContext , params : CallToolRequestParams ) -> CallToolResult :
1947+ assert params .name == "multi_close_tool"
1948+ for i , milestone in enumerate (milestones .values ()):
1949+ await ctx .session .send_log_message (
1950+ level = "info" ,
1951+ data = f"checkpoint_{ i } " ,
1952+ logger = "multi_close_tool" ,
1953+ related_request_id = ctx .request_id ,
1954+ )
1955+ assert ctx .close_sse_stream is not None
1956+ await ctx .close_sse_stream ()
1957+ # Client and server share one event loop, so the tool can wait directly on the
1958+ # client-side callback observing the reconnect.
1959+ with anyio .fail_after (5 ):
1960+ await milestone .wait ()
1961+ return CallToolResult (content = [TextContent (type = "text" , text = "Completed 3 checkpoints" )])
1962+
1963+ server = Server ("multi_reconnect_server" , on_call_tool = handle_call_tool )
19651964
19661965 async with (
1966+ # retry_interval is small to keep the test fast, but nonzero so each dying connection
1967+ # finishes unwinding before its replacement registers.
1968+ running_app (event_store = SimpleEventStore (), retry_interval = 50 , server = server ) as app ,
19671969 make_client (app ) as http_client ,
19681970 streamable_http_client (f"{ BASE_URL } /mcp" , http_client = http_client ) as (read_stream , write_stream ),
19691971 ClientSession (read_stream , write_stream ) as session ,
@@ -1975,11 +1977,7 @@ async def on_resumption_token(token: str) -> None:
19751977 result = await session .send_request (
19761978 types .CallToolRequest (
19771979 method = "tools/call" ,
1978- params = types .CallToolRequestParams (
1979- name = "tool_with_multiple_stream_closes" ,
1980- # retry_interval=500ms, so sleep 600ms to ensure reconnect completes
1981- arguments = {"checkpoints" : 3 , "sleep_time" : 0.6 },
1982- ),
1980+ params = types .CallToolRequestParams (name = "multi_close_tool" , arguments = {}),
19831981 ),
19841982 types .CallToolResult ,
19851983 metadata = metadata ,
0 commit comments