Skip to content

Commit f18d65b

Browse files
committed
Gate the multiple-reconnections test on observed reconnects, not sleeps
The tool slept a fixed duration after each close_sse_stream() so the client's retry-interval reconnect could land before the next close. That is a timing margin, not synchronization: a reconnect delayed past the margin made the close a silent no-op, two cycles merged, and the exact resumption-token count failed. The tool now waits, bounded by fail_after(5), for the client-side resumption-token callback to observe each cycle's two new tokens (the checkpoint and the new connection's priming event). The priming event is sent only after the server has re-registered the resumed stream, so once the client holds its token the next close is guaranteed to sever a live connection. The token count becomes a consequence of causality rather than margins, no sleep remains, the retry interval drops from 500ms to 50ms, and a genuinely failed reconnect now fails loudly at the timeout instead of silently merging cycles. The test now defines its own server with the gated tool inline, since the gate closes over per-test state; the shared server's tool_with_multiple_stream_closes had no other users and is removed, which moves the tools/list count assertions from 9 to 8.
1 parent de72011 commit f18d65b

1 file changed

Lines changed: 47 additions & 49 deletions

File tree

tests/shared/test_streamable_http.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,6 @@ async def _handle_list_tools(
185185
description="Tool that sends notification1, closes stream, sends notification2, notification3",
186186
input_schema={"type": "object", "properties": {}},
187187
),
188-
Tool(
189-
name="tool_with_multiple_stream_closes",
190-
description="Tool that closes SSE stream multiple times during execution",
191-
input_schema={
192-
"type": "object",
193-
"properties": {
194-
"checkpoints": {"type": "integer", "default": 3},
195-
"sleep_time": {"type": "number", "default": 0.2},
196-
},
197-
},
198-
),
199188
Tool(
200189
name="tool_with_standalone_stream_close",
201190
description="Tool that closes standalone GET stream mid-operation",
@@ -207,7 +196,6 @@ async def _handle_list_tools(
207196

208197
async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult:
209198
name = params.name
210-
args = params.arguments or {}
211199

212200
# When the tool is called, send a notification to test GET stream
213201
if name == "test_tool_with_standalone_notification":
@@ -301,25 +289,6 @@ async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: Call
301289
)
302290
return CallToolResult(content=[TextContent(type="text", text="All notifications sent")])
303291

304-
elif name == "tool_with_multiple_stream_closes":
305-
num_checkpoints = args.get("checkpoints", 3)
306-
sleep_time = args.get("sleep_time", 0.2)
307-
308-
for i in range(num_checkpoints):
309-
await ctx.session.send_log_message(
310-
level="info",
311-
data=f"checkpoint_{i}",
312-
logger="multi_close_tool",
313-
related_request_id=ctx.request_id,
314-
)
315-
316-
assert ctx.close_sse_stream is not None
317-
await ctx.close_sse_stream()
318-
319-
await anyio.sleep(sleep_time)
320-
321-
return CallToolResult(content=[TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")])
322-
323292
elif name == "tool_with_standalone_stream_close":
324293
await ctx.session.send_resource_updated(uri="http://notification_1")
325294
await anyio.sleep(0.1)
@@ -350,19 +319,21 @@ async def running_app(
350319
is_json_response_enabled: bool = False,
351320
event_store: EventStore | None = None,
352321
retry_interval: int | None = None,
322+
server: Server[Any] | None = None,
353323
) -> AsyncIterator[Starlette]:
354324
"""Serve the test server's streamable HTTP app in process for the duration.
355325
356326
Args:
357327
is_json_response_enabled: If True, use JSON responses instead of SSE streams.
358328
event_store: Optional event store for testing resumability.
359329
retry_interval: Retry interval in milliseconds for SSE polling.
330+
server: Server to mount; defaults to the file's shared test server.
360331
"""
361332
# DNS-rebinding protection validates Host/Origin headers against a network attack that cannot
362333
# exist for an in-process app; the protection itself is pinned by
363334
# tests/server/test_streamable_http_security.py.
364335
session_manager = StreamableHTTPSessionManager(
365-
app=_create_server(),
336+
app=server if server is not None else _create_server(),
366337
event_store=event_store,
367338
json_response=is_json_response_enabled,
368339
security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False),
@@ -912,7 +883,7 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session
912883
"""A tool call reaches the handler and returns its content."""
913884
# First list tools
914885
tools = await initialized_client_session.list_tools()
915-
assert len(tools.tools) == 9
886+
assert len(tools.tools) == 8
916887
assert tools.tools[0].name == "test_tool"
917888

918889
# Call the tool
@@ -945,7 +916,7 @@ async def test_streamable_http_client_session_persistence(basic_app: Starlette)
945916

946917
# Make multiple requests to verify session persistence
947918
tools = await session.list_tools()
948-
assert len(tools.tools) == 9
919+
assert len(tools.tools) == 8
949920

950921
# Read a resource
951922
resource = await session.read_resource(uri="foobar://test-persist")
@@ -970,7 +941,7 @@ async def test_streamable_http_client_json_response(json_app: Starlette) -> None
970941

971942
# Check tool listing
972943
tools = await session.list_tools()
973-
assert len(tools.tools) == 9
944+
assert len(tools.tools) == 8
974945

975946
# Call a tool and verify JSON response handling
976947
result = await session.call_tool("test_tool", {})
@@ -1056,7 +1027,7 @@ async def test_streamable_http_client_session_termination(basic_app: Starlette)
10561027

10571028
# Make a request to confirm session is working
10581029
tools = await session.list_tools()
1059-
assert len(tools.tools) == 9
1030+
assert len(tools.tools) == 8
10601031

10611032
async with make_client(basic_app, headers=headers) as httpx_client2:
10621033
async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as (
@@ -1117,7 +1088,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11171088

11181089
# Make a request to confirm session is working
11191090
tools = await session.list_tools()
1120-
assert len(tools.tools) == 9
1091+
assert len(tools.tools) == 8
11211092

11221093
async with make_client(basic_app, headers=headers) as httpx_client2:
11231094
async with streamable_http_client(f"{BASE_URL}/mcp", http_client=httpx_client2) as (
@@ -1943,27 +1914,58 @@ async def message_handler(
19431914

19441915

19451916
@pytest.mark.anyio
1946-
async def test_streamable_http_multiple_reconnections(
1947-
event_app: tuple[SimpleEventStore, Starlette],
1948-
) -> None:
1949-
"""Verify multiple close_sse_stream() calls each trigger a client reconnect.
1917+
async def test_streamable_http_multiple_reconnections() -> None:
1918+
"""Every close_sse_stream() severs a live connection and triggers its own client reconnect.
19501919
1951-
Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure
1952-
client has time to reconnect before the next checkpoint.
1920+
The tool closes its SSE stream three times; before each next cycle it waits until the
1921+
client has observed the previous cycle's two new resumption tokens (the checkpoint and the
1922+
new connection's priming event). The priming event is sent only after the server has
1923+
re-registered the resumed stream, so once the client holds its token the next close is
1924+
guaranteed to sever a live connection rather than silently no-op — making the exact token
1925+
count below a consequence of causality, not timing margins. This pins reconnect-per-close
1926+
accounting; reconnect *latency* is pinned by test_streamable_http_client_respects_retry_interval.
19531927
19541928
With 3 checkpoints, we expect 8 resumption tokens:
19551929
- 1 priming (initial POST connection)
19561930
- 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2)
19571931
- 3 priming (one per reconnect after each close)
19581932
- 1 response
19591933
"""
1960-
_, app = event_app
19611934
resumption_tokens: list[str] = []
1935+
# milestones[n] fires when the client has observed n tokens. After the initial priming
1936+
# (token 1), each completed cycle i contributes exactly two tokens — checkpoint_i and the
1937+
# reconnect's priming, in either order — so cycle i is complete at 3 + 2i tokens.
1938+
milestones = {3: anyio.Event(), 5: anyio.Event(), 7: anyio.Event()}
19621939

19631940
async def on_resumption_token(token: str) -> None:
19641941
resumption_tokens.append(token)
1942+
milestone = milestones.get(len(resumption_tokens))
1943+
if milestone is not None:
1944+
milestone.set()
1945+
1946+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
1947+
assert params.name == "multi_close_tool"
1948+
for i, milestone in enumerate(milestones.values()):
1949+
await ctx.session.send_log_message(
1950+
level="info",
1951+
data=f"checkpoint_{i}",
1952+
logger="multi_close_tool",
1953+
related_request_id=ctx.request_id,
1954+
)
1955+
assert ctx.close_sse_stream is not None
1956+
await ctx.close_sse_stream()
1957+
# Client and server share one event loop, so the tool can wait directly on the
1958+
# client-side callback observing the reconnect.
1959+
with anyio.fail_after(5):
1960+
await milestone.wait()
1961+
return CallToolResult(content=[TextContent(type="text", text="Completed 3 checkpoints")])
1962+
1963+
server = Server("multi_reconnect_server", on_call_tool=handle_call_tool)
19651964

19661965
async with (
1966+
# retry_interval is small to keep the test fast, but nonzero so each dying connection
1967+
# finishes unwinding before its replacement registers.
1968+
running_app(event_store=SimpleEventStore(), retry_interval=50, server=server) as app,
19671969
make_client(app) as http_client,
19681970
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream),
19691971
ClientSession(read_stream, write_stream) as session,
@@ -1975,11 +1977,7 @@ async def on_resumption_token(token: str) -> None:
19751977
result = await session.send_request(
19761978
types.CallToolRequest(
19771979
method="tools/call",
1978-
params=types.CallToolRequestParams(
1979-
name="tool_with_multiple_stream_closes",
1980-
# retry_interval=500ms, so sleep 600ms to ensure reconnect completes
1981-
arguments={"checkpoints": 3, "sleep_time": 0.6},
1982-
),
1980+
params=types.CallToolRequestParams(name="multi_close_tool", arguments={}),
19831981
),
19841982
types.CallToolResult,
19851983
metadata=metadata,

0 commit comments

Comments
 (0)