diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 9cb7f05fe3..288daaabea 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -110,6 +110,7 @@ from .run import ( ReasoningItemIdPolicy, RunConfig, + RunInterruptSignal, Runner, ToolErrorFormatter, ToolErrorFormatterArgs, @@ -447,6 +448,7 @@ def enable_verbose_stdout_logging(): "RunResultStreaming", "ResponsesWebSocketSession", "RunConfig", + "RunInterruptSignal", "ReasoningItemIdPolicy", "ToolExecutionConfig", "ToolErrorFormatter", diff --git a/src/agents/result.py b/src/agents/result.py index 8ae407003a..2ab733e934 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -367,6 +367,11 @@ class RunResult(RunResultBase): interruptions: list[ToolApprovalItem] = field(default_factory=list) """Pending tool approval requests (interruptions) for this run.""" + interrupted: bool = False + """Whether the run was interrupted via ``RunInterruptSignal`` before reaching a natural + final output. When True, ``final_output`` may be ``None`` and the result contains the + partial state accumulated up to the point of interruption.""" + def __post_init__(self) -> None: self._last_agent_ref = weakref.ref(self._last_agent) @@ -498,6 +503,9 @@ class RunResultStreaming(RunResultBase): """The last processed model response. This is needed for resuming from interruptions.""" interruptions: list[ToolApprovalItem] = field(default_factory=list) """Pending tool approval requests (interruptions) for this run.""" + interrupted: bool = False + """Whether the run was interrupted via ``RunInterruptSignal`` before reaching a natural + final output. When True, ``final_output`` may be ``None`` and ``is_complete`` is True.""" _waiting_on_event_queue: bool = field(default=False, repr=False) _current_turn_persisted_item_count: int = 0 @@ -673,6 +681,7 @@ def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None """ # Store the cancel mode for the background task to check self._cancel_mode = mode + self.interrupted = True if mode == "immediate": # Existing behavior - immediate shutdown diff --git a/src/agents/run.py b/src/agents/run.py index 014271a5ea..c375d02e2b 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -37,6 +37,7 @@ ModelInputData, ReasoningItemIdPolicy, RunConfig, + RunInterruptSignal, RunOptions, ToolErrorFormatter, ToolErrorFormatterArgs, @@ -131,6 +132,7 @@ "AgentRunner", "Runner", "RunConfig", + "RunInterruptSignal", "RunOptions", "RunState", "RunContextWrapper", @@ -766,6 +768,37 @@ def _finalize_result(result: RunResult) -> RunResult: try: while True: + # Check for external interrupt request before starting a new turn. + if run_config.interrupt_signal and run_config.interrupt_signal.is_interrupted: + logger.debug( + "Run interrupted via RunInterruptSignal at turn %s", current_turn + ) + output_guardrail_results: list[OutputGuardrailResult] = [] + result = RunResult( + input=copy_input_items(original_input), + new_items=list(session_items), + raw_responses=list(model_responses), + final_output=None, + _last_agent=current_agent, + input_guardrail_results=list(input_guardrail_results), + output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=list(tool_input_guardrail_results), + tool_output_guardrail_results=list(tool_output_guardrail_results), + context_wrapper=context_wrapper, + interruptions=[], + interrupted=True, + max_turns=max_turns, + ) + result._current_turn = current_turn + result._model_input_items = list(generated_items) + result._replay_from_model_input_items = list(generated_items) != list( + session_items + ) + if run_state is not None: + result._trace_state = run_state._trace_state + result._original_input = copy_input_items(original_input) + return _finalize_result(result) + resuming_turn = is_resumed_state all_input_guardrails = ( starting_agent.input_guardrails + (run_config.input_guardrails or []) diff --git a/src/agents/run_config.py b/src/agents/run_config.py index fcc9b01315..7ee963cf64 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os from collections.abc import Callable from dataclasses import dataclass, field @@ -200,6 +201,46 @@ class SandboxRunConfig: """ +@dataclass +class RunInterruptSignal: + """Signal that can be used to request graceful interruption of an active agent run. + + Pass an instance of this class as ``RunConfig.interrupt_signal``, then call + ``interrupt()`` from any thread or task to request that the run complete its + current turn and return partial results instead of continuing to the next turn. + + Example usage:: + + interrupt = RunInterruptSignal() + + async def main(): + task = asyncio.create_task(Runner.run(agent, "input", run_config=RunConfig(interrupt_signal=interrupt))) + await asyncio.sleep(2) + interrupt.interrupt() # Signal the run to wrap up gracefully + result = await task + """ + + _event: asyncio.Event = field(default_factory=asyncio.Event) + """Internal asyncio.Event used to signal interruption.""" + + def interrupt(self) -> None: + """Request the agent run to gracefully interrupt after the current turn. + + Thread-safe: can be called from any thread or asyncio task. + """ + # asyncio.Event.set() is thread-safe in Python 3.8+ + self._event.set() + + @property + def is_interrupted(self) -> bool: + """Return True if an interruption has been requested.""" + return self._event.is_set() + + def clear(self) -> None: + """Reset the interrupt signal so the same instance can be reused.""" + self._event.clear() + + @dataclass class RunConfig: """Configures settings for the entire agent run.""" @@ -209,6 +250,11 @@ class RunConfig: agent. The model_provider passed in below must be able to resolve this model name. """ + interrupt_signal: RunInterruptSignal | None = None + """Optional signal that can be set from outside the run to request graceful interruption. + When set, the agent loop will complete the current turn and return partial results. + """ + model_provider: ModelProvider = field(default_factory=MultiProvider) """The model provider to use when looking up string model names. Defaults to OpenAI.""" @@ -368,6 +414,7 @@ class RunOptions(TypedDict, Generic[TContext]): "ModelInputData", "ReasoningItemIdPolicy", "RunConfig", + "RunInterruptSignal", "RunOptions", "SandboxArchiveLimits", "SandboxConcurrencyLimits", diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..953d81b91c 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -669,6 +669,15 @@ async def _save_stream_items_without_count( try: while True: + # Check for external interrupt request before starting a new streaming turn. + if run_config.interrupt_signal and run_config.interrupt_signal.is_interrupted: + logger.debug( + "Streaming run interrupted via RunInterruptSignal at turn %s", current_turn + ) + streamed_result.interrupted = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + return + all_input_guardrails = ( starting_agent.input_guardrails + (run_config.input_guardrails or []) if current_turn == 0 and not is_resumed_state