diff --git a/src/livepeer_gateway/__init__.py b/src/livepeer_gateway/__init__.py index 3baf9f5..0145d8e 100644 --- a/src/livepeer_gateway/__init__.py +++ b/src/livepeer_gateway/__init__.py @@ -43,6 +43,7 @@ from .lv2v import LiveVideoToVideo, StartJobRequest, start_lv2v from .live_runner import ( LiveRunnerCallResult, + LiveRunnerCallStream, LiveRunnerGPU, LiveRunnerInstance, LiveRunnerPriceInfo, @@ -96,6 +97,7 @@ "get_orch_info", "LiveVideoToVideo", "LiveRunnerCallResult", + "LiveRunnerCallStream", "LiveRunnerGPU", "LiveRunnerInstance", "LiveRunnerPriceInfo", diff --git a/src/livepeer_gateway/http.py b/src/livepeer_gateway/http.py index 0b3566d..043d88a 100644 --- a/src/livepeer_gateway/http.py +++ b/src/livepeer_gateway/http.py @@ -300,6 +300,46 @@ async def request_json( return data +async def open_stream( + url: str, + *, + method: Optional[str] = None, + payload: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + connect_timeout: float = 10.0, +) -> "tuple[aiohttp.ClientSession, aiohttp.ClientResponse]": + """ + Open an HTTP request and return the live (session, response) without reading the + body, for streaming responses (SSE, chunked). The caller owns both and must close + them. + + No total timeout (streams run indefinitely) only connect/first-byte are bounded. + Raises LivepeerHTTPError on >= 400 (e.g. the 402 payment retry). + """ + resolved_method, req_headers, body = _json_request_parts( + url, + method=method, + payload=payload, + headers=headers, + ) + + timeout = aiohttp.ClientTimeout(total=None, sock_connect=connect_timeout, sock_read=None) + session = aiohttp.ClientSession(timeout=timeout, connector=aiohttp.TCPConnector(ssl=False)) + try: + resp = await session.request(resolved_method, url, data=body, headers=req_headers) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + await session.close() + raise LivepeerGatewayError( + f"HTTP stream error: failed to reach endpoint: {getattr(e, 'message', e)} (url={url})" + ) from e + if resp.status >= 400: + raw = await resp.text() + resp.release() + await session.close() + _raise_http_json_error(resp.status, url, raw, dict(resp.headers.items())) + return session, resp + + async def post_json( url: str, payload: dict[str, Any], diff --git a/src/livepeer_gateway/live_runner.py b/src/livepeer_gateway/live_runner.py index 5a38e2e..f9a1a27 100644 --- a/src/livepeer_gateway/live_runner.py +++ b/src/livepeer_gateway/live_runner.py @@ -9,14 +9,27 @@ import shutil import subprocess from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Literal, NotRequired, Optional, Protocol, TypedDict, cast +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Literal, + Mapping, + NotRequired, + Optional, + Protocol, + TypedDict, + cast, + overload, +) from urllib.parse import quote, urlparse, urlunparse import aiohttp from .channel_reader import ChannelReader from .errors import LivepeerGatewayError, LivepeerHTTPError, SignerRefreshRequired -from .http import post_json, request_json +from .http import open_stream, post_json, request_json from .remote_signer import ( GetPaymentResponse, LivePaymentSession, @@ -104,6 +117,46 @@ class LiveRunnerCallResult: ) +@dataclass +class LiveRunnerCallStream: + """A streaming live-runner response (SSE / chunked). + + Returned by ``call_runner(..., stream=True)``. It owns the underlying HTTP session + and response, so use it as an async context manager (or call ``aclose()``) to + release the connection. + """ + + status: int + headers: Mapping[str, str] + runner_url: str + runner: Optional[LiveRunnerInstance] + payment_session: Optional[LivePaymentSession] + _session: aiohttp.ClientSession = field(repr=False, compare=False) + _response: aiohttp.ClientResponse = field(repr=False, compare=False) + + @property + def content_type(self) -> str: + return self._response.content_type + + async def aiter_bytes(self) -> AsyncIterator[bytes]: + async for chunk in self._response.content.iter_any(): + yield chunk + + async def aiter_lines(self) -> AsyncIterator[str]: + async for line in self._response.content: + yield line.decode(errors="replace").rstrip("\n") + + async def aclose(self) -> None: + self._response.release() + await self._session.close() + + async def __aenter__(self) -> LiveRunnerCallStream: + return self + + async def __aexit__(self, *exc: object) -> None: + await self.aclose() + + @dataclass(frozen=True) class LiveRunnerGPU: id: str = "" @@ -185,7 +238,7 @@ def __init__( self._task: Optional[asyncio.Task[None]] = None self._o2r_task: Optional[asyncio.Task[None]] = None - async def start(self) -> "LiveRunnerRegistration": + async def start(self) -> LiveRunnerRegistration: await self._send_heartbeat() self._task = asyncio.create_task(self._heartbeat_loop()) return self @@ -233,7 +286,7 @@ async def close(self) -> None: except Exception: _LOG.debug("Live runner unregister failed", exc_info=True) - async def __aenter__(self) -> "LiveRunnerRegistration": + async def __aenter__(self) -> LiveRunnerRegistration: return self async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: @@ -554,6 +607,36 @@ async def remove_trickle_channels( return deleted +@overload +async def call_runner( + runner_url: str = ..., + *, + runner: Optional[LiveRunnerInstance] = ..., + payload: Optional[dict[str, Any]] = ..., + method: str = ..., + signer_url: Optional[str] = ..., + signer_headers: Optional[dict[str, str]] = ..., + timeout: float = ..., + max_payment_challenge_retries: int = ..., + stream: Literal[False] = False, +) -> LiveRunnerCallResult: ... + + +@overload +async def call_runner( + runner_url: str = ..., + *, + runner: Optional[LiveRunnerInstance] = ..., + payload: Optional[dict[str, Any]] = ..., + method: str = ..., + signer_url: Optional[str] = ..., + signer_headers: Optional[dict[str, str]] = ..., + timeout: float = ..., + max_payment_challenge_retries: int = ..., + stream: Literal[True], +) -> LiveRunnerCallStream: ... + + async def call_runner( runner_url: str = "", *, @@ -564,7 +647,14 @@ async def call_runner( signer_headers: Optional[dict[str, str]] = None, timeout: float = 5.0, max_payment_challenge_retries: int = 3, -) -> LiveRunnerCallResult: + stream: bool = False, +) -> LiveRunnerCallResult | LiveRunnerCallStream: + """Call a runner once and return its result (or a live stream if ``stream=True``). + + With ``signer_url`` set, payment is automatic and **per call**: a 402 challenge is + paid via the signer and retried (up to ``max_payment_challenge_retries``), one job, + one upfront payment. Raises ``LivepeerHTTPError`` on non-402 errors. + """ runner_url = runner_url.strip() or (runner.url.strip() if runner is not None else "") if not runner_url: raise LivepeerGatewayError("Live runner call requires runner_url") @@ -607,6 +697,20 @@ async def call_runner( request_kwargs: dict[str, Any] = {"timeout": timeout} if request_headers: request_kwargs["headers"] = request_headers + + if stream: + # Hand back the live response unbuffered. open_stream raises on a 402 + # before any body, so the payment retry below still catches it. + session, resp = await open_stream( + runner_url, + method=method, + payload=request_payload, + headers=request_headers or None, + ) + return LiveRunnerCallStream( + resp.status, resp.headers, runner_url, runner, payment_session, session, resp, + ) + data = await request_json( runner_url, method=method,