diff --git a/.changeset/interrupt-on-disconnect-fastapi.md b/.changeset/interrupt-on-disconnect-fastapi.md new file mode 100644 index 00000000..a965fc89 --- /dev/null +++ b/.changeset/interrupt-on-disconnect-fastapi.md @@ -0,0 +1,5 @@ +--- +'@e2b/code-interpreter-template': patch +--- + +Interrupt the kernel when the HTTP client disconnects mid-execution so the per-context lock is released and subsequent executions aren't blocked (#213). On the latest FastAPI (0.136.3) / Starlette (1.2.1), `StreamingResponse` no longer cancels the response body iterator on `http.disconnect` (ASGI spec 2.4+), so the server now detects the disconnect itself by polling `request.is_disconnected()` while streaming and interrupts the kernel. diff --git a/template/server/main.py b/template/server/main.py index 59215f57..62d263e3 100644 --- a/template/server/main.py +++ b/template/server/main.py @@ -123,6 +123,7 @@ async def post_execute(request: Request, exec_request: ExecutionRequest): exec_request.code, env_vars=exec_request.env_vars, access_token=request.headers.get("X-Access-Token", None), + request=request, ) ) diff --git a/template/server/messaging.py b/template/server/messaging.py index c51f8b21..970a07f7 100644 --- a/template/server/messaging.py +++ b/template/server/messaging.py @@ -13,6 +13,7 @@ Union, ) from pydantic import StrictStr +from starlette.requests import Request from websockets.client import WebSocketClientProtocol, connect from websockets.exceptions import ( ConnectionClosedError, @@ -39,6 +40,10 @@ KEEPALIVE_INTERVAL = 5 # seconds between keepalive pings during streaming +class _ClientDisconnected(Exception): + """Raised internally when the HTTP client disconnects mid-execution (#213).""" + + class Execution: def __init__(self, in_background: bool = False): self.queue = Queue[ @@ -251,27 +256,29 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): finally: del self._executions[message_id] - async def _wait_for_result(self, message_id: str): + async def _wait_for_result( + self, message_id: str, request: Optional[Request] = None + ): queue = self._executions[message_id].queue - # Use a timeout on queue.get() to periodically send keepalives. - # Without keepalives, the generator blocks indefinitely waiting for - # kernel output. If the client silently disappears (e.g. network - # failure), uvicorn can only detect the broken connection when it - # tries to write — so we force a write every KEEPALIVE_INTERVAL - # seconds. This ensures timely disconnect detection and kernel - # interrupt for abandoned executions (see #213). + # Wait with a timeout so that, even when the kernel emits no output, we + # periodically poll for client disconnects and write a keepalive. The + # latest Starlette no longer cancels this generator on disconnect, so + # an orphaned execution would otherwise keep holding self._lock (#213). while True: try: output = await asyncio.wait_for(queue.get(), timeout=KEEPALIVE_INTERVAL) except asyncio.TimeoutError: - # Yield a keepalive so Starlette writes to the socket. - # If the client has disconnected, the write fails and - # uvicorn delivers http.disconnect, which cancels this - # generator via CancelledError. + if request is not None and await request.is_disconnected(): + raise _ClientDisconnected() yield {"type": "keepalive"} continue + # Also check before forwarding output, in case the client left + # while the kernel was actively streaming. + if request is not None and await request.is_disconnected(): + raise _ClientDisconnected() + if output.type == OutputType.END_OF_EXECUTION: break @@ -320,6 +327,7 @@ async def execute( code: Union[str, StrictStr], env_vars: Dict[StrictStr, str], access_token: str, + request: Optional[Request] = None, ): if self._ws is None: raise Exception("WebSocket not connected") @@ -368,10 +376,12 @@ async def execute( logger.info( f"Sending code for the execution ({message_id}): {complete_code}" ) - request = self._get_execute_request( + # Don't rebind `request`: it holds the Starlette Request + # we poll for disconnects below (#213). + execute_request = self._get_execute_request( message_id, complete_code, False ) - await self._ws.send(request) + await self._ws.send(execute_request) break except (ConnectionClosedError, WebSocketException) as e: # Keep the last result, even if error @@ -392,22 +402,27 @@ async def execute( ) await execution.queue.put(UnexpectedEndOfExecution()) - # Stream the results. - # If the client disconnects (Starlette cancels the task), we - # interrupt the kernel so the next execution isn't blocked (#213). + # Stream the results. On client disconnect we interrupt the kernel + # so the lock is released and the next execution isn't blocked + # (#213). The disconnect surfaces either as _ClientDisconnected + # (latest Starlette, raised by _wait_for_result) or as + # CancelledError/GeneratorExit (older Starlette / generator teardown). try: - async for item in self._wait_for_result(message_id): + async for item in self._wait_for_result(message_id, request=request): yield item - except (asyncio.CancelledError, GeneratorExit): + except (asyncio.CancelledError, GeneratorExit, _ClientDisconnected) as e: logger.warning( f"Client disconnected during execution ({message_id}), interrupting kernel" ) - # Shield the interrupt from the ongoing cancellation so - # the HTTP request to the kernel actually completes. + # Shield so the interrupt completes even if we're being cancelled. try: await asyncio.shield(self.interrupt()) except asyncio.CancelledError: pass + # We detected the disconnect ourselves: unwind cleanly so the + # lock releases. A real cancellation/teardown must propagate. + if isinstance(e, _ClientDisconnected): + return raise finally: if message_id in self._executions: diff --git a/template/server/requirements.txt b/template/server/requirements.txt index b3e44b71..bc2b813f 100644 --- a/template/server/requirements.txt +++ b/template/server/requirements.txt @@ -1,4 +1,4 @@ -fastapi==0.111.0 +fastapi==0.136.3 httpx==0.28.1 websockets==12.0 uvicorn[standard]==0.30.1