From 9485f03a53b4c3f8cb6681e4b52c83e407aead1d Mon Sep 17 00:00:00 2001 From: Qiyao Qin Date: Thu, 4 Jun 2026 17:39:23 -0700 Subject: [PATCH 1/2] Expose sandbox error retryability --- .../extensions/sandbox/blaxel/sandbox.py | 94 ++++++++++- .../extensions/sandbox/cloudflare/sandbox.py | 43 +++-- .../extensions/sandbox/daytona/sandbox.py | 152 +++++++++++++---- src/agents/extensions/sandbox/e2b/sandbox.py | 131 ++++++++++----- .../extensions/sandbox/modal/sandbox.py | 91 +++++++++- .../extensions/sandbox/runloop/sandbox.py | 82 +++++++++ .../extensions/sandbox/vercel/sandbox.py | 104 +++++++++--- src/agents/sandbox/errors.py | 46 ++++++ src/agents/sandbox/sandboxes/docker.py | 14 +- src/agents/sandbox/session/events.py | 1 + src/agents/sandbox/session/sandbox_session.py | 4 + tests/extensions/sandbox/test_blaxel.py | 107 ++++++++++++ tests/extensions/sandbox/test_cloudflare.py | 94 +++++++++++ tests/extensions/sandbox/test_daytona.py | 140 +++++++++++++++- tests/extensions/sandbox/test_e2b.py | 119 +++++++++++--- tests/extensions/sandbox/test_modal.py | 123 ++++++++++++++ tests/extensions/sandbox/test_runloop.py | 140 ++++++++++++++++ tests/extensions/sandbox/test_vercel.py | 155 ++++++++++++++++++ tests/sandbox/test_errors.py | 62 +++++++ tests/sandbox/test_session_sinks.py | 42 ++++- 20 files changed, 1613 insertions(+), 131 deletions(-) create mode 100644 tests/sandbox/test_errors.py diff --git a/src/agents/extensions/sandbox/blaxel/sandbox.py b/src/agents/extensions/sandbox/blaxel/sandbox.py index 89197af895..cc28a182a2 100644 --- a/src/agents/extensions/sandbox/blaxel/sandbox.py +++ b/src/agents/extensions/sandbox/blaxel/sandbox.py @@ -63,6 +63,7 @@ TRANSIENT_HTTP_STATUS_CODES, exception_chain_contains_type, exception_chain_has_status_code, + iter_exception_chain, retry_async, ) from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes @@ -72,6 +73,85 @@ logger = logging.getLogger(__name__) +# Blaxel documents structured API error codes and retryability at: +# https://docs.blaxel.ai/troubleshooting/error-codes +_BLAXEL_ERROR_CODE_RETRYABLE: dict[str, bool] = { + "ROUTE_NOT_FOUND": False, # 404 + "WORKLOAD_NOT_FOUND": False, # 404 + "WORKSPACE_NOT_FOUND": False, # 404 + "WORKLOAD_UNAVAILABLE": True, # 404 + "AUTHENTICATION_REQUIRED": False, # 401 + "AUTHENTICATION_FAILED": False, # 401 + "FORBIDDEN": False, # 403 + "BAD_REQUEST": False, # 400 + "USAGE_LIMIT_EXCEEDED": False, # 402 + "POLICY_VIOLATION": False, # varies +} + + +def _coerce_mapping(value: object) -> dict[str, object] | None: + if isinstance(value, dict): + return {str(key): item for key, item in value.items()} + if isinstance(value, str): + try: + decoded = json.loads(value) + except json.JSONDecodeError: + return None + if isinstance(decoded, dict): + return {str(key): item for key, item in decoded.items()} + return None + + +def _blaxel_error_payload(error: BaseException) -> dict[str, object] | None: + for candidate in iter_exception_chain(error): + for attr in ("body", "payload"): + payload = _coerce_mapping(getattr(candidate, attr, None)) + if payload is not None: + return payload + + response = getattr(candidate, "response", None) + response_json = getattr(response, "json", None) + if callable(response_json): + try: + payload = _coerce_mapping(response_json()) + except Exception: + payload = None + if payload is not None: + return payload + + response_text = getattr(response, "text", None) + payload = _coerce_mapping(response_text) + if payload is not None: + return payload + + return None + + +def _blaxel_structured_error(error: BaseException) -> dict[str, object] | None: + payload = _blaxel_error_payload(error) + if payload is None: + return None + nested = payload.get("error") + if isinstance(nested, dict): + return {str(key): value for key, value in nested.items()} + return payload + + +def _blaxel_provider_retryability(error: BaseException) -> tuple[bool | None, str | None]: + structured_error = _blaxel_structured_error(error) + if structured_error is not None: + retryable = structured_error.get("retryable") + if isinstance(retryable, bool): + code = structured_error.get("code") + return retryable, str(code) if isinstance(code, str) and code else None + + code = structured_error.get("code") + if isinstance(code, str): + return _BLAXEL_ERROR_CODE_RETRYABLE.get(code), code + + return None, None + + def _blaxel_provider_error_detail(error: BaseException) -> str | None: message = str(error) status = getattr(error, "status_code", None) or getattr(error, "status", None) @@ -91,15 +171,26 @@ def _blaxel_exec_transport_error( ) -> ExecTransportError: detail = _blaxel_provider_error_detail(cause) context: dict[str, object] = {"backend": "blaxel"} + retryable, provider_error_code = _blaxel_provider_retryability(cause) + if provider_error_code is not None: + context["provider_error_code"] = provider_error_code if detail: context["provider_error"] = detail status = getattr(cause, "status_code", None) or getattr(cause, "status", None) if isinstance(status, int): context["http_status"] = status + if retryable is None and status in TRANSIENT_HTTP_STATUS_CODES: + retryable = True message = "Blaxel exec failed" if detail: message = f"{message}: {detail}" - return ExecTransportError(command=command, context=context, cause=cause, message=message) + return ExecTransportError( + command=command, + context=context, + cause=cause, + message=message, + retryable=retryable, + ) def _import_blaxel_sdk() -> Any: @@ -583,6 +674,7 @@ async def persist_workspace(self) -> io.IOBase: "reason": "tar_failed", "output": result.stderr.decode("utf-8", errors="replace"), }, + retryable=False, ) raw_data: Any = await self._sandbox.fs.read_binary(tar_path) if isinstance(raw_data, str): diff --git a/src/agents/extensions/sandbox/cloudflare/sandbox.py b/src/agents/extensions/sandbox/cloudflare/sandbox.py index d34d698294..a3f94ec591 100644 --- a/src/agents/extensions/sandbox/cloudflare/sandbox.py +++ b/src/agents/extensions/sandbox/cloudflare/sandbox.py @@ -62,17 +62,20 @@ from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot from ....sandbox.types import ExecResult, ExposedPortEndpoint, User -from ....sandbox.util.retry import ( - TRANSIENT_HTTP_STATUS_CODES, - exception_chain_has_status_code, - retry_async, -) +from ....sandbox.util.retry import retry_async from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes from ....sandbox.workspace_paths import coerce_posix_path, posix_path_as_path, sandbox_path_str _DEFAULT_EXEC_TIMEOUT_S = 30.0 _DEFAULT_REQUEST_TIMEOUT_S = 120.0 _MAX_ERROR_BODY_CHARS = 2000 +# Cloudflare documents sandbox HTTP status retry semantics at: +# https://cloudflare-sandbox-sdk.mintlify.app/advanced/error-handling#http-status-code-semantics +_CLOUDFLARE_HTTP_STATUS_RETRYABLE: dict[int, bool] = { + 400: False, + 500: False, + 503: True, +} logger = logging.getLogger(__name__) @@ -141,6 +144,12 @@ def _cloudflare_error_context( return context +def _cloudflare_retryability_for_status(status: int | None) -> bool | None: + if status is None: + return None + return _CLOUDFLARE_HTTP_STATUS_RETRYABLE.get(status) + + def _cloudflare_exec_error_detail(error: ExecTransportError) -> str | None: detail = error.context.get("provider_error") if isinstance(detail, str) and detail: @@ -164,15 +173,17 @@ def _cloudflare_transport_error( ) -> ExecTransportError: detail = str(cause) provider_error = f"{type(cause).__name__}: {detail}" if detail else type(cause).__name__ + context: dict[str, object] = { + "backend": "cloudflare", + "operation": operation, + "provider_error": provider_error, + } return ExecTransportError( command=command, - context={ - "backend": "cloudflare", - "operation": operation, - "provider_error": provider_error, - }, + context=context, cause=cause, message=f"Cloudflare {operation} transport failed: {provider_error}", + retryable=None, ) @@ -181,7 +192,7 @@ def _is_transient_workspace_error(exc: BaseException) -> bool: if not isinstance(exc, WorkspaceArchiveReadError | WorkspaceArchiveWriteError): return False status = exc.context.get("http_status") - return isinstance(status, int) and status in TRANSIENT_HTTP_STATUS_CODES + return isinstance(status, int) and _cloudflare_retryability_for_status(status) is True @dataclass @@ -731,6 +742,7 @@ async def _exec_internal( context=_cloudflare_error_context(status=resp.status, detail=detail), cause=Exception(message), message=message, + retryable=_cloudflare_retryability_for_status(resp.status), ) stdout_parts: list[bytes] = [] @@ -802,6 +814,7 @@ async def _exec_internal( ), cause=Exception(message), message=message, + retryable=_cloudflare_retryability_for_status(resp.status), ) except asyncio.TimeoutError as e: @@ -1152,6 +1165,7 @@ async def read(self, path: Path | str, *, user: str | User | None = None) -> io. "http_status": resp.status, "message": body.get("error", "path escapes /workspace"), }, + retryable=False, ) if resp.status != 200: body = {} @@ -1166,6 +1180,7 @@ async def read(self, path: Path | str, *, user: str | User | None = None) -> io. "http_status": resp.status, "message": body.get("error", f"HTTP {resp.status}"), }, + retryable=_cloudflare_retryability_for_status(resp.status), ) return io.BytesIO(self._decode_streamed_payload(await resp.read())) except (WorkspaceReadNotFoundError, WorkspaceArchiveReadError): @@ -1219,6 +1234,7 @@ async def write( "http_status": resp.status, "message": body.get("error", "path escapes /workspace"), }, + retryable=False, ) if resp.status != 200: body = {} @@ -1233,6 +1249,7 @@ async def write( "http_status": resp.status, "message": body.get("error", f"HTTP {resp.status}"), }, + retryable=_cloudflare_retryability_for_status(resp.status), ) except WorkspaceArchiveWriteError: raise @@ -1255,7 +1272,6 @@ async def running(self) -> bool: @retry_async( retry_if=lambda exc, self: isinstance(exc, aiohttp.ClientError) - or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) or _is_transient_workspace_error(exc) ) async def _persist_workspace_via_http(self) -> io.IOBase: @@ -1286,6 +1302,7 @@ async def _persist_workspace_via_http(self) -> io.IOBase: "http_status": resp.status, "message": body.get("error", f"HTTP {resp.status}"), }, + retryable=_cloudflare_retryability_for_status(resp.status), ) return io.BytesIO(self._decode_streamed_payload(await resp.read())) except WorkspaceArchiveReadError: @@ -1297,7 +1314,6 @@ async def _persist_workspace_via_http(self) -> io.IOBase: @retry_async( retry_if=lambda exc, self, data: isinstance(exc, aiohttp.ClientError) - or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) or _is_transient_workspace_error(exc) ) async def _hydrate_workspace_via_http(self, data: io.IOBase) -> None: @@ -1346,6 +1362,7 @@ async def _hydrate_workspace_via_http(self, data: io.IOBase) -> None: "http_status": resp.status, "message": body.get("error", f"HTTP {resp.status}"), }, + retryable=_cloudflare_retryability_for_status(resp.status), ) except WorkspaceArchiveWriteError: raise diff --git a/src/agents/extensions/sandbox/daytona/sandbox.py b/src/agents/extensions/sandbox/daytona/sandbox.py index 6df0e451c9..36bd195031 100644 --- a/src/agents/extensions/sandbox/daytona/sandbox.py +++ b/src/agents/extensions/sandbox/daytona/sandbox.py @@ -62,6 +62,7 @@ TRANSIENT_HTTP_STATUS_CODES, exception_chain_contains_type, exception_chain_has_status_code, + iter_exception_chain, retry_async, ) from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes @@ -76,6 +77,22 @@ logger = logging.getLogger(__name__) +# Daytona documents SDK error subclasses plus `status_code` and `error_code` fields at: +# https://www.daytona.io/docs/en/python-sdk/common/errors/ +_DAYTONA_HTTP_STATUS_RETRYABLE: dict[int, bool] = { + 400: False, + 401: False, + 403: False, + 404: False, + 409: False, + 429: True, + 500: True, + 502: True, + 503: True, + 504: True, +} + + def _daytona_provider_error_detail(error: BaseException) -> str | None: message = str(error) status = getattr(error, "status_code", None) or getattr(error, "status", None) @@ -88,6 +105,35 @@ def _daytona_provider_error_detail(error: BaseException) -> str | None: return type(error).__name__ +def _daytona_provider_retryability(error: BaseException) -> tuple[bool | None, str | None]: + non_retryable_types = _daytona_non_retryable_error_types() + retryable_types = _daytona_retryable_error_types() + + for candidate in iter_exception_chain(error): + provider_error_code = getattr(candidate, "error_code", None) + reason = str(provider_error_code) if isinstance(provider_error_code, str) else None + + if non_retryable_types and isinstance(candidate, non_retryable_types): + return False, reason + + if retryable_types and isinstance(candidate, retryable_types): + return True, reason + + status = getattr(candidate, "status_code", None) or getattr(candidate, "status", None) + if isinstance(status, int): + retryable = _DAYTONA_HTTP_STATUS_RETRYABLE.get(status) + if retryable is not None: + return retryable, reason or f"http_{status}" + + message = str(candidate).lower() + if "is the sandbox started" in message or "no ip address found" in message: + return False, "sandbox_not_running" + + if exception_chain_contains_type(error, _retryable_persist_workspace_error_types()): + return True, "provider_timeout" + return None, None + + def _daytona_exec_transport_error( *, command: tuple[str | Path, ...], @@ -95,15 +141,27 @@ def _daytona_exec_transport_error( ) -> ExecTransportError: detail = _daytona_provider_error_detail(cause) context: dict[str, object] = {"backend": "daytona"} + retryable, reason = _daytona_provider_retryability(cause) + if reason is not None: + context["reason"] = reason if detail: context["provider_error"] = detail + provider_error_code = getattr(cause, "error_code", None) + if isinstance(provider_error_code, str) and provider_error_code: + context["provider_error_code"] = provider_error_code status = getattr(cause, "status_code", None) or getattr(cause, "status", None) if isinstance(status, int): context["http_status"] = status message = "Daytona exec failed" if detail: message = f"{message}: {detail}" - return ExecTransportError(command=command, context=context, cause=cause, message=message) + return ExecTransportError( + command=command, + context=context, + cause=cause, + message=message, + retryable=retryable, + ) def _import_daytona_sdk() -> tuple[Any, Any, Any, Any]: @@ -178,32 +236,49 @@ def _import_session_execute_request() -> Any: ) from e -def _import_daytona_exceptions() -> dict[str, type[BaseException]]: - """Best-effort import Daytona exception classes for fine-grained error mapping.""" +def _daytona_exception_types(*names: str) -> tuple[type[BaseException], ...]: + """Best-effort import of Daytona exception classes by name.""" try: - from daytona import ( - DaytonaError, - DaytonaNotFoundError, - DaytonaRateLimitError, - DaytonaTimeoutError, - ) + daytona_module = __import__("daytona") except Exception: - return {} - return { - "base": DaytonaError, - "timeout": DaytonaTimeoutError, - "not_found": DaytonaNotFoundError, - "rate_limit": DaytonaRateLimitError, - } + return () + + exceptions: list[type[BaseException]] = [] + for name in names: + value = getattr(daytona_module, name, None) + if isinstance(value, type) and issubclass(value, BaseException): + exceptions.append(value) + return tuple(exceptions) + + +def _daytona_retryable_error_types() -> tuple[type[BaseException], ...]: + return _daytona_exception_types( + "DaytonaRateLimitError", + "DaytonaTimeoutError", + "DaytonaConnectionError", + ) + + +def _daytona_timeout_error_types() -> tuple[type[BaseException], ...]: + return _daytona_exception_types("DaytonaTimeoutError") + + +def _daytona_non_retryable_error_types() -> tuple[type[BaseException], ...]: + return _daytona_exception_types( + "DaytonaNotFoundError", + "DaytonaAuthenticationError", + "DaytonaAuthorizationError", + "DaytonaValidationError", + "DaytonaConflictError", + ) + + +def _daytona_not_found_error_types() -> tuple[type[BaseException], ...]: + return _daytona_exception_types("DaytonaNotFoundError") def _retryable_persist_workspace_error_types() -> tuple[type[BaseException], ...]: - excs = _import_daytona_exceptions() - retryable: list[type[BaseException]] = [asyncio.TimeoutError] - timeout_exc = excs.get("timeout") - if timeout_exc is not None: - retryable.append(timeout_exc) - return tuple(retryable) + return (asyncio.TimeoutError, *_daytona_timeout_error_types()) class DaytonaSandboxResources(BaseModel): @@ -495,8 +570,7 @@ async def _exec_internal( caller_timeout = self._coerce_exec_timeout(timeout) deadline = time.monotonic() + caller_timeout SessionExecuteRequest = _import_session_execute_request() - daytona_exc = _import_daytona_exceptions() - timeout_exc = daytona_exc.get("timeout") + timeout_error_types = _daytona_timeout_error_types() def _remaining_timeout() -> float: return max(0.0, deadline - time.monotonic()) @@ -535,7 +609,7 @@ def _remaining_timeout() -> float: except asyncio.TimeoutError as e: raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e except Exception as e: - if timeout_exc is not None and isinstance(e, timeout_exc): + if timeout_error_types and isinstance(e, timeout_error_types): raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e raise _daytona_exec_transport_error(command=command, cause=e) from e finally: @@ -566,8 +640,7 @@ async def pty_exec_start( envs = await self._resolved_envs() cwd = sandbox_path_str(self.state.manifest.root) exec_timeout = self._coerce_exec_timeout(timeout) - daytona_exc = _import_daytona_exceptions() - timeout_exc = daytona_exc.get("timeout") + timeout_error_types = _daytona_timeout_error_types() daytona_session_id = f"sandbox-{uuid.uuid4().hex[:12]}" entry = _DaytonaPtySessionEntry( @@ -657,7 +730,7 @@ async def _on_data(chunk: bytes | str) -> None: await asyncio.shield(cleanup_task) except BaseException: await asyncio.shield(cleanup_task) - if timeout_exc is not None and isinstance(e, timeout_exc): + if timeout_error_types and isinstance(e, timeout_error_types): raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e raise _daytona_exec_transport_error(command=command, cause=e) from e except BaseException: @@ -871,8 +944,7 @@ async def read(self, path: Path | str, *, user: str | User | None = None) -> io. else: workspace_path = await self._validate_path_access(path) - daytona_exc = _import_daytona_exceptions() - not_found_exc = daytona_exc.get("not_found") + not_found_error_types = _daytona_not_found_error_types() try: data: bytes = await self._sandbox.fs.download_file( @@ -881,7 +953,7 @@ async def read(self, path: Path | str, *, user: str | User | None = None) -> io. ) return io.BytesIO(data) except Exception as e: - if not_found_exc is not None and isinstance(e, not_found_exc): + if not_found_error_types and isinstance(e, not_found_error_types): raise WorkspaceReadNotFoundError(path=error_path, cause=e) from e raise WorkspaceArchiveReadError(path=error_path, cause=e) from e @@ -946,6 +1018,7 @@ async def _run_persist_workspace_command(self, tar_cmd: str, tar_path: str) -> b raise WorkspaceArchiveReadError( path=self._workspace_root_path(), context={"reason": "tar_failed", "output": result.result or ""}, + retryable=False, ) return cast( bytes, @@ -957,7 +1030,22 @@ async def _run_persist_workspace_command(self, tar_cmd: str, tar_path: str) -> b except WorkspaceArchiveReadError: raise except Exception as e: - raise WorkspaceArchiveReadError(path=self._workspace_root_path(), cause=e) from e + detail = _daytona_provider_error_detail(e) + retryable, reason = _daytona_provider_retryability(e) + context: dict[str, object] = {"backend": "daytona"} + if reason is not None: + context["reason"] = reason + if detail: + context["provider_error"] = detail + provider_error_code = getattr(e, "error_code", None) + if isinstance(provider_error_code, str) and provider_error_code: + context["provider_error_code"] = provider_error_code + raise WorkspaceArchiveReadError( + path=self._workspace_root_path(), + context=context, + cause=e, + retryable=retryable, + ) from e async def persist_workspace(self) -> io.IOBase: def _error_context_summary(error: WorkspaceArchiveReadError) -> dict[str, str]: diff --git a/src/agents/extensions/sandbox/e2b/sandbox.py b/src/agents/extensions/sandbox/e2b/sandbox.py index 6586ee1550..425436f3e0 100644 --- a/src/agents/extensions/sandbox/e2b/sandbox.py +++ b/src/agents/extensions/sandbox/e2b/sandbox.py @@ -87,12 +87,34 @@ logger = logging.getLogger(__name__) +# E2B documents SDK exception classes at: +# https://e2b.dev/docs/sdk-reference/python-sdk/v1.0.0/exceptions +def _e2b_provider_retryability(error: BaseException) -> tuple[bool | None, str | None]: + non_retryable_types = _e2b_non_retryable_error_types() + retryable_types = _e2b_retryable_error_types() + + for candidate in iter_exception_chain(error): + if non_retryable_types and isinstance(candidate, non_retryable_types): + return False, type(candidate).__name__ + + if retryable_types and isinstance(candidate, retryable_types): + return True, type(candidate).__name__ + + status = getattr(candidate, "status_code", None) or getattr(candidate, "status", None) + if isinstance(status, int) and status in TRANSIENT_HTTP_STATUS_CODES: + return True, "transient_http_status" + + if exception_chain_contains_type(error, _retryable_persist_workspace_error_types()): + return True, "provider_timeout" + return None, None + + def _raise_e2b_exec_error( exc: BaseException, *, command: Sequence[str | Path], timeout: float | None, - timeout_exc: type[BaseException] | None, + timeout_error_types: tuple[type[BaseException], ...], ) -> NoReturn: """Classify an E2B exception and raise the appropriate ExecFailureError.""" # Build context from the exception chain. @@ -113,13 +135,21 @@ def _raise_e2b_exec_error( chain = list(iter_exception_chain(exc)) - # Sandbox gone — always a transport error. - if any("sandbox" in str(c).lower() and "not found" in str(c).lower() for c in chain): - ctx.setdefault("reason", "sandbox_not_found") - raise ExecTransportError(command=command, context=ctx, cause=exc) from exc + retryable, reason = _e2b_provider_retryability(exc) + if reason is not None: + ctx.setdefault("reason", reason) + + # Terminal provider errors are transport failures, not command timeouts. + if retryable is False: + raise ExecTransportError( + command=command, + context=ctx, + cause=exc, + retryable=False, + ) from exc # E2B timeout or httpcore read timeout. - is_timeout = timeout_exc is not None and exception_chain_contains_type(exc, (timeout_exc,)) + is_timeout = exception_chain_contains_type(exc, timeout_error_types) if not is_timeout and any( type(c).__name__ == "ReadTimeout" and type(c).__module__.startswith("httpcore") for c in chain @@ -135,7 +165,7 @@ def _raise_e2b_exec_error( cause=exc, ) from exc - raise ExecTransportError(command=command, context=ctx, cause=exc) from exc + raise ExecTransportError(command=command, context=ctx, cause=exc, retryable=retryable) from exc def _encode_e2b_snapshot_ref(*, snapshot_id: str) -> bytes: @@ -491,23 +521,48 @@ async def _sandbox_connect( return await sandbox_class._cls_connect(sandbox_id=sandbox_id, timeout=timeout) -def _import_e2b_exceptions() -> Mapping[str, type[BaseException]]: - """Best-effort import of E2B exception classes for classification.""" - +def _e2b_exception_types(*names: str) -> tuple[type[BaseException], ...]: + """Best-effort import of E2B exception classes by name.""" try: - from e2b.exceptions import ( - NotFoundException, - SandboxException, - TimeoutException, - ) + from e2b import exceptions as e2b_exceptions except Exception: # pragma: no cover - handled by fallbacks - return {} + return () - return { - "not_found": cast(type[BaseException], NotFoundException), - "sandbox": cast(type[BaseException], SandboxException), - "timeout": cast(type[BaseException], TimeoutException), - } + exceptions: list[type[BaseException]] = [] + for name in names: + value = getattr(e2b_exceptions, name, None) + if isinstance(value, type) and issubclass(value, BaseException): + exceptions.append(value) + return tuple(exceptions) + + +def _e2b_retryable_error_types() -> tuple[type[BaseException], ...]: + return _e2b_exception_types( + "RateLimitException", + "TimeoutException", + ) + + +def _e2b_timeout_error_types() -> tuple[type[BaseException], ...]: + return _e2b_exception_types("TimeoutException") + + +def _e2b_non_retryable_error_types() -> tuple[type[BaseException], ...]: + return _e2b_exception_types( + "AuthenticationException", + "FileNotFoundException", + "GitAuthException", + "GitUpstreamException", + "InvalidArgumentException", + "NotEnoughSpaceException", + "NotFoundException", + "SandboxNotFoundException", + "TemplateException", + ) + + +def _e2b_not_found_error_types() -> tuple[type[BaseException], ...]: + return _e2b_exception_types("NotFoundException") def _import_command_exit_exception() -> type[BaseException] | None: @@ -521,12 +576,7 @@ def _import_command_exit_exception() -> type[BaseException] | None: def _retryable_persist_workspace_error_types() -> tuple[type[BaseException], ...]: - excs = _import_e2b_exceptions() - retryable: list[type[BaseException]] = [] - timeout_exc = excs.get("timeout") - if timeout_exc is not None: - retryable.append(timeout_exc) - return tuple(retryable) + return _e2b_timeout_error_types() class E2BSandboxTimeouts(BaseModel): @@ -847,8 +897,7 @@ async def _exec_internal( cmd_str = shlex.join(command_list) exec_timeout = self._coerce_exec_timeout(timeout) - e2b_exc = _import_e2b_exceptions() - timeout_exc = e2b_exc.get("timeout") + timeout_error_types = _e2b_timeout_error_types() command_exit_exc = _import_command_exit_exception() try: @@ -880,7 +929,7 @@ async def _exec_internal( e, command=command, timeout=timeout, - timeout_exc=timeout_exc, + timeout_error_types=timeout_error_types, ) def supports_pty(self) -> bool: @@ -901,8 +950,7 @@ async def pty_exec_start( envs = await self._resolved_envs() cwd = self.state.manifest.root if self._workspace_root_ready else None exec_timeout = self._coerce_exec_timeout(timeout) - e2b_exc = _import_e2b_exceptions() - timeout_exc = e2b_exc.get("timeout") + timeout_error_types = _e2b_timeout_error_types() entry = _E2BPtyProcessEntry(handle=None, tty=tty) @@ -971,7 +1019,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None: e, command=command, timeout=timeout, - timeout_exc=timeout_exc, + timeout_error_types=timeout_error_types, ) if pruned_entry is not None: @@ -1051,8 +1099,7 @@ async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase workspace_path = await self._validate_path_access(path) - e2b_exc = _import_e2b_exceptions() - not_found_exc = e2b_exc.get("not_found") + not_found_error_types = _e2b_not_found_error_types() try: content = await _sandbox_read_file( @@ -1066,7 +1113,7 @@ async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase data = str(content).encode("utf-8", errors="replace") return io.BytesIO(data) except Exception as e: # pragma: no cover - exercised via unit tests with fakes - if not_found_exc is not None and isinstance(e, not_found_exc): + if not_found_error_types and isinstance(e, not_found_error_types): raise WorkspaceReadNotFoundError(path=path, cause=e) from e raise WorkspaceArchiveReadError(path=path, cause=e) from e @@ -1255,12 +1302,22 @@ async def _run_persist_workspace_command(self, tar_cmd: str) -> str: "exit_code": exit_code, "stderr": str(getattr(result, "stderr", "") or ""), }, + retryable=False, ) return str(getattr(result, "stdout", "") or "") except WorkspaceArchiveReadError: raise except Exception as e: # pragma: no cover - exercised via unit tests with fakes - raise WorkspaceArchiveReadError(path=error_root, cause=e) from e + retryable, reason = _e2b_provider_retryability(e) + context: dict[str, object] = {"backend": "e2b"} + if reason is not None: + context["reason"] = reason + raise WorkspaceArchiveReadError( + path=error_root, + context=context, + cause=e, + retryable=retryable, + ) from e async def persist_workspace(self) -> io.IOBase: if self.state.workspace_persistence == _WORKSPACE_PERSISTENCE_SNAPSHOT: diff --git a/src/agents/extensions/sandbox/modal/sandbox.py b/src/agents/extensions/sandbox/modal/sandbox.py index bad940e2c9..3468ae4b00 100644 --- a/src/agents/extensions/sandbox/modal/sandbox.py +++ b/src/agents/extensions/sandbox/modal/sandbox.py @@ -70,6 +70,7 @@ TRANSIENT_HTTP_STATUS_CODES, exception_chain_contains_type, exception_chain_has_status_code, + iter_exception_chain, retry_async, ) from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes @@ -117,6 +118,75 @@ def _modal_provider_error_detail(error: BaseException) -> str | None: return type(error).__name__ +def _modal_exception_types(*names: str) -> tuple[type[BaseException], ...]: + exception_module = getattr(modal, "exception", None) + if exception_module is None: + try: + from modal import exception as exception_module + except Exception: + return () + + exceptions: list[type[BaseException]] = [] + for name in names: + value = getattr(exception_module, name, None) + if isinstance(value, type) and issubclass(value, BaseException): + exceptions.append(value) + return tuple(exceptions) + + +def _modal_retryable_error_types() -> tuple[type[BaseException], ...]: + return _modal_exception_types( + "ConnectionError", + "InternalError", + "InternalFailure", + "ServiceError", + ) + + +def _modal_non_retryable_error_types() -> tuple[type[BaseException], ...]: + return _modal_exception_types( + "AlreadyExistsError", + "AuthError", + "ConflictError", + "InvalidError", + "LogsFetchError", + "NotFoundError", + "PermissionDeniedError", + "RequestSizeError", + "SandboxFilesystemDirectoryNotEmptyError", + "SandboxFilesystemFileTooLargeError", + "SandboxFilesystemIsADirectoryError", + "SandboxFilesystemNotADirectoryError", + "SandboxFilesystemNotFoundError", + "SandboxFilesystemPathAlreadyExistsError", + "SandboxFilesystemPermissionError", + "UnimplementedError", + "VersionError", + ) + + +def _modal_exec_timeout_error_types() -> tuple[type[BaseException], ...]: + return _modal_exception_types("ExecTimeoutError") + + +def _modal_provider_retryability(error: BaseException) -> tuple[bool | None, str | None]: + non_retryable_types = _modal_non_retryable_error_types() + retryable_types = _modal_retryable_error_types() + + for candidate in iter_exception_chain(error): + if non_retryable_types and isinstance(candidate, non_retryable_types): + return False, type(candidate).__name__ + + if retryable_types and isinstance(candidate, retryable_types): + return True, type(candidate).__name__ + + status = getattr(candidate, "status_code", None) or getattr(candidate, "status", None) + if isinstance(status, int) and status in TRANSIENT_HTTP_STATUS_CODES: + return True, "transient_http_status" + + return None, None + + def _modal_exec_transport_error( *, command: tuple[str | Path, ...], @@ -124,15 +194,26 @@ def _modal_exec_transport_error( ) -> ExecTransportError: detail = _modal_provider_error_detail(cause) context: dict[str, object] = {"backend": "modal"} + retryable, reason = _modal_provider_retryability(cause) + if reason is not None: + context["reason"] = reason if detail: context["provider_error"] = detail status = getattr(cause, "status_code", None) or getattr(cause, "status", None) if isinstance(status, int): context["http_status"] = status + if retryable is None and status in TRANSIENT_HTTP_STATUS_CODES: + retryable = True message = "Modal exec failed" if detail: message = f"{message}: {detail}" - return ExecTransportError(command=command, context=context, cause=cause, message=message) + return ExecTransportError( + command=command, + context=context, + cause=cause, + message=message, + retryable=retryable, + ) @asynccontextmanager @@ -707,6 +788,8 @@ async def _run_async() -> ExecResult: except ExecTimeoutError: raise except Exception as e: + if exception_chain_contains_type(e, _modal_exec_timeout_error_types()): + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e raise _modal_exec_transport_error(command=command, cause=e) from e def supports_pty(self) -> bool: @@ -767,6 +850,8 @@ async def pty_exec_start( except Exception as e: if entry is not None and not registered: await self._terminate_pty_entry(entry) + if exception_chain_contains_type(e, _modal_exec_timeout_error_types()): + raise ExecTimeoutError(command=command, timeout_s=timeout, cause=e) from e raise _modal_exec_transport_error(command=command, cause=e) from e if pruned_entry is not None: @@ -1546,7 +1631,8 @@ def _modal_tar_skip_relpaths(self, root: Path) -> set[Path]: @retry_async( retry_if=lambda exc, self: ( - exception_chain_contains_type(exc, (ExecTransportError,)) + isinstance(exc, ExecTransportError) + and exc.retryable is not False or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) ) ) @@ -1580,6 +1666,7 @@ async def _persist_workspace_via_tar(self) -> io.IOBase: "exit_code": out.exit_code, "stderr": out.stderr.decode("utf-8", "replace"), }, + retryable=False, ) return io.BytesIO(out.stdout) except WorkspaceArchiveReadError: diff --git a/src/agents/extensions/sandbox/runloop/sandbox.py b/src/agents/extensions/sandbox/runloop/sandbox.py index 31004017f2..c8d5a660b2 100644 --- a/src/agents/extensions/sandbox/runloop/sandbox.py +++ b/src/agents/extensions/sandbox/runloop/sandbox.py @@ -53,6 +53,7 @@ from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot from ....sandbox.types import ExecResult, ExposedPortEndpoint, User +from ....sandbox.util.retry import iter_exception_chain from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tar_bytes from ....sandbox.workspace_paths import coerce_posix_path, posix_path_as_path, sandbox_path_str @@ -83,10 +84,16 @@ class _RunloopSdkImports: api_response_validation_error: type[BaseException] api_status_error: type[BaseException] api_timeout_error: type[BaseException] + authentication_error: type[BaseException] + bad_request_error: type[BaseException] + internal_server_error: type[BaseException] not_found_error: type[BaseException] + permission_denied_error: type[BaseException] polling_config: type[Any] | None polling_timeout: type[BaseException] | None + rate_limit_error: type[BaseException] runloop_error: type[BaseException] + unprocessable_entity_error: type[BaseException] _RUNLOOP_SDK_IMPORTS: _RunloopSdkImports | None = None @@ -103,8 +110,14 @@ def _import_runloop_sdk() -> _RunloopSdkImports: APIResponseValidationError, APIStatusError, APITimeoutError, + AuthenticationError, + BadRequestError, + InternalServerError, NotFoundError, + PermissionDeniedError, + RateLimitError, RunloopError, + UnprocessableEntityError, ) from runloop_api_client.sdk import AsyncRunloopSDK except ImportError as e: @@ -132,10 +145,16 @@ def _import_runloop_sdk() -> _RunloopSdkImports: api_response_validation_error=APIResponseValidationError, api_status_error=APIStatusError, api_timeout_error=APITimeoutError, + authentication_error=AuthenticationError, + bad_request_error=BadRequestError, + internal_server_error=InternalServerError, not_found_error=NotFoundError, + permission_denied_error=PermissionDeniedError, polling_config=polling_config, polling_timeout=polling_timeout, + rate_limit_error=RateLimitError, runloop_error=RunloopError, + unprocessable_entity_error=UnprocessableEntityError, ) return _RUNLOOP_SDK_IMPORTS @@ -262,6 +281,56 @@ def _runloop_error_message(exc: BaseException) -> str | None: return None +_RUNLOOP_HTTP_STATUS_RETRYABLE: dict[int, bool] = { + 400: False, + 401: False, + 403: False, + 404: False, + 408: True, + 422: False, + 429: True, + 500: True, + 502: True, + 503: True, + 504: True, +} + + +def _runloop_retryable_error_types() -> tuple[type[BaseException], ...]: + sdk_imports = _import_runloop_sdk() + return ( + sdk_imports.api_connection_error, + sdk_imports.api_timeout_error, + sdk_imports.internal_server_error, + sdk_imports.rate_limit_error, + ) + + +def _runloop_non_retryable_error_types() -> tuple[type[BaseException], ...]: + sdk_imports = _import_runloop_sdk() + return ( + sdk_imports.authentication_error, + sdk_imports.bad_request_error, + sdk_imports.not_found_error, + sdk_imports.permission_denied_error, + sdk_imports.unprocessable_entity_error, + ) + + +def _runloop_provider_retryability(exc: BaseException) -> bool | None: + retryable_error_types = _runloop_retryable_error_types() + non_retryable_error_types = _runloop_non_retryable_error_types() + for candidate in iter_exception_chain(exc): + if isinstance(candidate, retryable_error_types): + return True + if isinstance(candidate, non_retryable_error_types): + return False + status_code = _runloop_status_code(candidate) + if status_code in _RUNLOOP_HTTP_STATUS_RETRYABLE: + return _RUNLOOP_HTTP_STATUS_RETRYABLE[status_code] + return None + + def _runloop_provider_error_types() -> tuple[type[BaseException], ...]: sdk_imports = _import_runloop_sdk() return ( @@ -781,6 +850,7 @@ async def _run_exec_command( command=command, context=_runloop_error_context(e, backend_detail="exec_failed"), cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise ExecTransportError(command=command, cause=e) from e @@ -795,6 +865,7 @@ async def _ensure_tunnel_url(self, port: int) -> str: reason="backend_unavailable", context=_runloop_error_context(e, backend_detail="get_tunnel_url_failed"), cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise if isinstance(url, str) and url: @@ -815,6 +886,7 @@ async def _ensure_tunnel_url(self, port: int) -> str: reason="backend_unavailable", context=_runloop_error_context(e, backend_detail="enable_tunnel_failed"), cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise try: @@ -829,6 +901,7 @@ async def _ensure_tunnel_url(self, port: int) -> str: reason="backend_unavailable", context=context, cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise if not isinstance(url, str) or not url: @@ -894,6 +967,7 @@ async def read(self, path: Path | str, *, user: str | User | None = None) -> io. path=error_path, context=_runloop_error_context(e, backend_detail="file_download_failed"), cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise WorkspaceArchiveReadError(path=error_path, cause=e) from e @@ -929,6 +1003,7 @@ async def write( path=workspace_path, context=_runloop_error_context(e, backend_detail="file_upload_failed"), cause=e, + retryable=_runloop_provider_retryability(e), ) from e raise WorkspaceArchiveWriteError(path=workspace_path, cause=e) from e @@ -1134,10 +1209,14 @@ async def persist_workspace(self) -> io.IOBase: except WorkspaceArchiveReadError as e: snapshot_error = e except Exception as e: + retryable = None + if _is_runloop_provider_error(e): + retryable = _runloop_provider_retryability(e) snapshot_error = WorkspaceArchiveReadError( path=root, context={"reason": "snapshot_failed"}, cause=e, + retryable=retryable, ) finally: remount_error: WorkspaceArchiveReadError | None = None @@ -1249,6 +1328,9 @@ async def hydrate_workspace(self, data: io.IOBase) -> None: path=root, context=context, cause=e, + retryable=_runloop_provider_retryability(e) + if _is_runloop_provider_error(e) + else None, ) from e async def _restore_snapshot_into_workspace_on_resume(self) -> None: diff --git a/src/agents/extensions/sandbox/vercel/sandbox.py b/src/agents/extensions/sandbox/vercel/sandbox.py index 44812c1788..ab25bc3398 100644 --- a/src/agents/extensions/sandbox/vercel/sandbox.py +++ b/src/agents/extensions/sandbox/vercel/sandbox.py @@ -23,13 +23,7 @@ import httpx from pydantic import TypeAdapter, field_serializer, field_validator -from vercel.sandbox import ( - AsyncSandbox, - NetworkPolicy, - Resources, - SandboxStatus, - SnapshotSource, -) +from vercel import sandbox as vercel_sandbox from ....sandbox.errors import ( ConfigurationError, @@ -62,6 +56,12 @@ from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tarfile from ....sandbox.workspace_paths import coerce_posix_path, posix_path_as_path, sandbox_path_str +AsyncSandbox = vercel_sandbox.AsyncSandbox +NetworkPolicy = vercel_sandbox.NetworkPolicy +Resources = vercel_sandbox.Resources +SandboxStatus = vercel_sandbox.SandboxStatus +SnapshotSource = vercel_sandbox.SnapshotSource + WorkspacePersistenceMode = Literal["tar", "snapshot"] _WORKSPACE_PERSISTENCE_TAR: WorkspacePersistenceMode = "tar" @@ -78,6 +78,30 @@ httpx.NetworkError, httpx.ProtocolError, ) +_VERCEL_RETRYABLE_PROVIDER_ERRORS: tuple[type[BaseException], ...] = ( + vercel_sandbox.SandboxRateLimitError, + vercel_sandbox.SandboxServerError, +) +_VERCEL_NON_RETRYABLE_PROVIDER_ERRORS: tuple[type[BaseException], ...] = ( + vercel_sandbox.SandboxAuthError, + vercel_sandbox.SandboxNotFoundError, + vercel_sandbox.SandboxPermissionError, + vercel_sandbox.SandboxValidationError, +) +_VERCEL_HTTP_STATUS_RETRYABLE: dict[int, bool] = { + 400: False, + 401: False, + 403: False, + 404: False, + 408: True, + 425: True, + 422: False, + 429: True, + 500: True, + 502: True, + 503: True, + 504: True, +} # Sandbox status values from which the sandbox can still transition to RUNNING. # Only "pending" qualifies: a freshly created sandbox transitions PENDING -> RUNNING. @@ -86,18 +110,25 @@ _VERCEL_TRANSIENT_SANDBOX_STATUSES: frozenset[str] = frozenset({"pending"}) -def _is_transient_create_error(exc: BaseException) -> bool: - if exception_chain_has_status_code(exc, {408, 425, 429, 500, 502, 503, 504}): +def _vercel_provider_retryability(exc: BaseException) -> bool | None: + if exception_chain_contains_type(exc, _VERCEL_RETRYABLE_PROVIDER_ERRORS): + return True + if exception_chain_contains_type(exc, _VERCEL_NON_RETRYABLE_PROVIDER_ERRORS): + return False + if exception_chain_contains_type(exc, _VERCEL_TRANSIENT_TRANSPORT_ERRORS): return True + for status_code, retryable in _VERCEL_HTTP_STATUS_RETRYABLE.items(): + if exception_chain_has_status_code(exc, {status_code}): + return retryable + return None - return exception_chain_contains_type(exc, _VERCEL_TRANSIENT_TRANSPORT_ERRORS) +def _is_transient_create_error(exc: BaseException) -> bool: + return _vercel_provider_retryability(exc) is True -def _is_transient_write_error(exc: BaseException) -> bool: - if exception_chain_has_status_code(exc, {408, 425, 429, 500, 502, 503, 504}): - return True - return exception_chain_contains_type(exc, _VERCEL_TRANSIENT_TRANSPORT_ERRORS) +def _is_transient_write_error(exc: BaseException) -> bool: + return _vercel_provider_retryability(exc) is True @retry_async(retry_if=lambda exc, **_kwargs: _is_transient_create_error(exc)) @@ -314,7 +345,11 @@ async def _prepare_backend_workspace(self) -> None: sandbox = await self._ensure_sandbox() finished = await sandbox.run_command("mkdir", ["-p", "--", root.as_posix()]) except Exception as exc: - raise WorkspaceStartError(path=posix_path_as_path(root), cause=exc) from exc + raise WorkspaceStartError( + path=posix_path_as_path(root), + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc if finished.exit_code != 0: raise WorkspaceStartError( @@ -445,10 +480,15 @@ async def _exec_internal( except ExecTimeoutError: raise except Exception as exc: + context: dict[str, object] = { + "backend": "vercel", + "sandbox_id": self.state.sandbox_id, + } raise ExecTransportError( command=normalized, - context={"backend": "vercel", "sandbox_id": self.state.sandbox_id}, + context=context, cause=exc, + retryable=_vercel_provider_retryability(exc), ) from exc async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint: @@ -462,6 +502,7 @@ async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint: reason="backend_unavailable", context={"backend": "vercel", "sandbox_id": self.state.sandbox_id}, cause=exc, + retryable=_vercel_provider_retryability(exc), ) from exc parsed = urlsplit(domain) @@ -489,7 +530,11 @@ async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase try: payload = await sandbox.read_file(sandbox_path_str(normalized_path)) except Exception as exc: - raise WorkspaceArchiveReadError(path=normalized_path, cause=exc) from exc + raise WorkspaceArchiveReadError( + path=normalized_path, + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc if payload is None: raise WorkspaceReadNotFoundError(path=normalized_path) return io.BytesIO(payload) @@ -518,7 +563,11 @@ async def write( [{"path": sandbox_path_str(normalized_path), "content": bytes(payload)}] ) except Exception as exc: - raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc + raise WorkspaceArchiveWriteError( + path=normalized_path, + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc async def persist_workspace(self) -> io.IOBase: return await with_ephemeral_mounts_removed( @@ -536,7 +585,11 @@ async def _persist_workspace_internal(self) -> io.IOBase: try: snapshot = await sandbox.snapshot(expiration=self.state.snapshot_expiration_ms) except Exception as exc: - raise WorkspaceArchiveReadError(path=root, cause=exc) from exc + raise WorkspaceArchiveReadError( + path=root, + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc return io.BytesIO(_encode_snapshot_ref(snapshot_id=snapshot.snapshot_id)) root = self._workspace_root_path() @@ -572,7 +625,11 @@ async def _persist_workspace_internal(self) -> io.IOBase: except WorkspaceArchiveReadError: raise except Exception as exc: - raise WorkspaceArchiveReadError(path=root, cause=exc) from exc + raise WorkspaceArchiveReadError( + path=root, + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc finally: try: await sandbox.run_command( @@ -612,6 +669,7 @@ async def _hydrate_workspace_internal(self, raw: bytes) -> None: raise WorkspaceArchiveWriteError( path=self._workspace_root_path(), cause=exc, + retryable=_vercel_provider_retryability(exc), ) from exc return @@ -638,7 +696,11 @@ async def _hydrate_workspace_internal(self, raw: bytes) -> None: except WorkspaceArchiveWriteError: raise except Exception as exc: - raise WorkspaceArchiveWriteError(path=root, cause=exc) from exc + raise WorkspaceArchiveWriteError( + path=root, + cause=exc, + retryable=_vercel_provider_retryability(exc), + ) from exc finally: try: await sandbox.run_command( diff --git a/src/agents/sandbox/errors.py b/src/agents/sandbox/errors.py index c90efdefbf..8e7848a39a 100644 --- a/src/agents/sandbox/errors.py +++ b/src/agents/sandbox/errors.py @@ -83,6 +83,8 @@ class SandboxError(Exception): op: The operation where the error occurred. context: Structured metadata to aid debugging. cause: Optional underlying exception. + retryable: Whether retrying the same operation is expected to succeed. + `None` means the SDK cannot safely classify the error. """ message: str @@ -90,8 +92,11 @@ class SandboxError(Exception): op: OpName context: dict[str, object] cause: BaseException | None = None + retryable: bool | None = None def __post_init__(self) -> None: + if self.retryable is None and isinstance(self.cause, SandboxError): + self.retryable = self.cause.retryable super().__init__(self.message) if self.cause is not None: self.__cause__ = self.cause @@ -153,6 +158,7 @@ def __init__( op="materialize", context={"rel": str(rel), "reason": reason, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -178,6 +184,7 @@ def __init__( op="write", context={"path": str(path), "scheme": scheme, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -192,11 +199,13 @@ def __init__( reason: str, context: Mapping[str, object] | None = None, cause: BaseException | None = None, + retryable: bool | None = None, ) -> None: if reason == "not_configured": message = f"port {port} is not configured for host exposure" else: message = f"port {port} could not be resolved for host exposure" + resolved_retryable = False if reason == "not_configured" else retryable super().__init__( message=message, error_code=ErrorCode.EXPOSED_PORT_UNAVAILABLE, @@ -208,6 +217,7 @@ def __init__( **_as_context(context), }, cause=cause, + retryable=resolved_retryable, ) @@ -224,6 +234,7 @@ def __init__( command: Sequence[str | Path], context: Mapping[str, object] | None = None, cause: BaseException | None = None, + retryable: bool | None = None, ) -> None: cmd = tuple(str(c) for c in command) super().__init__( @@ -232,6 +243,7 @@ def __init__( op="exec", context={"command": cmd, "command_str": _format_command(cmd), **_as_context(context)}, cause=cause, + retryable=retryable, ) self.command = cmd @@ -272,6 +284,7 @@ def __init__( **_as_context(context), }, cause=cause, + retryable=False, ) self.exit_code = exec_result.exit_code self.stdout = exec_result.stdout @@ -297,6 +310,7 @@ def __init__( command=command, context={"timeout_s": timeout_s, **_as_context(context)}, cause=cause, + retryable=False, ) self.timeout_s = timeout_s @@ -311,6 +325,7 @@ def __init__( context: Mapping[str, object] | None = None, cause: BaseException | None = None, message: str | None = None, + retryable: bool | None = None, ) -> None: super().__init__( message=message or "exec transport error", @@ -318,6 +333,7 @@ def __init__( command=command, context=_as_context(context), cause=cause, + retryable=retryable, ) @@ -339,6 +355,7 @@ def __init__( op="exec", context={"session_id": session_id, **_as_context(context)}, cause=cause, + retryable=False, ) self.session_id = session_id @@ -370,6 +387,7 @@ def __init__( op="apply_patch", context={"path": str(path), "reason": reason, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -393,6 +411,7 @@ def __init__( op="apply_patch", context=resolved_context, cause=cause, + retryable=False, ) @@ -412,6 +431,7 @@ def __init__( op="apply_patch", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -431,6 +451,7 @@ def __init__( op="apply_patch", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -450,6 +471,7 @@ def __init__( op="read", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -462,6 +484,7 @@ def __init__( path: Path, context: Mapping[str, object] | None = None, cause: BaseException | None = None, + retryable: bool | None = None, ) -> None: super().__init__( message=f"failed to read archive for path: {path}", @@ -469,6 +492,7 @@ def __init__( op="read", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=retryable, ) @@ -481,6 +505,7 @@ def __init__( path: Path, context: Mapping[str, object] | None = None, cause: BaseException | None = None, + retryable: bool | None = None, ) -> None: super().__init__( message=f"failed to write archive for path: {path}", @@ -488,6 +513,7 @@ def __init__( op="write", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=retryable, ) @@ -508,6 +534,7 @@ def __init__( op="write", context={"path": str(path), "actual_type": actual_type, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -520,6 +547,7 @@ def __init__( path: Path, context: Mapping[str, object] | None = None, cause: BaseException | None = None, + retryable: bool | None = None, ) -> None: super().__init__( message="failed to stop session", @@ -527,6 +555,7 @@ def __init__( op="stop", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=retryable, ) @@ -540,6 +569,7 @@ def __init__( context: Mapping[str, object] | None = None, cause: BaseException | None = None, message: str | None = None, + retryable: bool | None = None, ) -> None: super().__init__( message=message or "failed to start session", @@ -547,6 +577,7 @@ def __init__( op="start", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=retryable, ) @@ -566,6 +597,7 @@ def __init__( op="exec", context={"path": str(path), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -589,6 +621,7 @@ def __init__( op="materialize", context={"src": str(src), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -608,6 +641,7 @@ def __init__( op="materialize", context={"src": str(src), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -627,6 +661,7 @@ def __init__( op="materialize", context={"src": str(src), **_as_context(context)}, cause=cause, + retryable=False, ) @@ -649,6 +684,7 @@ def __init__( op="materialize", context=_as_context(context), cause=cause, + retryable=False, ) @@ -670,6 +706,7 @@ def __init__( op="materialize", context={"url": url, "ref": ref, "stderr": stderr, **_as_context(context)}, cause=cause, + retryable=None, ) @@ -696,6 +733,7 @@ def __init__( **_as_context(context), }, cause=cause, + retryable=False, ) @@ -722,6 +760,7 @@ def __init__( **_as_context(context), }, cause=cause, + retryable=None, ) @@ -745,6 +784,7 @@ def __init__( op="materialize", context={"tool": tool, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -762,6 +802,7 @@ def __init__( error_code=ErrorCode.MOUNT_CONFIG_INVALID, op="materialize", context=_as_context(context), + retryable=False, ) @@ -782,6 +823,7 @@ def __init__( op="materialize", context={"command": command, "stderr": stderr, **_as_context(context)}, cause=cause, + retryable=False, ) @@ -801,6 +843,7 @@ def __init__( op="materialize", context=_as_context(context), cause=cause, + retryable=False, ) @@ -821,6 +864,7 @@ def __init__( op="snapshot_persist", context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, cause=cause, + retryable=None, ) @@ -841,6 +885,7 @@ def __init__( op="snapshot_restore", context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, cause=cause, + retryable=None, ) @@ -859,4 +904,5 @@ def __init__( error_code=ErrorCode.SNAPSHOT_NOT_RESTORABLE, op="snapshot_restore", context={"snapshot_id": snapshot_id, "path": str(path), **_as_context(context)}, + retryable=False, ) diff --git a/src/agents/sandbox/sandboxes/docker.py b/src/agents/sandbox/sandboxes/docker.py index 2148840566..ae160fc978 100644 --- a/src/agents/sandbox/sandboxes/docker.py +++ b/src/agents/sandbox/sandboxes/docker.py @@ -1219,9 +1219,19 @@ async def persist_workspace(self) -> io.IOBase: ) return strip_tar_member_prefix(root_prefixed_archive, prefix=staging_workspace.name) except docker.errors.NotFound as e: - raise WorkspaceArchiveReadError(path=error_root, cause=e) from e + raise WorkspaceArchiveReadError(path=error_root, cause=e, retryable=False) from e except docker.errors.APIError as e: - raise WorkspaceArchiveReadError(path=error_root, cause=e) from e + status_code = getattr(e, "status_code", None) + retryable = ( + True + if isinstance(status_code, int) and status_code in TRANSIENT_HTTP_STATUS_CODES + else None + ) + raise WorkspaceArchiveReadError( + path=error_root, + cause=e, + retryable=retryable, + ) from e async def hydrate_workspace(self, data: io.IOBase) -> None: root = self._workspace_root_path() diff --git a/src/agents/sandbox/session/events.py b/src/agents/sandbox/session/events.py index c0aa587900..c99709693b 100644 --- a/src/agents/sandbox/session/events.py +++ b/src/agents/sandbox/session/events.py @@ -70,6 +70,7 @@ class SandboxSessionFinishEvent(SandboxSessionEventBase): error_code: ErrorCode | None = None error_type: str | None = None error_message: str | None = None + error_retryable: bool | None = None # Optional exec outputs (truncated / opt-in via policy). stdout: str | None = None diff --git a/src/agents/sandbox/session/sandbox_session.py b/src/agents/sandbox/session/sandbox_session.py index 97dccfa07b..1d41bc02fb 100644 --- a/src/agents/sandbox/session/sandbox_session.py +++ b/src/agents/sandbox/session/sandbox_session.py @@ -351,6 +351,9 @@ def _apply_trace_finish_data( if isinstance(exc, SandboxError): trace_data["error_code"] = exc.error_code error_data["error_code"] = exc.error_code + if exc.retryable is not None: + trace_data["error_retryable"] = exc.retryable + error_data["error_retryable"] = exc.retryable span.set_error({"message": type(exc).__name__, "data": error_data}) return if not ok: @@ -477,6 +480,7 @@ async def _emit_finish_event( event.error_message = str(exc) if isinstance(exc, SandboxError): event.error_code = exc.error_code + event.error_retryable = exc.retryable # Preserve raw bytes so Instrumentation can apply per-op/per-sink policies later. # Decoding here would force one global formatting decision before sink-specific redaction diff --git a/tests/extensions/sandbox/test_blaxel.py b/tests/extensions/sandbox/test_blaxel.py index fcbd2428d8..84ef1e1b72 100644 --- a/tests/extensions/sandbox/test_blaxel.py +++ b/tests/extensions/sandbox/test_blaxel.py @@ -3286,6 +3286,113 @@ async def _raise_500(*args: object, **kw: object) -> None: assert exc_info.value.context["backend"] == "blaxel" assert exc_info.value.context["http_status"] == 500 assert exc_info.value.context["provider_error"] == "HTTP 500: internal error" + assert exc_info.value.retryable is True + + @pytest.mark.asyncio + async def test_exec_uses_structured_blaxel_non_retryable_error_code( + self, fake_sandbox: _FakeSandboxInstance + ) -> None: + from agents.extensions.sandbox.blaxel import sandbox as mod + + session = _make_session(fake_sandbox) + + class FakeApiError(Exception): + def __init__(self) -> None: + super().__init__("route not found") + self.status_code = 404 + self.body = { + "error": { + "code": "ROUTE_NOT_FOUND", + "message": "Preview not found: sandbox", + "retryable": False, + "status": 404, + } + } + + async def _raise_route_not_found(*args: object, **kw: object) -> None: + raise FakeApiError() + + fake_sandbox.process.exec = _raise_route_not_found # type: ignore[assignment] + + with patch.object(mod, "_import_sandbox_api_error", return_value=FakeApiError): + with pytest.raises(ExecTransportError) as exc_info: + await session._exec_internal("echo", "hello") + + assert str(exc_info.value) == "Blaxel exec failed: HTTP 404: route not found" + assert exc_info.value.context["backend"] == "blaxel" + assert exc_info.value.context["http_status"] == 404 + assert exc_info.value.context["provider_error"] == "HTTP 404: route not found" + assert exc_info.value.context["provider_error_code"] == "ROUTE_NOT_FOUND" + assert exc_info.value.retryable is False + + @pytest.mark.asyncio + async def test_exec_uses_structured_blaxel_retryable_error_code( + self, fake_sandbox: _FakeSandboxInstance + ) -> None: + from agents.extensions.sandbox.blaxel import sandbox as mod + + session = _make_session(fake_sandbox) + + class FakeApiError(Exception): + def __init__(self) -> None: + super().__init__("workload unavailable") + self.status_code = 404 + self.body = { + "error": { + "code": "WORKLOAD_UNAVAILABLE", + "message": "No healthy replica is serving workload", + "retryable": True, + "status": 404, + } + } + + async def _raise_workload_unavailable(*args: object, **kw: object) -> None: + raise FakeApiError() + + fake_sandbox.process.exec = _raise_workload_unavailable # type: ignore[assignment] + + with patch.object(mod, "_import_sandbox_api_error", return_value=FakeApiError): + with pytest.raises(ExecTransportError) as exc_info: + await session._exec_internal("echo", "hello") + + assert str(exc_info.value) == "Blaxel exec failed: HTTP 404: workload unavailable" + assert exc_info.value.context["backend"] == "blaxel" + assert exc_info.value.context["http_status"] == 404 + assert exc_info.value.context["provider_error"] == "HTTP 404: workload unavailable" + assert exc_info.value.context["provider_error_code"] == "WORKLOAD_UNAVAILABLE" + assert exc_info.value.retryable is True + + @pytest.mark.parametrize( + ("code", "expected_retryable"), + [ + ("ROUTE_NOT_FOUND", False), + ("WORKLOAD_NOT_FOUND", False), + ("WORKSPACE_NOT_FOUND", False), + ("WORKLOAD_UNAVAILABLE", True), + ("AUTHENTICATION_REQUIRED", False), + ("AUTHENTICATION_FAILED", False), + ("FORBIDDEN", False), + ("BAD_REQUEST", False), + ("USAGE_LIMIT_EXCEEDED", False), + ("POLICY_VIOLATION", False), + ], + ) + def test_blaxel_retryability_error_code_table( + self, + code: str, + expected_retryable: bool, + ) -> None: + from agents.extensions.sandbox.blaxel import sandbox as mod + + class FakeApiError(Exception): + def __init__(self) -> None: + super().__init__(code) + self.body = {"error": {"code": code, "message": code}} + + retryable, provider_error_code = mod._blaxel_provider_retryability(FakeApiError()) + + assert retryable is expected_retryable + assert provider_error_code == code # --------------------------------------------------------------------------- diff --git a/tests/extensions/sandbox/test_cloudflare.py b/tests/extensions/sandbox/test_cloudflare.py index 43c81665a9..84e5ae9f39 100644 --- a/tests/extensions/sandbox/test_cloudflare.py +++ b/tests/extensions/sandbox/test_cloudflare.py @@ -664,6 +664,58 @@ async def test_cloudflare_exec_non_200_includes_provider_error_details() -> None str(exc_info.value) == "POST /exec failed: HTTP 502: pool_error: pool error: Failed to start container" ) + assert exc_info.value.retryable is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("status", "expected_retryable"), + [ + (400, False), + (500, False), + (503, True), + ], +) +async def test_cloudflare_exec_retryability_follows_documented_status_semantics( + status: int, + expected_retryable: bool, +) -> None: + sess = _make_session( + fake_http=_FakeHttp( + { + "POST /exec": _FakeResponse( + status=status, + json_body={ + "error": "cloudflare sandbox error", + "code": "cloudflare_error", + }, + ) + } + ) + ) + + with pytest.raises(ExecTransportError) as exc_info: + await sess._exec_internal("mkdir", "-p", "--", "/workspace", timeout=5.0) + + assert exc_info.value.context["backend"] == "cloudflare" + assert exc_info.value.context["http_status"] == status + assert exc_info.value.context["provider_error"] == "cloudflare_error: cloudflare sandbox error" + assert exc_info.value.retryable is expected_retryable + + +@pytest.mark.parametrize( + ("status", "expected_retryable"), + [ + (400, False), + (500, False), + (503, True), + (418, None), + ], +) +def test_cloudflare_retryability_status_table(status: int, expected_retryable: bool | None) -> None: + from agents.extensions.sandbox.cloudflare import sandbox as mod + + assert mod._cloudflare_retryability_for_status(status) is expected_retryable @pytest.mark.asyncio @@ -965,6 +1017,48 @@ async def test_cloudflare_persist_and_hydrate_use_http_endpoints() -> None: assert "root" not in hydrate_calls[0].get("params", {}) +@pytest.mark.asyncio +async def test_cloudflare_persist_retries_only_documented_503_status() -> None: + fake_http = _FakeHttp( + { + "POST /persist": _FakeResponse( + status=503, + json_body={"error": "container starting"}, + ) + } + ) + sess = _make_session(fake_http=fake_http) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await sess.persist_workspace() + + persist_calls = [c for c in fake_http.calls if c["method"] == "POST" and "/persist" in c["url"]] + assert len(persist_calls) == 3 + assert exc_info.value.context["http_status"] == 503 + assert exc_info.value.retryable is True + + +@pytest.mark.asyncio +async def test_cloudflare_persist_does_not_retry_documented_fail_fast_500() -> None: + fake_http = _FakeHttp( + { + "POST /persist": _FakeResponse( + status=500, + json_body={"error": "configuration error"}, + ) + } + ) + sess = _make_session(fake_http=fake_http) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await sess.persist_workspace() + + persist_calls = [c for c in fake_http.calls if c["method"] == "POST" and "/persist" in c["url"]] + assert len(persist_calls) == 1 + assert exc_info.value.context["http_status"] == 500 + assert exc_info.value.retryable is False + + @pytest.mark.asyncio async def test_cloudflare_persist_unmounts_and_remounts_ephemeral_bucket_mounts() -> None: fake_http = _FakeHttp( diff --git a/tests/extensions/sandbox/test_daytona.py b/tests/extensions/sandbox/test_daytona.py index 70a9015e91..c8276cf545 100644 --- a/tests/extensions/sandbox/test_daytona.py +++ b/tests/extensions/sandbox/test_daytona.py @@ -1110,6 +1110,37 @@ async def test_persist_workspace_remounts_mounts_after_snapshot( assert mount._unmounted_paths == [mount_path] assert mount._mounted_paths == [mount_path] + @pytest.mark.asyncio + async def test_persist_workspace_marks_stopped_sandbox_non_retryable( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Verify stopped Daytona sandboxes expose provider-neutral retryability.""" + + daytona_module = _load_daytona_module(monkeypatch) + sandbox = _FakeDaytonaSandbox() + state = daytona_module.DaytonaSandboxSessionState( + manifest=Manifest(root=daytona_module.DEFAULT_DAYTONA_WORKSPACE_ROOT), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.id, + ) + session = daytona_module.DaytonaSandboxSession.from_state(state, sandbox=sandbox) + + async def _raise_stopped_sandbox(_cmd: str, **_kwargs: object) -> object: + raise RuntimeError( + "bad request: failed to resolve container IP after 3 attempts: " + "no IP address found. Is the Sandbox started?" + ) + + monkeypatch.setattr(sandbox.process, "exec", _raise_stopped_sandbox) + + with pytest.raises(daytona_module.WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert exc_info.value.retryable is False + assert exc_info.value.context["backend"] == "daytona" + assert exc_info.value.context["reason"] == "sandbox_not_running" + @pytest.mark.asyncio async def test_persist_workspace_uses_nested_mount_targets_and_runtime_skip_paths( self, @@ -1342,8 +1373,8 @@ class _FakeTimeout(Exception): monkeypatch.setattr( daytona_module, - "_import_daytona_exceptions", - lambda: {"timeout": _FakeTimeout}, + "_daytona_timeout_error_types", + lambda: (_FakeTimeout,), ) sandbox = _FakeDaytonaSandbox() @@ -1358,6 +1389,111 @@ class _FakeTimeout(Exception): with pytest.raises(ExecTimeoutError): await session.pty_exec_start("python3", shell=False, tty=False, timeout=2.0) + @pytest.mark.asyncio + async def test_pty_start_marks_documented_sdk_not_found_non_retryable( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + daytona_module = _load_daytona_module(monkeypatch) + + class _FakeNotFound(Exception): + status_code = 404 + error_code = "sandbox_not_found" + + monkeypatch.setattr( + daytona_module, + "_daytona_non_retryable_error_types", + lambda: (_FakeNotFound,), + ) + monkeypatch.setattr(daytona_module, "_daytona_retryable_error_types", lambda: ()) + + sandbox = _FakeDaytonaSandbox() + sandbox.process.create_pty_session_error = _FakeNotFound("sandbox not found") + state = daytona_module.DaytonaSandboxSessionState( + manifest=Manifest(root=daytona_module.DEFAULT_DAYTONA_WORKSPACE_ROOT), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.id, + ) + session = daytona_module.DaytonaSandboxSession.from_state(state, sandbox=sandbox) + + with pytest.raises(ExecTransportError) as exc_info: + await session.pty_exec_start("python3", shell=False, tty=True) + + assert exc_info.value.retryable is False + assert exc_info.value.context["backend"] == "daytona" + assert exc_info.value.context["http_status"] == 404 + assert exc_info.value.context["provider_error_code"] == "sandbox_not_found" + assert exc_info.value.context["reason"] == "sandbox_not_found" + + @pytest.mark.asyncio + async def test_pty_start_marks_documented_sdk_rate_limit_retryable( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + daytona_module = _load_daytona_module(monkeypatch) + + class _FakeRateLimit(Exception): + status_code = 429 + error_code = "rate_limit_exceeded" + + monkeypatch.setattr( + daytona_module, + "_daytona_retryable_error_types", + lambda: (_FakeRateLimit,), + ) + monkeypatch.setattr(daytona_module, "_daytona_non_retryable_error_types", lambda: ()) + + sandbox = _FakeDaytonaSandbox() + sandbox.process.create_pty_session_error = _FakeRateLimit("rate limit exceeded") + state = daytona_module.DaytonaSandboxSessionState( + manifest=Manifest(root=daytona_module.DEFAULT_DAYTONA_WORKSPACE_ROOT), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.id, + ) + session = daytona_module.DaytonaSandboxSession.from_state(state, sandbox=sandbox) + + with pytest.raises(ExecTransportError) as exc_info: + await session.pty_exec_start("python3", shell=False, tty=True) + + assert exc_info.value.retryable is True + assert exc_info.value.context["backend"] == "daytona" + assert exc_info.value.context["http_status"] == 429 + assert exc_info.value.context["provider_error_code"] == "rate_limit_exceeded" + assert exc_info.value.context["reason"] == "rate_limit_exceeded" + + @pytest.mark.parametrize( + ("status", "expected_retryable"), + [ + (400, False), + (401, False), + (403, False), + (404, False), + (409, False), + (429, True), + (500, True), + (502, True), + (503, True), + (504, True), + ], + ) + def test_daytona_retryability_status_table( + self, + monkeypatch: pytest.MonkeyPatch, + status: int, + expected_retryable: bool, + ) -> None: + daytona_module = _load_daytona_module(monkeypatch) + monkeypatch.setattr(daytona_module, "_daytona_non_retryable_error_types", lambda: ()) + monkeypatch.setattr(daytona_module, "_daytona_retryable_error_types", lambda: ()) + + class FakeStatusError(Exception): + status_code = status + + retryable, reason = daytona_module._daytona_provider_retryability(FakeStatusError()) + + assert retryable is expected_retryable + assert reason == f"http_{status}" + @pytest.mark.asyncio async def test_session_reader_keeps_entry_live_when_logs_fail_without_exit_code( self, diff --git a/tests/extensions/sandbox/test_e2b.py b/tests/extensions/sandbox/test_e2b.py index a7bbc8bd1f..6a123770c3 100644 --- a/tests/extensions/sandbox/test_e2b.py +++ b/tests/extensions/sandbox/test_e2b.py @@ -1456,6 +1456,7 @@ async def test_e2b_persist_workspace_raises_on_nonzero_snapshot_exit() -> None: assert exc_info.value.context["reason"] == "snapshot_nonzero_exit" assert exc_info.value.context["exit_code"] == 2 + assert exc_info.value.retryable is False @pytest.mark.asyncio @@ -2025,8 +2026,10 @@ async def test_e2b_pty_start_maps_timeout_failures( monkeypatch: pytest.MonkeyPatch, ) -> None: sandbox = _FakeE2BSandbox() - timeout_exc = e2b_module._import_e2b_exceptions().get("timeout") - if timeout_exc is None: + timeout_error_types = e2b_module._e2b_timeout_error_types() + if timeout_error_types: + timeout_exc = timeout_error_types[0] + else: class _FakeTimeout(Exception): pass @@ -2034,8 +2037,8 @@ class _FakeTimeout(Exception): timeout_exc = _FakeTimeout monkeypatch.setattr( e2b_module, - "_import_e2b_exceptions", - lambda: {"timeout": _FakeTimeout}, + "_e2b_timeout_error_types", + lambda: (_FakeTimeout,), ) sandbox.pty.create_error = timeout_exc("timed out") state = E2BSandboxSessionState( @@ -2072,8 +2075,8 @@ def __init__(self) -> None: monkeypatch.setattr( e2b_module, - "_import_e2b_exceptions", - lambda: {"timeout": _FakeTimeout}, + "_e2b_timeout_error_types", + lambda: (_FakeTimeout,), ) async def _raise_timeout(*args: object, **kwargs: object) -> object: @@ -2122,10 +2125,10 @@ async def _raise_timeout(*args: object, **kwargs: object) -> object: @pytest.mark.asyncio -async def test_e2b_exec_maps_missing_sandbox_timeout_to_transport_error( +async def test_e2b_exec_maps_missing_sandbox_not_found_to_transport_error( monkeypatch: pytest.MonkeyPatch, ) -> None: - class _FakeTimeout(Exception): + class _FakeNotFound(Exception): pass sandbox = _FakeE2BSandbox() @@ -2140,21 +2143,94 @@ class _FakeTimeout(Exception): monkeypatch.setattr( e2b_module, - "_import_e2b_exceptions", - lambda: {"timeout": _FakeTimeout}, + "_e2b_non_retryable_error_types", + lambda: (_FakeNotFound,), ) + monkeypatch.setattr(e2b_module, "_e2b_retryable_error_types", lambda: ()) + monkeypatch.setattr(e2b_module, "_e2b_timeout_error_types", lambda: ()) - async def _raise_timeout(*args: object, **kwargs: object) -> object: + async def _raise_not_found(*args: object, **kwargs: object) -> object: _ = (args, kwargs) - raise _FakeTimeout("The sandbox was not found: request failed") + raise _FakeNotFound("The sandbox was not found: request failed") - monkeypatch.setattr(e2b_module, "_sandbox_run_command", _raise_timeout) + monkeypatch.setattr(e2b_module, "_sandbox_run_command", _raise_not_found) with pytest.raises(ExecTransportError) as exc_info: await session._exec_internal("python3", "build.py", timeout=2.0) # noqa: SLF001 assert exc_info.value.context["provider_error"] == "The sandbox was not found: request failed" - assert exc_info.value.context["reason"] == "sandbox_not_found" + assert exc_info.value.context["reason"] == "_FakeNotFound" + assert exc_info.value.retryable is False + + +@pytest.mark.asyncio +async def test_e2b_exec_marks_rate_limit_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeRateLimit(Exception): + pass + + sandbox = _FakeE2BSandbox() + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + monkeypatch.setattr(e2b_module, "_e2b_retryable_error_types", lambda: (_FakeRateLimit,)) + monkeypatch.setattr(e2b_module, "_e2b_non_retryable_error_types", lambda: ()) + monkeypatch.setattr(e2b_module, "_e2b_timeout_error_types", lambda: ()) + + async def _raise_rate_limit(*args: object, **kwargs: object) -> object: + _ = (args, kwargs) + raise _FakeRateLimit("rate limit exceeded") + + monkeypatch.setattr(e2b_module, "_sandbox_run_command", _raise_rate_limit) + + with pytest.raises(ExecTransportError) as exc_info: + await session._exec_internal("python3", "build.py", timeout=2.0) # noqa: SLF001 + + assert exc_info.value.context["provider_error"] == "rate limit exceeded" + assert exc_info.value.context["reason"] == "_FakeRateLimit" + assert exc_info.value.retryable is True + + +@pytest.mark.asyncio +async def test_e2b_exec_marks_deterministic_provider_errors_non_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeGitAuth(Exception): + pass + + sandbox = _FakeE2BSandbox() + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + monkeypatch.setattr(e2b_module, "_e2b_retryable_error_types", lambda: ()) + monkeypatch.setattr(e2b_module, "_e2b_non_retryable_error_types", lambda: (_FakeGitAuth,)) + monkeypatch.setattr(e2b_module, "_e2b_timeout_error_types", lambda: ()) + + async def _raise_git_auth(*args: object, **kwargs: object) -> object: + _ = (args, kwargs) + raise _FakeGitAuth("git authentication failed") + + monkeypatch.setattr(e2b_module, "_sandbox_run_command", _raise_git_auth) + + with pytest.raises(ExecTransportError) as exc_info: + await session._exec_internal("python3", "build.py", timeout=2.0) # noqa: SLF001 + + assert exc_info.value.context["provider_error"] == "git authentication failed" + assert exc_info.value.context["reason"] == "_FakeGitAuth" + assert exc_info.value.retryable is False @pytest.mark.asyncio @@ -2212,20 +2288,22 @@ class ReadTimeout(Exception): @pytest.mark.asyncio -async def test_e2b_pty_start_maps_missing_sandbox_timeout_to_transport_error( +async def test_e2b_pty_start_maps_missing_sandbox_not_found_to_transport_error( monkeypatch: pytest.MonkeyPatch, ) -> None: - class _FakeTimeout(Exception): + class _FakeNotFound(Exception): pass monkeypatch.setattr( e2b_module, - "_import_e2b_exceptions", - lambda: {"timeout": _FakeTimeout}, + "_e2b_non_retryable_error_types", + lambda: (_FakeNotFound,), ) + monkeypatch.setattr(e2b_module, "_e2b_retryable_error_types", lambda: ()) + monkeypatch.setattr(e2b_module, "_e2b_timeout_error_types", lambda: ()) sandbox = _FakeE2BSandbox() - sandbox.pty.create_error = _FakeTimeout("The sandbox was not found: request failed") + sandbox.pty.create_error = _FakeNotFound("The sandbox was not found: request failed") state = E2BSandboxSessionState( session_id=uuid.uuid4(), manifest=Manifest(root="/workspace"), @@ -2239,4 +2317,5 @@ class _FakeTimeout(Exception): await session.pty_exec_start("python3", shell=False, tty=True, timeout=2.0) assert exc_info.value.context["provider_error"] == "The sandbox was not found: request failed" - assert exc_info.value.context["reason"] == "sandbox_not_found" + assert exc_info.value.context["reason"] == "_FakeNotFound" + assert exc_info.value.retryable is False diff --git a/tests/extensions/sandbox/test_modal.py b/tests/extensions/sandbox/test_modal.py index b57e3bfa0c..48c3f5b905 100644 --- a/tests/extensions/sandbox/test_modal.py +++ b/tests/extensions/sandbox/test_modal.py @@ -318,6 +318,24 @@ def override_locally(key: str, value: str) -> None: _FakeConfig.override_calls.append((key, value)) os.environ["MODAL_" + key.upper()] = value + class _FakeModalError(Exception): + pass + + class _FakeModalConnectionError(_FakeModalError): + pass + + class _FakeModalExecTimeoutError(TimeoutError): + pass + + class _FakeModalInternalFailure(_FakeModalError): + pass + + class _FakeModalInvalidError(_FakeModalError): + pass + + class _FakeModalNotFoundError(_FakeModalError): + pass + _FakeSandbox.create = staticmethod(_with_aio(_FakeSandbox._create)) _FakeSandbox.from_id = staticmethod(_with_aio(_FakeSandbox._from_id)) _FakeApp.lookup = staticmethod(_with_aio(_FakeApp._lookup)) @@ -329,6 +347,14 @@ def override_locally(key: str, value: str) -> None: fake_modal.Secret = _FakeSecret fake_modal.CloudBucketMount = _FakeCloudBucketMount + fake_modal_exception: Any = types.ModuleType("modal.exception") + fake_modal_exception.ConnectionError = _FakeModalConnectionError + fake_modal_exception.ExecTimeoutError = _FakeModalExecTimeoutError + fake_modal_exception.InternalFailure = _FakeModalInternalFailure + fake_modal_exception.InvalidError = _FakeModalInvalidError + fake_modal_exception.NotFoundError = _FakeModalNotFoundError + fake_modal.exception = fake_modal_exception + fake_modal_config: Any = types.ModuleType("modal.config") fake_modal_config.config = _FakeConfig @@ -336,6 +362,7 @@ def override_locally(key: str, value: str) -> None: fake_container_process.ContainerProcess = object monkeypatch.setitem(sys.modules, "modal", fake_modal) + monkeypatch.setitem(sys.modules, "modal.exception", fake_modal_exception) monkeypatch.setitem(sys.modules, "modal.config", fake_modal_config) monkeypatch.setitem(sys.modules, "modal.container_process", fake_container_process) sys.modules.pop("agents.extensions.sandbox.modal.sandbox", None) @@ -3298,6 +3325,72 @@ def _exec(self, *command: object, **kwargs: object) -> object: assert exc_info.value.context["provider_error"] == "FileNotFoundError: missing-shell" +@pytest.mark.asyncio +async def test_modal_pty_start_marks_typed_not_found_non_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FailingSandbox: + object_id = "sb-fail" + + def __init__(self) -> None: + self.exec = _with_aio(self._exec) + + def _exec(self, *command: object, **kwargs: object) -> object: + _ = (command, kwargs) + raise modal_module.modal.exception.NotFoundError("sandbox not found") + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id="sb-fail", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=_FailingSandbox()) + + with pytest.raises(modal_module.ExecTransportError) as exc_info: + await session.pty_exec_start("python3", shell=False, tty=True) + + assert exc_info.value.retryable is False + assert exc_info.value.context["backend"] == "modal" + assert exc_info.value.context["reason"] == "_FakeModalNotFoundError" + assert exc_info.value.context["provider_error"] == "_FakeModalNotFoundError: sandbox not found" + + +@pytest.mark.asyncio +async def test_modal_pty_start_marks_typed_internal_failure_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _FailingSandbox: + object_id = "sb-fail" + + def __init__(self) -> None: + self.exec = _with_aio(self._exec) + + def _exec(self, *command: object, **kwargs: object) -> object: + _ = (command, kwargs) + raise modal_module.modal.exception.InternalFailure("internal failure") + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id="sb-fail", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=_FailingSandbox()) + + with pytest.raises(modal_module.ExecTransportError) as exc_info: + await session.pty_exec_start("python3", shell=False, tty=True) + + assert exc_info.value.retryable is True + assert exc_info.value.context["backend"] == "modal" + assert exc_info.value.context["reason"] == "_FakeModalInternalFailure" + assert exc_info.value.context["provider_error"] == "_FakeModalInternalFailure: internal failure" + + @pytest.mark.asyncio async def test_modal_start_wraps_exec_details( monkeypatch: pytest.MonkeyPatch, @@ -3360,6 +3453,36 @@ def _exec(self, *command: object, **kwargs: object) -> object: await session.pty_exec_start("python3", shell=False, tty=True, timeout=2.0) +@pytest.mark.asyncio +async def test_modal_pty_start_maps_modal_exec_timeout_failures( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + class _TimeoutSandbox: + object_id = "sb-timeout" + + def __init__(self) -> None: + self.exec = _with_aio(self._exec) + + def _exec(self, *command: object, **kwargs: object) -> object: + _ = (command, kwargs) + raise modal_module.modal.exception.ExecTimeoutError("command timed out") + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + sandbox_id="sb-timeout", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=_TimeoutSandbox()) + + with pytest.raises(modal_module.ExecTimeoutError) as exc_info: + await session.pty_exec_start("python3", shell=False, tty=True, timeout=2.0) + + assert exc_info.value.retryable is False + + @pytest.mark.asyncio async def test_modal_pty_start_cleans_up_unregistered_process_on_cancellation( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/extensions/sandbox/test_runloop.py b/tests/extensions/sandbox/test_runloop.py index b28965a43c..ec6999d252 100644 --- a/tests/extensions/sandbox/test_runloop.py +++ b/tests/extensions/sandbox/test_runloop.py @@ -161,6 +161,42 @@ def __init__( ) +class _FakeAuthenticationError(_FakeAPIStatusError): + def __init__( + self, + message: str = "authentication failed", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(401, body=body, url=url, method=method, message=message) + + +class _FakeBadRequestError(_FakeAPIStatusError): + def __init__( + self, + message: str = "bad request", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(400, body=body, url=url, method=method, message=message) + + +class _FakeInternalServerError(_FakeAPIStatusError): + def __init__( + self, + message: str = "internal server error", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(500, body=body, url=url, method=method, message=message) + + class _FakeNotFoundError(_FakeAPIStatusError): def __init__( self, @@ -173,6 +209,42 @@ def __init__( super().__init__(404, body=body, url=url, method=method, message=message) +class _FakePermissionDeniedError(_FakeAPIStatusError): + def __init__( + self, + message: str = "permission denied", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(403, body=body, url=url, method=method, message=message) + + +class _FakeRateLimitError(_FakeAPIStatusError): + def __init__( + self, + message: str = "rate limited", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(429, body=body, url=url, method=method, message=message) + + +class _FakeUnprocessableEntityError(_FakeAPIStatusError): + def __init__( + self, + message: str = "unprocessable entity", + *, + body: object | None = None, + url: str = "https://api.runloop.ai/v1/test", + method: str = "POST", + ) -> None: + super().__init__(422, body=body, url=url, method=method, message=message) + + class _FakeExecutionResult: def __init__(self, *, stdout: str = "", stderr: str = "", exit_code: int | None = 0) -> None: self._stdout = stdout @@ -1165,8 +1237,14 @@ def _load_runloop_module(monkeypatch: pytest.MonkeyPatch) -> Any: fake_runloop.APIResponseValidationError = _FakeAPIResponseValidationError fake_runloop.APITimeoutError = _FakeAPITimeoutError fake_runloop.APIStatusError = _FakeAPIStatusError + fake_runloop.AuthenticationError = _FakeAuthenticationError + fake_runloop.BadRequestError = _FakeBadRequestError + fake_runloop.InternalServerError = _FakeInternalServerError fake_runloop.NotFoundError = _FakeNotFoundError + fake_runloop.PermissionDeniedError = _FakePermissionDeniedError + fake_runloop.RateLimitError = _FakeRateLimitError fake_runloop.RunloopError = _FakeRunloopError + fake_runloop.UnprocessableEntityError = _FakeUnprocessableEntityError fake_sdk: Any = types.ModuleType("runloop_api_client.sdk") fake_sdk.AsyncRunloopSDK = _FakeAsyncRunloopSDK @@ -2338,6 +2416,66 @@ async def _raise_rate_limit(*args: object, **kwargs: object) -> object: assert exc_info.value.context["cause_type"] == "_FakeAPIStatusError" assert exc_info.value.context["provider_body"] == {"error": "rate limited"} assert exc_info.value.context["detail"] == "exec_failed" + assert exc_info.value.retryable is True + + @pytest.mark.asyncio + async def test_exec_marks_typed_runloop_bad_request_non_retryable( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + runloop_module = _load_runloop_module(monkeypatch) + + async with runloop_module.RunloopSandboxClient() as client: + session = await client.create(options=runloop_module.RunloopSandboxClientOptions()) + await session.start() + sdk = _FakeAsyncRunloopSDK.created_instances[-1] + devbox = sdk.devbox.devboxes[session.state.devbox_id] + + async def _raise_bad_request(*args: object, **kwargs: object) -> object: + _ = (args, kwargs) + raise _FakeBadRequestError( + body={"error": "invalid command"}, + url=f"https://api.runloop.ai/v1/devboxes/{devbox.id}/execute", + method="POST", + ) + + monkeypatch.setattr(devbox.cmd, "exec", _raise_bad_request) + + with pytest.raises(runloop_module.ExecTransportError) as exc_info: + await session.exec("pwd", shell=False) + + assert exc_info.value.context["http_status"] == 400 + assert exc_info.value.context["cause_type"] == "_FakeBadRequestError" + assert exc_info.value.context["provider_body"] == {"error": "invalid command"} + assert exc_info.value.context["detail"] == "exec_failed" + assert exc_info.value.retryable is False + + @pytest.mark.parametrize( + ("status", "expected_retryable"), + [ + (400, False), + (401, False), + (403, False), + (404, False), + (408, True), + (422, False), + (429, True), + (500, True), + (502, True), + (503, True), + (504, True), + ], + ) + def test_runloop_retryability_status_table( + self, + monkeypatch: pytest.MonkeyPatch, + status: int, + expected_retryable: bool, + ) -> None: + runloop_module = _load_runloop_module(monkeypatch) + error = _FakeAPIStatusError(status, body={"error": f"HTTP {status}"}) + + assert runloop_module._runloop_provider_retryability(error) is expected_retryable @pytest.mark.asyncio async def test_exec_wraps_command_with_workspace_context( @@ -2486,6 +2624,7 @@ async def _raise_download_error(**kwargs: object) -> bytes: assert exc_info.value.context["cause_type"] == "_FakeAPIStatusError" assert exc_info.value.context["provider_body"] == {"error": "download failed"} assert exc_info.value.context["detail"] == "file_download_failed" + assert exc_info.value.retryable is True @pytest.mark.asyncio async def test_write_wraps_runloop_http_error_with_provider_context( @@ -2518,6 +2657,7 @@ async def _raise_upload_error(**kwargs: object) -> object: assert exc_info.value.context["cause_type"] == "_FakeAPIStatusError" assert exc_info.value.context["provider_body"] == {"error": "upload rate limited"} assert exc_info.value.context["detail"] == "file_upload_failed" + assert exc_info.value.retryable is True @pytest.mark.asyncio async def test_manifest_apply_preserves_existing_files_in_non_empty_directory( diff --git a/tests/extensions/sandbox/test_vercel.py b/tests/extensions/sandbox/test_vercel.py index 71c4130b4c..0430c23d13 100644 --- a/tests/extensions/sandbox/test_vercel.py +++ b/tests/extensions/sandbox/test_vercel.py @@ -55,6 +55,48 @@ class SnapshotSource(BaseModel): snapshot_id: str +class _FakeVercelSandboxError(Exception): + pass + + +class _FakeVercelAPIError(_FakeVercelSandboxError): + def __init__(self, message: str, *, status_code: int, data: object | None = None) -> None: + super().__init__(message) + self.status_code = status_code + self.response = types.SimpleNamespace(status_code=status_code) + self.data = data + + +class _FakeVercelSandboxAuthError(_FakeVercelAPIError): + def __init__(self, message: str = "auth failed", *, data: object | None = None) -> None: + super().__init__(message, status_code=401, data=data) + + +class _FakeVercelSandboxNotFoundError(_FakeVercelAPIError): + def __init__(self, message: str = "not found", *, data: object | None = None) -> None: + super().__init__(message, status_code=404, data=data) + + +class _FakeVercelSandboxPermissionError(_FakeVercelAPIError): + def __init__(self, message: str = "permission denied", *, data: object | None = None) -> None: + super().__init__(message, status_code=403, data=data) + + +class _FakeVercelSandboxRateLimitError(_FakeVercelAPIError): + def __init__(self, message: str = "rate limited", *, data: object | None = None) -> None: + super().__init__(message, status_code=429, data=data) + + +class _FakeVercelSandboxServerError(_FakeVercelAPIError): + def __init__(self, message: str = "server error", *, data: object | None = None) -> None: + super().__init__(message, status_code=500, data=data) + + +class _FakeVercelSandboxValidationError(_FakeVercelSandboxError): + def __init__(self, message: str = "validation failed") -> None: + super().__init__(message) + + class _MemorySnapshot(SnapshotBase): type: Literal["test-vercel-memory"] = "test-vercel-memory" payload: bytes = b"" @@ -371,8 +413,15 @@ def _load_vercel_module(monkeypatch: pytest.MonkeyPatch) -> Any: fake_vercel_sandbox.NetworkPolicyRule = NetworkPolicyRule fake_vercel_sandbox.NetworkPolicySubnets = NetworkPolicySubnets fake_vercel_sandbox.Resources = Resources + fake_vercel_sandbox.SandboxAuthError = _FakeVercelSandboxAuthError + fake_vercel_sandbox.SandboxNotFoundError = _FakeVercelSandboxNotFoundError + fake_vercel_sandbox.SandboxPermissionError = _FakeVercelSandboxPermissionError + fake_vercel_sandbox.SandboxRateLimitError = _FakeVercelSandboxRateLimitError + fake_vercel_sandbox.SandboxServerError = _FakeVercelSandboxServerError fake_vercel_sandbox.SandboxStatus = types.SimpleNamespace(RUNNING="running") + fake_vercel_sandbox.SandboxValidationError = _FakeVercelSandboxValidationError fake_vercel_sandbox.SnapshotSource = SnapshotSource + cast(Any, fake_vercel).sandbox = fake_vercel_sandbox monkeypatch.setitem(sys.modules, "vercel", fake_vercel) monkeypatch.setitem(sys.modules, "vercel.sandbox", fake_vercel_sandbox) @@ -541,6 +590,112 @@ async def test_vercel_exec_read_write_and_port_resolution(monkeypatch: pytest.Mo assert payload.read() == b"payload" +@pytest.mark.asyncio +async def test_vercel_exec_marks_typed_not_found_non_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + vercel_module = _load_vercel_module(monkeypatch) + state = vercel_module.VercelSandboxSessionState( + session_id="00000000-0000-0000-0000-000000000120", + manifest=Manifest(), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id="sandbox-exec-missing", + ) + sandbox = _FakeAsyncSandbox(sandbox_id="sandbox-exec-missing") + session = vercel_module.VercelSandboxSession.from_state(state, sandbox=sandbox) + + async def _raise_not_found(*args: object, **kwargs: object) -> object: + _ = (args, kwargs) + raise vercel_module.vercel_sandbox.SandboxNotFoundError("sandbox missing") + + monkeypatch.setattr(sandbox, "run_command", _raise_not_found) + + with pytest.raises(vercel_module.ExecTransportError) as exc_info: + await session.exec("pwd", shell=False) + + assert exc_info.value.retryable is False + + +@pytest.mark.asyncio +async def test_vercel_exec_marks_typed_rate_limit_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + vercel_module = _load_vercel_module(monkeypatch) + state = vercel_module.VercelSandboxSessionState( + session_id="00000000-0000-0000-0000-000000000121", + manifest=Manifest(), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id="sandbox-exec-rate-limit", + ) + sandbox = _FakeAsyncSandbox(sandbox_id="sandbox-exec-rate-limit") + session = vercel_module.VercelSandboxSession.from_state(state, sandbox=sandbox) + + async def _raise_rate_limit(*args: object, **kwargs: object) -> object: + _ = (args, kwargs) + raise vercel_module.vercel_sandbox.SandboxRateLimitError("rate limited") + + monkeypatch.setattr(sandbox, "run_command", _raise_rate_limit) + + with pytest.raises(vercel_module.ExecTransportError) as exc_info: + await session.exec("pwd", shell=False) + + assert exc_info.value.retryable is True + + +@pytest.mark.asyncio +async def test_vercel_write_marks_typed_validation_error_non_retryable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + vercel_module = _load_vercel_module(monkeypatch) + monkeypatch.setattr("agents.sandbox.util.retry.asyncio.sleep", _noop_sleep) + + state = vercel_module.VercelSandboxSessionState( + session_id="00000000-0000-0000-0000-000000000122", + manifest=Manifest(), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id="sandbox-write-validation", + ) + sandbox = _FakeAsyncSandbox(sandbox_id="sandbox-write-validation") + sandbox.write_failures = [vercel_module.vercel_sandbox.SandboxValidationError("invalid write")] + session = vercel_module.VercelSandboxSession.from_state(state, sandbox=sandbox) + + with pytest.raises(vercel_module.WorkspaceArchiveWriteError) as exc_info: + await session.write(Path("hello.txt"), io.BytesIO(b"world")) + + assert len(sandbox.write_files_calls) == 1 + assert exc_info.value.retryable is False + + +@pytest.mark.parametrize( + ("status", "expected_retryable"), + [ + (400, False), + (401, False), + (403, False), + (404, False), + (408, True), + (425, True), + (422, False), + (429, True), + (500, True), + (502, True), + (503, True), + (504, True), + ], +) +def test_vercel_retryability_status_table( + monkeypatch: pytest.MonkeyPatch, + status: int, + expected_retryable: bool, +) -> None: + vercel_module = _load_vercel_module(monkeypatch) + + class FakeStatusError(Exception): + status_code = status + + assert vercel_module._vercel_provider_retryability(FakeStatusError()) is expected_retryable + + @pytest.mark.asyncio async def test_vercel_start_uses_base_session_contract_and_materializes_workspace( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/sandbox/test_errors.py b/tests/sandbox/test_errors.py new file mode 100644 index 0000000000..5ce3635c54 --- /dev/null +++ b/tests/sandbox/test_errors.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from pathlib import Path + +from agents.sandbox.errors import ( + ErrorCode, + ExecTimeoutError, + GitCloneError, + GitCopyError, + SandboxError, + SnapshotPersistError, + SnapshotRestoreError, + WorkspaceArchiveReadError, + WorkspaceReadNotFoundError, + WorkspaceStopError, + WorkspaceWriteTypeError, +) + + +def test_sandbox_error_retryable_can_be_set_explicitly() -> None: + error = SandboxError( + message="backend is unavailable", + error_code=ErrorCode.EXEC_TRANSPORT_ERROR, + op="exec", + context={}, + retryable=True, + ) + + assert error.retryable is True + + +def test_wrapped_sandbox_error_inherits_retryable_from_cause() -> None: + cause = WorkspaceArchiveReadError( + path=Path("/workspace"), + retryable=False, + ) + + error = WorkspaceStopError(path=Path("/workspace"), cause=cause) + + assert error.retryable is False + + +def test_deterministic_sandbox_errors_are_non_retryable() -> None: + assert WorkspaceReadNotFoundError(path=Path("/workspace/missing.txt")).retryable is False + assert ( + WorkspaceWriteTypeError(path=Path("/workspace/out.txt"), actual_type="str").retryable + is False + ) + assert ExecTimeoutError(command=("python", "script.py"), timeout_s=1.0).retryable is False + + +def test_broad_archive_errors_default_to_unknown_retryability() -> None: + error = WorkspaceArchiveReadError(path=Path("/workspace")) + + assert error.retryable is None + + +def test_broad_materialization_and_snapshot_errors_default_to_unknown_retryability() -> None: + assert GitCloneError(url="https://example.test/repo.git", ref="main").retryable is None + assert GitCopyError(src_root="/tmp/repo", dest=Path("/workspace")).retryable is None + assert SnapshotPersistError(snapshot_id="snap", path=Path("/tmp/snap")).retryable is None + assert SnapshotRestoreError(snapshot_id="snap", path=Path("/tmp/snap")).retryable is None diff --git a/tests/sandbox/test_session_sinks.py b/tests/sandbox/test_session_sinks.py index 6c58a76c30..420f142f3e 100644 --- a/tests/sandbox/test_session_sinks.py +++ b/tests/sandbox/test_session_sinks.py @@ -11,6 +11,7 @@ from inline_snapshot import snapshot from agents.sandbox.entries import Dir, File +from agents.sandbox.errors import WorkspaceReadNotFoundError from agents.sandbox.manifest import Manifest from agents.sandbox.sandboxes.unix_local import ( UnixLocalSandboxSession, @@ -31,7 +32,7 @@ from agents.sandbox.session.base_sandbox_session import BaseSandboxSession from agents.sandbox.snapshot import LocalSnapshot from agents.tracing import custom_span, trace -from tests.testing_processor import fetch_normalized_spans +from tests.testing_processor import fetch_normalized_spans, fetch_ordered_spans def _build_unix_local_session( @@ -362,6 +363,45 @@ def _callback(event: SandboxSessionEvent, session: BaseSandboxSession) -> None: assert all(session is inner for _op, session in seen) +@pytest.mark.asyncio +async def test_sandbox_session_error_events_and_traces_include_retryability( + tmp_path: Path, +) -> None: + events: list[SandboxSessionEvent] = [] + instrumentation = Instrumentation( + sinks=[CallbackSink(lambda e, _sess: events.append(e), mode="sync")] + ) + inner = _build_unix_local_session(tmp_path) + + with trace("sandbox_retryability_test"): + async with SandboxSession(inner, instrumentation=instrumentation) as session: + with pytest.raises(WorkspaceReadNotFoundError): + await session.read(Path("missing.txt")) + + read_finish = [event for event in events if event.op == "read" and event.phase == "finish"][0] + assert isinstance(read_finish, SandboxSessionFinishEvent) + assert read_finish.error_retryable is False + + spans = fetch_normalized_spans() + read_span = next( + child for child in spans[0]["children"] if child["data"]["name"] == "sandbox.read" + ) + span_data = read_span["data"] + assert isinstance(span_data, dict) + span_payload = span_data["data"] + assert isinstance(span_payload, dict) + assert span_payload["error_retryable"] is False + + raw_read_span = next( + span for span in fetch_ordered_spans() if span.span_data.export()["name"] == "sandbox.read" + ) + span_error = raw_read_span.error + assert span_error is not None + error_payload = span_error["data"] + assert isinstance(error_payload, dict) + assert error_payload["error_retryable"] is False + + @pytest.mark.asyncio async def test_sandbox_session_ops_nest_under_sdk_trace_and_events_carry_trace_ids( tmp_path: Path, From cbd9b62d4325ae92dce3e002799c8e2b9a58f337 Mon Sep 17 00:00:00 2001 From: Qiyao Qin Date: Thu, 4 Jun 2026 18:11:03 -0700 Subject: [PATCH 2/2] Fix Modal tar persist retry classification --- .../extensions/sandbox/modal/sandbox.py | 20 +++-- tests/extensions/sandbox/test_modal.py | 81 +++++++++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/src/agents/extensions/sandbox/modal/sandbox.py b/src/agents/extensions/sandbox/modal/sandbox.py index 3468ae4b00..804a745aee 100644 --- a/src/agents/extensions/sandbox/modal/sandbox.py +++ b/src/agents/extensions/sandbox/modal/sandbox.py @@ -40,6 +40,7 @@ ExecTransportError, ExposedPortUnavailableError, MountConfigError, + SandboxError, WorkspaceArchiveReadError, WorkspaceArchiveWriteError, WorkspaceReadNotFoundError, @@ -187,6 +188,17 @@ def _modal_provider_retryability(error: BaseException) -> tuple[bool | None, str return None, None +def _modal_tar_persist_retryable(exc: BaseException) -> bool: + for candidate in iter_exception_chain(exc): + if isinstance(candidate, SandboxError) and candidate.retryable is False: + return False + + if exception_chain_contains_type(exc, (ExecTransportError,)): + return True + + return exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) + + def _modal_exec_transport_error( *, command: tuple[str | Path, ...], @@ -1629,13 +1641,7 @@ def _modal_tar_skip_relpaths(self, root: Path) -> set[Path]: continue return skip - @retry_async( - retry_if=lambda exc, self: ( - isinstance(exc, ExecTransportError) - and exc.retryable is not False - or exception_chain_has_status_code(exc, TRANSIENT_HTTP_STATUS_CODES) - ) - ) + @retry_async(retry_if=lambda exc, self: _modal_tar_persist_retryable(exc)) async def _persist_workspace_via_tar(self) -> io.IOBase: # Existing tar implementation extracted so snapshot_filesystem mode can fall back cleanly. root = self._workspace_root_path() diff --git a/tests/extensions/sandbox/test_modal.py b/tests/extensions/sandbox/test_modal.py index 48c3f5b905..bde186d9aa 100644 --- a/tests/extensions/sandbox/test_modal.py +++ b/tests/extensions/sandbox/test_modal.py @@ -2542,6 +2542,87 @@ async def _fake_exec( ] +@pytest.mark.asyncio +async def test_modal_tar_persist_retries_wrapped_exec_transport_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=None) + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + commands.append(rendered) + if len(commands) == 1: + raise modal_module.ExecTransportError( + command=tuple(rendered), + message="modal transport failed", + ) + return ExecResult(stdout=b"tar-bytes", stderr=b"", exit_code=0) + + monkeypatch.setattr(session, "exec", _fake_exec) + + archive = await session.persist_workspace() + + assert archive.read() == b"tar-bytes" + assert commands == [ + ["tar", "cf", "-", "-C", "/workspace", "."], + ["tar", "cf", "-", "-C", "/workspace", "."], + ] + + +@pytest.mark.asyncio +async def test_modal_tar_persist_does_not_retry_wrapped_non_retryable_exec_transport_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + modal_module, _create_calls, _registry_tags = _load_modal_module(monkeypatch) + + state = modal_module.ModalSandboxSessionState( + manifest=Manifest(root="/workspace"), + snapshot=modal_module.resolve_snapshot(None, "snapshot"), + app_name="sandbox-tests", + ) + session = modal_module.ModalSandboxSession.from_state(state, sandbox=None) + commands: list[list[str]] = [] + + async def _fake_exec( + *command: object, + timeout: float | None = None, + shell: bool | list[str] = True, + user: object | None = None, + ) -> ExecResult: + _ = (timeout, shell, user) + rendered = [str(part) for part in command] + commands.append(rendered) + raise modal_module.ExecTransportError( + command=tuple(rendered), + message="modal transport failed permanently", + retryable=False, + ) + + monkeypatch.setattr(session, "exec", _fake_exec) + + with pytest.raises(WorkspaceArchiveReadError) as exc_info: + await session.persist_workspace() + + assert str(exc_info.value) == "failed to read archive for path: /workspace" + assert isinstance(exc_info.value.cause, modal_module.ExecTransportError) + assert str(exc_info.value.cause) == "modal transport failed permanently" + assert commands == [["tar", "cf", "-", "-C", "/workspace", "."]] + + @pytest.mark.asyncio async def test_modal_snapshot_filesystem_rejects_escaping_mount_paths_before_exec( monkeypatch: pytest.MonkeyPatch,