Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,11 @@ async def arun(self) -> None:
self._state.id, e, persistence_dir=self._state.persistence_dir
) from e
finally:
self._cancel_token = None
# A cancelled token must stay observable: interrupted tool calls run
# in worker threads that can outlive arun() and still poll it. A
# fresh token is created on the next run().
if self._cancel_token is not None and not self._cancel_token.is_cancelled:
self._cancel_token = None
Comment on lines +1073 to +1077
Copy link
Copy Markdown
Member

@malhotra5 malhotra5 May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass a cancel token to each tool call. If arun can be called without a tool call finishing, interrupts should stop everything that's currently in flight

wdyt?

self._arun_task = None

def set_confirmation_policy(self, policy: ConfirmationPolicyBase) -> None:
Expand Down
33 changes: 31 additions & 2 deletions tests/sdk/conversation/test_interrupt.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,37 @@ async def test_interrupt_sets_cancel_token(tmp_path):
conv.interrupt()
await asyncio.wait_for(task, timeout=2.0)

# After arun finishes, token is cleared
assert conv._cancel_token is None
# After an interrupt the cancelled token is retained (not cleared) so tool
# threads that outlive arun() can still observe it.
assert conv._cancel_token is not None
assert conv._cancel_token.is_cancelled


@pytest.mark.asyncio
async def test_cancel_token_stays_observable_after_interrupt(tmp_path):
"""A tool polling conversation.cancel_token from a worker thread that
outlives arun() must still see the cancellation, not the None the finally
used to clear. A fresh token is swapped in on the next run."""
conv = _make_conversation(SlowLLM(sleep_seconds=60.0), tmp_path)

task = asyncio.create_task(conv.arun())
await asyncio.sleep(0.05)
conv.interrupt()
await asyncio.wait_for(task, timeout=2.0)

# arun() has run its finally; a late poll via the public property (what
# tools use) must still observe the cancellation.
assert conv.cancel_token is not None
assert conv.cancel_token.is_cancelled

# The next run replaces it with a fresh, uncancelled token.
conv.send_message("again")
resumed = asyncio.create_task(conv.arun())
await asyncio.sleep(0.05)
assert conv.cancel_token is not None
assert not conv.cancel_token.is_cancelled
conv.interrupt()
await asyncio.wait_for(resumed, timeout=2.0)


# ── ParallelToolExecutor cancellation tests ───────────────────────────
Expand Down
Loading