From 88e845edcd665d2fa12ad6955bbb203ca4fa0d4b Mon Sep 17 00:00:00 2001 From: Rasmus Faber-Espensen Date: Mon, 25 May 2026 23:04:06 +0200 Subject: [PATCH] Add bounded transcript checkpoint event store --- .../_display/textual/widgets/transcript.py | 6 +- src/inspect_ai/_eval/score.py | 2 +- src/inspect_ai/_eval/task/run.py | 63 +- src/inspect_ai/agent/_acp/transport_live.py | 6 +- .../log/_recorders/buffer/database.py | 83 +- .../log/_recorders/buffer/filestore.py | 6 + .../log/_recorders/buffer/history_provider.py | 119 +++ src/inspect_ai/log/_recorders/buffer/types.py | 9 + src/inspect_ai/log/_recorders/streaming.py | 29 +- src/inspect_ai/log/_samples.py | 6 +- src/inspect_ai/log/_transcript.py | 530 +++++++++++-- .../util/_checkpoint/_event_store.py | 454 +++++++++++ .../util/_checkpoint/_layout/host_context.py | 23 - .../util/_checkpoint/checkpointer.py | 15 + .../util/_checkpoint/checkpointer_factory.py | 5 +- .../util/_checkpoint/checkpointer_impl.py | 377 ++++----- .../util/_checkpoint/checkpointer_noop.py | 3 + src/inspect_ai/util/_checkpoint/hydrate.py | 82 +- tests/_eval/test_retry_error_events.py | 275 +++++++ .../checkpoint/test_checkpoint_event_store.py | 447 +++++++++++ tests/checkpoint/test_checkpointer.py | 742 +++++++++++++++--- tests/display/test_textual_transcript.py | 99 +++ tests/log/test_transcript_bounded.py | 684 ++++++++++++++++ tests/test_helpers/transcript.py | 52 ++ 24 files changed, 3626 insertions(+), 491 deletions(-) create mode 100644 src/inspect_ai/log/_recorders/buffer/history_provider.py create mode 100644 src/inspect_ai/util/_checkpoint/_event_store.py create mode 100644 tests/_eval/test_retry_error_events.py create mode 100644 tests/checkpoint/test_checkpoint_event_store.py create mode 100644 tests/display/test_textual_transcript.py create mode 100644 tests/log/test_transcript_bounded.py create mode 100644 tests/test_helpers/transcript.py diff --git a/src/inspect_ai/_display/textual/widgets/transcript.py b/src/inspect_ai/_display/textual/widgets/transcript.py index bcfacd5bb1..9edbbc3e1f 100644 --- a/src/inspect_ai/_display/textual/widgets/transcript.py +++ b/src/inspect_ai/_display/textual/widgets/transcript.py @@ -81,7 +81,7 @@ async def sync_sample(self, sample: ActiveSample | None) -> None: # if we have either a new sample or a new event count then proceed if ( sample.id != self._sample_id - or len(sample.transcript.events) != self._sample_events + or sample.transcript.event_count != self._sample_events ): # update (scrolling to end if we are already close to it) new_sample = sample.id != self._sample_id @@ -92,14 +92,14 @@ async def sync_sample(self, sample: ActiveSample | None) -> None: async with self.batch(): await self.remove_children() await self.mount_all( - self._widgets_for_events(sample.transcript.events) + self._widgets_for_events(sample.transcript.resident_events) ) if scroll_to_end: self.scroll_end(animate=not new_sample) # set members self._sample_id = sample.id - self._sample_events = len(sample.transcript.events) + self._sample_events = sample.transcript.event_count # if we aren't active then save as a pending sample else: diff --git a/src/inspect_ai/_eval/score.py b/src/inspect_ai/_eval/score.py index ecd5882e38..51f69f7dbd 100644 --- a/src/inspect_ai/_eval/score.py +++ b/src/inspect_ai/_eval/score.py @@ -371,7 +371,7 @@ async def _run_score_task( init_subtask_store(state.store) # load a copy of the current sample events into the transcript - init_transcript(Transcript([*sample.events], log_model_api=False)) + init_transcript(Transcript([*sample.events], log_model_api=False, bounded=False)) if state.scores is None: state.scores = {} diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index a62891d525..cf8c7ae149 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -75,14 +75,22 @@ EvalSampleSummary, eval_error, ) -from inspect_ai.log._recorders.streaming import materialize_streaming_sample +from inspect_ai.log._recorders.buffer.history_provider import ( + BufferTranscriptHistoryProvider, +) +from inspect_ai.log._recorders.streaming import ( + eval_retry_error_from_history, + materialize_streaming_sample, +) from inspect_ai.log._samples import ( active_sample, ) from inspect_ai.log._transcript import ( Transcript, + TranscriptHistoryProvider, init_transcript, transcript, + transcript_bounded_enabled, ) from inspect_ai.model import ( GenerateConfig, @@ -174,6 +182,18 @@ SAMPLE_TOTAL_PROGRESS_UNITS = 1 +def _sample_transcript_config( + logger: TaskLogger | None, sample_id: str | int, epoch: int +) -> tuple[bool, TranscriptHistoryProvider | None]: + if logger is not None and logger.buffer_db is not None: + return ( + transcript_bounded_enabled(), + BufferTranscriptHistoryProvider(logger.buffer_db, sample_id, epoch), + ) + else: + return False, None + + @dataclass class TaskRunOptions: task: Task @@ -881,7 +901,15 @@ def on_sample_event(event: Event) -> None: init_sample_model_usage() init_sample_role_usage() set_sample_state(state) - sample_transcript = Transcript(log_model_api=log_model_api) + sample_transcript_bounded, history_provider = _sample_transcript_config( + logger, sample_id, state.epoch + ) + sample_transcript = Transcript( + log_model_api=log_model_api, + bounded=sample_transcript_bounded, + resident_tail=100, + history_provider=history_provider, + ) init_transcript(sample_transcript) init_subtask_store(state.store) sample_transcript.subscribe(on_sample_event) @@ -1430,6 +1458,8 @@ def make_eval_sample(include_events: bool = True) -> EvalSample: ): await emit_attempt_end(will_retry=True) + retry_error = _eval_retry_error(error, logger, state.sample_id, state.epoch) + # remove any buffered sample events if logger is not None: logger.remove_sample(state.sample_id, state.epoch) @@ -1463,7 +1493,7 @@ def make_eval_sample(include_events: bool = True) -> EvalSample: retry_on_error=retry_on_error - 1, score_on_error=score_on_error, # forward on error that caused retry - error_retries=copy(error_retries) + [_eval_retry_error(error)], + error_retries=copy(error_retries) + [retry_error], time_limit=time_limit, working_limit=working_limit, semaphore=semaphore, @@ -1745,16 +1775,29 @@ def init_sample_assistant_internal() -> None: pass -def _eval_retry_error(error: EvalError) -> EvalRetryError: +def _eval_retry_error( + error: EvalError, + logger: TaskLogger | None = None, + sample_id: str | int | None = None, + epoch: int | None = None, +) -> EvalRetryError: """Create retry error with events from the most recent ModelEvent onward.""" from inspect_ai.event._model import ModelEvent - events = transcript().events - recent_events = list(events) - for i in range(len(events) - 1, -1, -1): - if isinstance(events[i], ModelEvent): - recent_events = list(events[i:]) - break + if logger is not None and logger.buffer_db is not None and sample_id is not None: + if epoch is None: + raise ValueError( + "epoch is required when reading retry events from buffer DB" + ) + with logger.buffer_db.open_sample_history(sample_id, epoch) as history: + return eval_retry_error_from_history(error, history) + + sample_transcript = transcript() + recent_events = ( + sample_transcript.events_since_last(ModelEvent) + if sample_transcript.full_history_available + else [] + ) return EvalRetryError( message=error.message, traceback=error.traceback, diff --git a/src/inspect_ai/agent/_acp/transport_live.py b/src/inspect_ai/agent/_acp/transport_live.py index e35479e011..2ac43034e0 100644 --- a/src/inspect_ai/agent/_acp/transport_live.py +++ b/src/inspect_ai/agent/_acp/transport_live.py @@ -263,9 +263,9 @@ def _noop_unsubscribe() -> None: def snapshot(self) -> Sequence[Any]: if self._captured is None: return [] - # list() the events sequence so callers iterating concurrently - # with new ``_event`` appends don't see size changes mid-iteration. - return list(self._captured.events)[self._attach_index :] + # Slice the transcript view directly so bounded/provider-backed + # transcripts can read only the suffix since attach. + return self._captured.events[self._attach_index :] class _CancelSnapshot: diff --git a/src/inspect_ai/log/_recorders/buffer/database.py b/src/inspect_ai/log/_recorders/buffer/database.py index b36be8feda..8f475f1c65 100644 --- a/src/inspect_ai/log/_recorders/buffer/database.py +++ b/src/inspect_ai/log/_recorders/buffer/database.py @@ -10,7 +10,7 @@ from logging import getLogger from pathlib import Path from sqlite3 import Connection, OperationalError -from typing import Callable, Iterable, Iterator, Literal +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Literal, cast import psutil from pydantic import BaseModel, JsonValue @@ -67,6 +67,9 @@ logger = getLogger(__name__) SYNC_CLEANUP_TIMEOUT = 30 +if TYPE_CHECKING: + from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + class TaskData(BaseModel): version: int @@ -488,6 +491,84 @@ def sample_attachment(self, id: str | int, epoch: int, hash: str) -> str | None: conn.rollback() raise + def import_checkpoint_events( + self, id: str | int, epoch: int, event_store: "CheckpointEventStore" + ) -> int: + seed_count = 0 + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + try: + pool_attachment_refs: set[str] = set() + message_pos_map: dict[int, int] = {} + for pos, message_entry in enumerate( + self._get_message_pool(conn, id, epoch) + ): + message_pos_map[pos] = event_store.merge_message_pool_entry( + message_entry.msg_id, message_entry.data + ) + pool_attachment_refs.update( + event_store.attachment_refs_from_json(message_entry.data) + ) + call_pos_map: dict[int, int] = {} + for pos, call_entry in enumerate( + self._get_call_pool(conn, id, epoch) + ): + call_pos_map[pos] = event_store.merge_call_pool_entry( + call_entry.hash, call_entry.data + ) + pool_attachment_refs.update( + event_store.attachment_refs_from_json(call_entry.data) + ) + + def attachment_lookup(hash: str) -> str | None: + row = conn.execute( + """ + SELECT content FROM attachments + WHERE sample_id = ? AND sample_epoch = ? AND hash = ? + """, + [str(id), epoch, hash], + ).fetchone() + return None if row is None else str(row["content"]) + + event_store.merge_attachment_refs( + pool_attachment_refs, attachment_lookup + ) + for row in self._get_events(conn, id, epoch, latest_only=True): + event_store.merge_condensed_event( + row.event_id, + self._remap_pool_refs( + row.event, message_pos_map, call_pos_map + ), + attachment_lookup, + ) + seed_count += 1 + conn.commit() + return seed_count + except Exception: + conn.rollback() + raise + + @staticmethod + def _remap_pool_refs( + event: JsonData, message_pos_map: dict[int, int], call_pos_map: dict[int, int] + ) -> JsonData: + remapped = dict(event) + input_refs = remapped.get("input_refs") + if isinstance(input_refs, list): + remapped["input_refs"] = cast( + JsonValue, _remap_refs(input_refs, message_pos_map) + ) + call = remapped.get("call") + if isinstance(call, dict): + call_refs = call.get("call_refs") + if isinstance(call_refs, list): + remapped["call"] = { + **call, + "call_refs": cast(JsonValue, _remap_refs(call_refs, call_pos_map)), + } + return remapped + @contextmanager def open_sample_history_tail( self, diff --git a/src/inspect_ai/log/_recorders/buffer/filestore.py b/src/inspect_ai/log/_recorders/buffer/filestore.py index 6ab5ffc035..9bad103d79 100644 --- a/src/inspect_ai/log/_recorders/buffer/filestore.py +++ b/src/inspect_ai/log/_recorders/buffer/filestore.py @@ -353,6 +353,12 @@ def get_sample_data( def sample_event_count(self, id: str | int, epoch: int) -> int: raise NotImplementedError("Sample history is only available for buffer DBs") + @override + def import_checkpoint_events( + self, id: str | int, epoch: int, event_store: object + ) -> int: + raise NotImplementedError("Sample history is only available for buffer DBs") + @override def open_sample_history_tail( self, diff --git a/src/inspect_ai/log/_recorders/buffer/history_provider.py b/src/inspect_ai/log/_recorders/buffer/history_provider.py new file mode 100644 index 0000000000..77a31f395c --- /dev/null +++ b/src/inspect_ai/log/_recorders/buffer/history_provider.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from collections.abc import Iterator, Mapping, Sequence +from typing import TYPE_CHECKING + +from inspect_ai.event._event import Event +from inspect_ai.event._pool import ( + materialize_pooled_events, + resolve_model_event_calls, + resolve_model_event_inputs, +) +from inspect_ai.event._validate import validate_events + +if TYPE_CHECKING: + from inspect_ai.log._recorders.buffer.database import SampleBufferDatabase + from inspect_ai.log._recorders.buffer.history import SampleHistory + from inspect_ai.log._transcript import TranscriptHistoryProvider + from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + + +class BufferTranscriptHistoryProvider: + def __init__( + self, + buffer_db: SampleBufferDatabase, + sample_id: str | int, + epoch: int, + ) -> None: + self._buffer_db = buffer_db + self._sample_id = sample_id + self._epoch = epoch + + @property + def event_count(self) -> int: + return self._buffer_db.sample_event_count(self._sample_id, self._epoch) + + def events(self) -> Sequence[Event]: + return self._events() + + def iter_events(self) -> Iterator[Event]: + return self._iter_events() + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + if n is None: + return self._events() + if n <= 0: + return [] + with self._buffer_db.open_sample_history_tail( + self._sample_id, self._epoch, n + ) as history: + return _materialize_events(history) + + def events_from(self, start: int) -> Sequence[Event]: + if start <= 0: + return self._events() + with self._buffer_db.open_sample_history_from( + self._sample_id, self._epoch, start + ) as history: + return _materialize_events(history) + + def events_since_last(self, event_type: type[Event]) -> list[Event]: + suffix: list[Event] = [] + with self._buffer_db.open_sample_history( + self._sample_id, self._epoch + ) as history: + for event in _iter_materialized_events(history): + if isinstance(event, event_type): + suffix = [event] + else: + suffix.append(event) + return suffix + + def attachments(self) -> Mapping[str, str]: + with self._buffer_db.open_sample_history( + self._sample_id, self._epoch + ) as history: + return dict(history.attachments) + + def attachment(self, hash: str) -> str | None: + return self._buffer_db.sample_attachment(self._sample_id, self._epoch, hash) + + def import_checkpoint_events(self, event_store: "CheckpointEventStore") -> int: + return self._buffer_db.import_checkpoint_events( + self._sample_id, self._epoch, event_store + ) + + def _events(self) -> list[Event]: + with self._buffer_db.open_sample_history( + self._sample_id, self._epoch + ) as history: + return _materialize_events(history) + + def _iter_events(self) -> Iterator[Event]: + with self._buffer_db.open_sample_history( + self._sample_id, self._epoch + ) as history: + yield from _iter_materialized_events(history) + + +def _materialize_events(history: SampleHistory) -> list[Event]: + return materialize_pooled_events( + history.iter_events(), + history.events_data["messages"], + history.events_data["calls"], + ) + + +def _iter_materialized_events(history: SampleHistory) -> Iterator[Event]: + message_pool = history.events_data["messages"] + call_pool = history.events_data["calls"] + for raw_event in history.iter_events(): + event = validate_events([raw_event])[0] + event = resolve_model_event_inputs([event], message_pool)[0] + yield resolve_model_event_calls([event], call_pool)[0] + + +if TYPE_CHECKING: + _buffer_transcript_history_provider: type[TranscriptHistoryProvider] = ( + BufferTranscriptHistoryProvider + ) diff --git a/src/inspect_ai/log/_recorders/buffer/types.py b/src/inspect_ai/log/_recorders/buffer/types.py index 21dcd8c528..2b6b1be387 100644 --- a/src/inspect_ai/log/_recorders/buffer/types.py +++ b/src/inspect_ai/log/_recorders/buffer/types.py @@ -9,6 +9,8 @@ from ..._log import EvalSampleSummary if TYPE_CHECKING: + from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + from .history import SampleHistory JsonData: TypeAlias = dict[str, JsonValue] @@ -132,6 +134,13 @@ def sample_event_count(self, id: str | int, epoch: int) -> int: """Return the number of distinct events recorded for a sample.""" ... + @abc.abstractmethod + def import_checkpoint_events( + self, id: str | int, epoch: int, event_store: "CheckpointEventStore" + ) -> int: + """Import a sample's full event history into a checkpoint event store.""" + ... + @abc.abstractmethod def open_sample_history_tail( self, diff --git a/src/inspect_ai/log/_recorders/streaming.py b/src/inspect_ai/log/_recorders/streaming.py index eec0aca618..07c87a607b 100644 --- a/src/inspect_ai/log/_recorders/streaming.py +++ b/src/inspect_ai/log/_recorders/streaming.py @@ -2,8 +2,10 @@ from typing import TYPE_CHECKING +from inspect_ai._util.error import EvalError +from inspect_ai.event._pool import materialize_pooled_events from inspect_ai.event._validate import validate_events -from inspect_ai.log._log import EvalSample +from inspect_ai.log._log import EvalRetryError, EvalSample from inspect_ai.log._resolve import rebind_sample_timelines, resolve_sample_events_data if TYPE_CHECKING: @@ -26,3 +28,28 @@ def materialize_streaming_sample( } ) return rebind_sample_timelines(materialized) + + +def eval_retry_error_from_history( + error: EvalError, history: "SampleHistory" +) -> EvalRetryError: + """Create retry error from full history since the latest ModelEvent.""" + suffix: list[object] = [] + for event in history.iter_events(): + if event.get("event") == "model": + suffix = [event] + elif suffix: + suffix.append(event) + + events = materialize_pooled_events( + suffix, + history.events_data["messages"], + history.events_data["calls"], + ) + + return EvalRetryError( + message=error.message, + traceback=error.traceback, + traceback_ansi=error.traceback_ansi, + events=events, + ) diff --git a/src/inspect_ai/log/_samples.py b/src/inspect_ai/log/_samples.py index 41ae7ffdc6..7485d00c70 100644 --- a/src/inspect_ai/log/_samples.py +++ b/src/inspect_ai/log/_samples.py @@ -1,5 +1,4 @@ import contextlib -from contextlib import AbstractAsyncContextManager from contextvars import ContextVar from datetime import datetime, timezone from logging import getLogger @@ -26,7 +25,7 @@ from shortuuid import uuid from inspect_ai.dataset._dataset import Sample -from inspect_ai.util._checkpoint.checkpointer import Checkpointer, ResumeCheckpoint +from inspect_ai.util._checkpoint.checkpointer import CheckpointerSetup, ResumeCheckpoint from inspect_ai.util._checkpoint.checkpointer_factory import create_checkpointer from inspect_ai.util._checkpoint.config import ResolvedCheckpointConfig from inspect_ai.util._limit import LimitExceededError @@ -56,7 +55,7 @@ def __init__( fails_on_error: bool, transcript: Transcript, sandboxes: dict[str, SandboxConnection], - checkpointer: AbstractAsyncContextManager[Checkpointer], + checkpointer: CheckpointerSetup, eval_id: str, eval_set_id: str | None = None, run_id: str | None = None, @@ -283,6 +282,7 @@ async def active_sample( "ActiveSample on_complete hook raised", exc_info=True, ) + active.checkpointer.close() active.complete() _active_samples.remove(active) _sample_active.set(None) diff --git a/src/inspect_ai/log/_transcript.py b/src/inspect_ai/log/_transcript.py index 0844bf6155..9f0ebc32ea 100644 --- a/src/inspect_ai/log/_transcript.py +++ b/src/inspect_ai/log/_transcript.py @@ -1,11 +1,16 @@ import contextlib +import os +from collections import deque from contextvars import ContextVar -from dataclasses import dataclass from logging import getLogger from typing import ( + TYPE_CHECKING, Callable, + Deque, Iterator, Literal, + Mapping, + Protocol, Sequence, TypeVar, overload, @@ -14,75 +19,220 @@ from pydantic import ( JsonValue, ) +from shortuuid import uuid +from inspect_ai._util.list import find_last_match from inspect_ai._util.logger import warn_once from inspect_ai.event._base import BaseEvent from inspect_ai.event._event import Event from inspect_ai.event._info import InfoEvent from inspect_ai.event._interrupt import InterruptEvent from inspect_ai.event._model import ModelEvent +from inspect_ai.event._sample_init import SampleInitEvent from inspect_ai.event._store import StoreEvent from inspect_ai.event._timeline import Timeline from inspect_ai.log._condense import ( WalkContext, + attachment_refs_from_value, events_attachment_fn, walk_model_call, ) from inspect_ai.util._store import store, store_changes, store_jsonable +if TYPE_CHECKING: + from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + logger = getLogger(__name__) ET = TypeVar("ET", bound=BaseEvent) -@dataclass(frozen=True) -class _TranscriptSubscription: - id: int - callback: Callable[[Event], None] +def transcript_bounded_enabled() -> bool: + value = os.environ.get("INSPECT_TRANSCRIPT_BOUNDED") + if value is None: + return False + return value.strip().lower() not in ("0", "false", "no", "off") + + +class TranscriptHistoryProvider(Protocol): + @property + def event_count(self) -> int: ... + + def iter_events(self) -> Iterator[Event]: ... + + def events(self) -> Sequence[Event]: ... + + def recent_events(self, n: int | None = None) -> Sequence[Event]: ... + + def events_from(self, start: int) -> Sequence[Event]: ... + + def events_since_last(self, event_type: type[Event]) -> list[Event]: ... + + def attachments(self) -> Mapping[str, str]: ... + + def attachment(self, hash: str) -> str | None: ... + + def import_checkpoint_events(self, event_store: "CheckpointEventStore") -> int: ... + + +class _TranscriptEventsView(Sequence[Event]): + def __init__(self, transcript: "Transcript") -> None: + self._transcript = transcript + + def __len__(self) -> int: + return self._transcript.event_count + + def __iter__(self) -> Iterator[Event]: + provider = self._transcript._history_provider + if provider is None: + return iter(self._transcript._events) + return provider.iter_events() + + def __contains__(self, item: object) -> bool: + if not isinstance(item, BaseEvent): + return False + item_key = item.uuid + if item_key is None: + return any(event is item for event in self._transcript._events) + return any(event.uuid == item_key for event in self._transcript._events) + + @overload + def __getitem__(self, index: int) -> Event: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[Event]: ... + + def __getitem__(self, index: int | slice) -> Event | Sequence[Event]: + if isinstance(index, slice): + return self._slice(index) + if index == -1: + provider = self._transcript._history_provider + if provider is None: + last_event = self._transcript.last_event + if last_event is not None: + return last_event + else: + recent_events = self._transcript.recent_events(1) + if recent_events: + return recent_events[-1] + raise IndexError("Transcript events index out of range") + if index >= 0: + provider = self._transcript._history_provider + if provider is not None: + for event_index, event in enumerate(provider.iter_events()): + if event_index == index: + return event + raise IndexError("Transcript events index out of range") + events = self._materialize() + return events[index] + + def _slice(self, index: slice) -> Sequence[Event]: + if index == slice(None, None, None): + provider = self._transcript._history_provider + if provider is None: + return self._transcript._events[:] + return provider.events() + if index.step is None and index.stop is None and index.start is not None: + start = index.start + provider = self._transcript._history_provider + if provider is not None: + if start >= 0: + return provider.events_from(start) + if -start <= min( + self._transcript._resident_tail, len(self._transcript._events) + ): + return self._transcript._events[start:] + return provider.recent_events(-start) + return self._materialize()[index] + + def _materialize(self) -> list[Event]: + provider = self._transcript._history_provider + if provider is None: + return list(self._transcript._events) + return list(provider.events()) class Transcript: """Transcript of events.""" - _event_logger: _TranscriptSubscription | None - _event_loggers: list[_TranscriptSubscription] + _event_logger: Callable[[Event], None] | None + _event_loggers: list[Callable[[Event], None]] + _additional_subscribers: list[Callable[[Event], None]] _notifying_subscribers: set[int] _context: WalkContext @overload - def __init__(self, *, log_model_api: bool | None = None) -> None: ... + def __init__( + self, + *, + log_model_api: bool | None = None, + bounded: bool = False, + resident_tail: int = 100, + history_provider: TranscriptHistoryProvider | None = None, + ) -> None: ... @overload def __init__( - self, events: list[Event], log_model_api: bool | None = None + self, + events: list[Event], + log_model_api: bool | None = None, + bounded: bool = False, + resident_tail: int = 100, + history_provider: TranscriptHistoryProvider | None = None, ) -> None: ... def __init__( - self, events: list[Event] | None = None, log_model_api: bool | None = None + self, + events: list[Event] | None = None, + log_model_api: bool | None = None, + bounded: bool = False, + resident_tail: int = 100, + history_provider: TranscriptHistoryProvider | None = None, ) -> None: self._event_logger = None + self._event_loggers = [] + self._additional_subscribers = self._event_loggers self._log_model_api = log_model_api self._context = WalkContext(message_cache={}, only_core=False) - self._events: list[Event] = events if events is not None else [] + self._events: list[Event] = self._copy_uuidless_events(events or []) + self._history_provider = history_provider + self._events_view = _TranscriptEventsView(self) self._attachments: dict[str, str] = {} + self._attachment_refcount: dict[str, int] = {} + self._event_attachment_refs: dict[str, set[str]] = {} self._timelines: list[Timeline] = [] self._model_call_counts: dict[str, int] = {} - self._kept_event_ids: set[int] = set() - self._event_loggers = [] - self._next_event_logger_id = 0 - # Sidecar of currently-pending events keyed by ``event.uuid`` so - # consumers (live TUI toolbar, future DB-backed transcripts) can - # query in-flight state in O(in-flight) without scanning all - # events. Maintained by ``_event``/``_event_updated``. Insertion - # order = declared order; dict preserves it so the "earliest - # pending" is the first value. - self._pending_events: dict[str, Event] = {} - if events is not None: - for ev in events: - if ev.pending and ev.uuid is not None: - self._pending_events[ev.uuid] = ev + self._kept_event_ids: set[str] = set() + self._bounded = bounded + self._resident_tail = resident_tail + self._event_count = len(self._events) + self._events_truncated = False + self._pinned_event_ids: set[str] = { + self._event_key(event) + for event in self._events + if isinstance(event, SampleInitEvent) + } + self._pending_event_ids: set[str] = { + self._event_key(event) for event in self._events if event.pending + } + self._pending_events: dict[str, Event] = { + self._event_key(event): event for event in self._events if event.pending + } + self._resident_event_ids: set[str] = { + self._event_key(event) for event in self._events + } + self._evictable_event_ids: Deque[str] = deque( + event_key + for event in self._events + if (event_key := self._event_key(event)) + not in self._pinned_event_ids | self._pending_event_ids + ) + # Re-entry guard for subscriber callbacks. If a subscriber logs while + # handling an event, the resulting LoggerEvent should still reach all + # other subscribers, but not recursively notify the same subscriber. self._notifying_subscribers = set() + self._evict_events() def info(self, data: JsonValue, *, source: str | None = None) -> None: """Add an `InfoEvent` to the transcript. @@ -114,20 +264,66 @@ def step(self, name: str, type: str | None = None) -> Iterator[None]: @property def events(self) -> Sequence[Event]: + """Compatibility view of the logical event history. + + For unbounded or provider-free transcripts this returns resident events. + For bounded transcripts with a history provider this returns a lazy view + over the full logical history. Iteration, random indexing, and some slices + may read and materialize events from the provider; hot paths should use + `resident_events`, `event_count`, `last_event`, or `recent_events()`. + """ + if self._history_provider is None: + return self._events + return self._events_view + + @property + def resident_events(self) -> Sequence[Event]: + """Events currently resident in memory for live/hot-path consumers.""" return self._events @property def pending_events(self) -> Sequence[Event]: - """Currently-pending events in declared (insertion) order. - - Returns a snapshot of events with ``pending=True`` keyed by - their ``uuid``. Updates whenever ``_event`` records a new - pending event or ``_event_updated`` flips one to a terminal - state. Bounded by the number of in-flight operations - (typically 0-1, up to the stage size under parallel tools). - """ + """Currently-pending events in insertion order.""" return list(self._pending_events.values()) + @property + def event_count(self) -> int: + return self._event_count + + @property + def resident_events_truncated(self) -> bool: + return self._events_truncated + + @property + def full_history_available(self) -> bool: + return not self._events_truncated or self._history_provider is not None + + @property + def last_event(self) -> Event | None: + return self._events[-1] if self._events else None + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + if n is not None and n <= 0: + return [] + if self._history_provider is None: + return self._events if n is None else self._events[-n:] + if n is not None and n <= min(self._resident_tail, len(self._events)): + return self._events[-n:] + return self._history_provider.recent_events(n) + + def events_since_last(self, event_type: type[Event]) -> list[Event]: + if self._events_truncated: + if self._history_provider is not None: + return self._history_provider.events_since_last(event_type) + raise RuntimeError( + "Full transcript history is not available from this Transcript" + ) + events = list(self._events) + index = find_last_match(events, lambda event: isinstance(event, event_type)) + if index is not None: + return events[index:] + return events + @property def attachments(self) -> dict[str, str]: return self._attachments @@ -153,31 +349,71 @@ def add_timeline(self, timeline: Timeline) -> None: self._timelines.append(timeline) def _event(self, event: Event) -> None: + event_key = self._ensure_event_key(event) + if event_key in self._resident_event_ids: + raise ValueError(f"Duplicate event uuid: {event_key}") self._process_event(event) self._events.append(event) + self._resident_event_ids.add(event_key) + self._set_attachment_refs(event) + self._event_count += 1 + self._update_pin_state(event) self._update_pending(event) + self._update_evictable_state(event) + self._evict_events() + + def _extend_restored_events( + self, + events: Sequence[Event], + attachments: Mapping[str, str], + *, + notify_subscribers: bool = False, + ) -> None: + events = self._copy_uuidless_events(events) + event_keys: list[str] = [] + new_event_keys: set[str] = set() + for event in events: + event_key = self._ensure_event_key(event) + if event_key in self._resident_event_ids or event_key in new_event_keys: + raise ValueError(f"Duplicate event uuid: {event_key}") + event_keys.append(event_key) + new_event_keys.add(event_key) + + self._attachments.update(attachments) + for event, event_key in zip(events, event_keys): + self._events.append(event) + self._resident_event_ids.add(event_key) + self._set_attachment_refs(event) + self._event_count += 1 + self._update_pin_state(event) + self._update_pending(event) + self._update_evictable_state(event) + if notify_subscribers: + self._notify_subscribers(event) + self._evict_events() def _event_updated(self, event: Event) -> None: - self._process_event(event) - self._update_pending(event) + if self._is_resident(event): + self._process_event(event) + self._set_attachment_refs(event) + self._update_pin_state(event) + self._update_pending(event) + self._update_evictable_state(event) + self._evict_events() + else: + self._process_event(event, retain_attachments=False) + self._update_pending(event) + self._prune_unreferenced_attachments() def _update_pending(self, event: Event) -> None: - """Reflect ``event``'s current pending state in the sidecar. - - Adds the event on first emission with ``pending=True``; removes - on the subsequent ``_event_updated`` that flips it to a terminal - state. Events without a ``uuid`` (synthetic step/span events) - are skipped — they can't be deduplicated and don't represent - an in-flight operation. - """ - if event.uuid is None: - return + """Reflect ``event``'s current pending state in the sidecar.""" + event_key = self._event_key(event) if event.pending: - self._pending_events[event.uuid] = event + self._pending_events[event_key] = event else: - self._pending_events.pop(event.uuid, None) + self._pending_events.pop(event_key, None) - def _process_event(self, event: Event) -> None: + def _process_event(self, event: Event, *, retain_attachments: bool = True) -> None: if isinstance(event, ModelEvent) and event.call is not None: is_error = bool(event.call.error) if not is_error: @@ -186,8 +422,8 @@ def _process_event(self, event: Event) -> None: elif self._log_model_api is False: event.call = None else: - event_id = id(event) - if event_id not in self._kept_event_ids: + event_key = self._event_key(event) + if event_key not in self._kept_event_ids: from inspect_ai._util.constants import ( DEFAULT_LOG_MODEL_API_CALLS, ) @@ -195,60 +431,193 @@ def _process_event(self, event: Event) -> None: count = self._model_call_counts.get(event.model, 0) if count < DEFAULT_LOG_MODEL_API_CALLS: self._model_call_counts[event.model] = count + 1 - self._kept_event_ids.add(event_id) + self._kept_event_ids.add(event_key) else: event.call = None + if retain_attachments and event.call is not None: + event_fn = events_attachment_fn(self.attachments) + event.call = walk_model_call(event.call, event_fn, self._context) + + self._notify_subscribers(event) + + def _notify_subscribers(self, event: Event) -> None: for event_logger in list(self._event_loggers): - subscriber_id = event_logger.id + subscriber_id = id(event_logger) if subscriber_id in self._notifying_subscribers: continue self._notifying_subscribers.add(subscriber_id) try: try: - event_logger.callback(event) + event_logger(event) except Exception: logger.warning("Transcript subscriber failed", exc_info=True) finally: self._notifying_subscribers.remove(subscriber_id) - # condense model event calls immediately to prevent O(N) memory usage - if isinstance(event, ModelEvent) and event.call is not None: - event_fn = events_attachment_fn(self.attachments) - event.call = walk_model_call(event.call, event_fn, self._context) + def _set_attachment_refs(self, event: Event) -> None: + if not self._bounded: + return + + event_key = self._event_key(event) + previous_refs = self._event_attachment_refs.get(event_key, set()) + current_refs = self._attachment_refs(event) + for ref in previous_refs - current_refs: + self._decrement_attachment_ref(ref) + for ref in current_refs - previous_refs: + self._attachment_refcount[ref] = self._attachment_refcount.get(ref, 0) + 1 + if current_refs: + self._event_attachment_refs[event_key] = current_refs + else: + self._event_attachment_refs.pop(event_key, None) + + def _attachment_refs(self, event: Event) -> set[str]: + return attachment_refs_from_value(event.model_dump(mode="python")) + + def _decrement_attachment_ref(self, ref: str) -> None: + count = self._attachment_refcount.get(ref, 0) - 1 + if count > 0: + self._attachment_refcount[ref] = count + else: + self._attachment_refcount.pop(ref, None) + self._attachments.pop(ref, None) + + def _prune_unreferenced_attachments(self) -> None: + if not self._bounded: + return + + for ref in list(self._attachments): + if ref not in self._attachment_refcount: + self._attachments.pop(ref, None) + + def _evict_events(self) -> None: + if not self._bounded: + return + + resident_tail = max(self._resident_tail, 0) + evicted_event_ids: set[str] = set() + while len(self._evictable_event_ids) > resident_tail: + event_key = self._evictable_event_ids.popleft() + if not self._is_evictable_event_key(event_key): + continue + evicted_event_ids.add(event_key) + + if evicted_event_ids: + self._events = [ + event + for event in self._events + if self._event_key(event) not in evicted_event_ids + ] + self._resident_event_ids.difference_update(evicted_event_ids) + self._events_truncated = True + self._prune_pin_state() + + def _prune_pin_state(self) -> None: + resident_event_keys = self._resident_event_ids + self._pinned_event_ids.intersection_update(resident_event_keys) + self._pending_event_ids.intersection_update(resident_event_keys) + self._evictable_event_ids = deque( + event_key + for event_key in self._evictable_event_ids + if self._is_evictable_event_key(event_key) + ) + if self._bounded: + self._kept_event_ids.intersection_update(resident_event_keys) + self._prune_attachment_refs(resident_event_keys) + + def _prune_attachment_refs(self, resident_event_keys: set[str]) -> None: + for event_key in list(self._event_attachment_refs): + if event_key in resident_event_keys: + continue + for ref in self._event_attachment_refs.pop(event_key): + self._decrement_attachment_ref(ref) + + def _is_resident(self, event: Event) -> bool: + return event.uuid is not None and event.uuid in self._resident_event_ids + + def _update_pin_state(self, event: Event) -> None: + event_key = self._event_key(event) + if isinstance(event, SampleInitEvent): + self._pinned_event_ids.add(event_key) + if event.pending: + self._pending_event_ids.add(event_key) + else: + self._pending_event_ids.discard(event_key) + + def _update_evictable_state(self, event: Event) -> None: + event_key = self._event_key(event) + if not self._is_evictable_event_key(event_key): + return + if event_key in self._evictable_event_ids: + return + + for index, resident_event in enumerate(self._events): + if resident_event is event: + self._evictable_event_ids.insert(index, event_key) + return + self._evictable_event_ids.append(event_key) + + def _is_evictable_event_key(self, event_key: str) -> bool: + return ( + event_key in self._resident_event_ids + and event_key not in self._pinned_event_ids + and event_key not in self._pending_event_ids + ) + + def _event_key(self, event: Event) -> str: + if event.uuid is None: + raise ValueError("Transcript event is missing uuid") + return event.uuid + + def _ensure_event_key(self, event: Event) -> str: + if event.uuid is None: + event.uuid = uuid() + return event.uuid + + @staticmethod + def _copy_uuidless_events(events: Sequence[Event]) -> list[Event]: + copied_events: list[Event] = [] + for event in events: + if event.uuid is None: + event = event.model_copy() + event.uuid = uuid() + copied_events.append(event) + return copied_events def subscribe(self, event_logger: Callable[[Event], None]) -> Callable[[], None]: - subscription = self._create_subscription(event_logger) - self._event_loggers.append(subscription) - unsubscribed = False + """Subscribe to transcript event notifications. + + The callback is invoked when an event is added and when a resident event + is updated. Subscriber exceptions are logged and do not prevent other + subscribers or normal transcript processing. Returns an unsubscribe + callback that removes the subscription. + """ + self._event_loggers.append(event_logger) def unsubscribe() -> None: - nonlocal unsubscribed - if not unsubscribed: - unsubscribed = True - self._event_loggers.remove(subscription) + if event_logger in self._event_loggers: + self._event_loggers.remove(event_logger) return unsubscribe def _subscribe(self, event_logger: Callable[[Event], None]) -> None: - if self._event_logger is not None: + """Legacy subscription API for eval logging.""" + if self._event_logger is not None and self._event_logger in self._event_loggers: self._event_loggers.remove(self._event_logger) - self._event_logger = self._create_subscription(event_logger) - self._event_loggers.append(self._event_logger) + self._event_logger = event_logger + self._event_loggers.append(event_logger) def _add_subscriber(self, callback: Callable[[Event], None]) -> Callable[[], None]: return self.subscribe(callback) - def _create_subscription( - self, callback: Callable[[Event], None] - ) -> _TranscriptSubscription: - self._next_event_logger_id += 1 - return _TranscriptSubscription(id=self._next_event_logger_id, callback=callback) - def transcript() -> Transcript: """Get the current `Transcript`.""" - return _transcript.get() + active_transcript = _transcript.get() + if active_transcript is None: + active_transcript = Transcript() + _transcript.set(active_transcript) + return active_transcript def record_interrupt_event( @@ -263,15 +632,6 @@ def record_interrupt_event( Internal helper used by Inspect's cancellation machinery — the ACP `cancel_current_turn` path, sample-level limit handlers, and system shutdown. Not a public API for agent authors. - - Args: - source: What caused the interrupt — an ACP client cancel, - a sample limit, or system shutdown. - interrupted: What was running at the moment of the interrupt. - interrupted_tool_call_id: ``ToolEvent.id`` of the in-flight - tool call, if any. - interrupted_model_event_id: ``ModelEvent.uuid`` of the - in-flight model call, if any. """ transcript()._event( InterruptEvent( @@ -298,6 +658,6 @@ def init_transcript(transcript: Transcript) -> None: _transcript.set(transcript) -_transcript: ContextVar[Transcript] = ContextVar( - "subtask_transcript", default=Transcript() +_transcript: ContextVar[Transcript | None] = ContextVar( + "subtask_transcript", default=None ) diff --git a/src/inspect_ai/util/_checkpoint/_event_store.py b/src/inspect_ai/util/_checkpoint/_event_store.py new file mode 100644 index 0000000000..92559dbb4e --- /dev/null +++ b/src/inspect_ai/util/_checkpoint/_event_store.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import json +import logging +import os +import tempfile +from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass +from pathlib import Path +from sqlite3 import Connection, connect +from threading import RLock +from typing import TextIO + +from pydantic import JsonValue +from pydantic_core import to_jsonable_python +from shortuuid import uuid + +from inspect_ai._util.hash import mm3_hash +from inspect_ai._util.json import to_json_str_safe +from inspect_ai.event._event import Event +from inspect_ai.event._pool import ( + _msg_hash, + condense_model_event_calls_with_lookup, + condense_model_event_inputs_with_lookup, +) +from inspect_ai.log._condense import ( + attachment_refs_from_value, + condense_event, +) +from inspect_ai.model._chat_message import ChatMessage + +logger = logging.getLogger(__name__) + +CHECKPOINT_EVENT_STORE = "checkpoint_events.sqlite" + + +@dataclass(frozen=True) +class CheckpointEventStoreCounts: + events: int + message_pool: int + call_pool: int + attachments: int + db_bytes: int + + +class CheckpointEventStore: + def __init__(self, path: str | Path, *, reset: bool = False) -> None: + self._path = Path(path) + self._path.parent.mkdir(parents=True, exist_ok=True) + if reset: + self._reset_files() + self._conn = connect(self._path, check_same_thread=False) + self._lock = RLock() + self._pending_message_pos_by_event: dict[str, dict[int, int]] = {} + self._pending_call_pos_by_event: dict[str, dict[int, int]] = {} + self._conn.row_factory = None + self._init_schema(self._conn) + + @property + def path(self) -> Path: + return self._path + + def counts(self) -> CheckpointEventStoreCounts: + with self._lock: + return CheckpointEventStoreCounts( + events=self._count_rows("events"), + message_pool=self._count_rows("message_pool"), + call_pool=self._count_rows("call_pool"), + attachments=self._count_rows("attachments"), + db_bytes=self._path.stat().st_size if self._path.exists() else 0, + ) + + def close(self) -> None: + with self._lock: + self._conn.close() + + def merge_event( + self, event: Event, attachment_lookup: Callable[[str], str | None] + ) -> None: + with self._lock: + if event.uuid is None: + event.uuid = uuid() + logical_id = event.uuid + + with self._conn: + event_attachments: dict[str, str] = {} + event = condense_event(event, event_attachments) + condensed_event = self._condense_event(logical_id, event) + event_json = json.dumps( + to_jsonable_python( + condensed_event, exclude_none=True, fallback=lambda _: None + ), + separators=(",", ":"), + ) + self._upsert_event(logical_id, event_json) + self._insert_attachments(event_attachments) + self._merge_attachment_refs( + self._attachment_refs(event), + lambda ref: event_attachments.get(ref) or attachment_lookup(ref), + ) + + def merge_condensed_event( + self, + logical_id: str, + event: Mapping[str, JsonValue], + attachment_lookup: Callable[[str], str | None], + ) -> None: + event_jsonable = dict(event) + event_jsonable.setdefault("uuid", logical_id) + event_json = json.dumps(event_jsonable, separators=(",", ":")) + with self._lock: + with self._conn: + self._upsert_event(logical_id, event_json) + self._merge_attachment_refs( + attachment_refs_from_value(event_jsonable), attachment_lookup + ) + + def merge_message_pool_entry(self, hash_value: str, json_text: str) -> int: + with self._lock: + with self._conn: + return self._pool_pos("message_pool", hash_value, json_text) + + def merge_message_pool(self, messages: Iterable[ChatMessage]) -> None: + with self._lock: + with self._conn: + for message in messages: + self._message_pos(message) + + def merge_call_pool_entry(self, hash_value: str, json_text: str) -> int: + with self._lock: + with self._conn: + return self._pool_pos("call_pool", hash_value, json_text) + + def merge_call_pool(self, calls: Iterable[JsonValue]) -> None: + with self._lock: + with self._conn: + for call in calls: + self._call_pos(call) + + def _upsert_event(self, logical_id: str, event_json: str) -> None: + row = self._conn.execute( + "SELECT first_seq FROM events WHERE logical_id = ?", + (logical_id,), + ).fetchone() + if row is None: + first_seq = self._next_event_seq() + self._conn.execute( + "INSERT INTO events(logical_id, first_seq, latest_json) VALUES (?, ?, ?)", + (logical_id, first_seq, event_json), + ) + else: + self._conn.execute( + "UPDATE events SET latest_json = ? WHERE logical_id = ?", + (event_json, logical_id), + ) + + def _merge_attachment_refs( + self, refs: Iterable[str], attachment_lookup: Callable[[str], str | None] + ) -> None: + for ref in refs: + content = attachment_lookup(ref) + if content is not None: + self._conn.execute( + "INSERT OR IGNORE INTO attachments(hash, content) VALUES (?, ?)", + (ref, content), + ) + elif not self._has_attachment(ref): + logger.warning( + "Checkpoint event references missing attachment: %s", ref + ) + + def _has_attachment(self, ref: str) -> bool: + row = self._conn.execute( + "SELECT 1 FROM attachments WHERE hash = ?", + (ref,), + ).fetchone() + return row is not None + + def merge_events( + self, events: Iterable[Event], attachments: Mapping[str, str] + ) -> None: + with self._lock: + for event in events: + self.merge_event(event, attachments.get) + + def merge_attachments(self, attachments: Mapping[str, str]) -> None: + if not attachments: + return + with self._lock: + with self._conn: + self._insert_attachments(attachments) + + def _insert_attachments(self, attachments: Mapping[str, str]) -> None: + if not attachments: + return + self._conn.executemany( + "INSERT OR IGNORE INTO attachments(hash, content) VALUES (?, ?)", + attachments.items(), + ) + + def merge_attachment_refs( + self, refs: Iterable[str], attachment_lookup: Callable[[str], str | None] + ) -> None: + with self._lock: + with self._conn: + self._merge_attachment_refs(refs, attachment_lookup) + + @staticmethod + def attachment_refs_from_json(json_text: str) -> set[str]: + return attachment_refs_from_value(json.loads(json_text)) + + def _reset_files(self) -> None: + for path in ( + self._path, + self._path.with_name(f"{self._path.name}-wal"), + self._path.with_name(f"{self._path.name}-shm"), + ): + path.unlink(missing_ok=True) + + def export_snapshot_files( + self, + sample_working_dir: str | Path, + *, + store_json: object, + agent_state: Mapping[str, object] | None, + ) -> None: + sample_dir = Path(sample_working_dir) + with self._lock: + self._write_text_atomic( + sample_dir / "store.json", to_json_str_safe(store_json) + ) + if agent_state is not None: + self._write_text_atomic( + sample_dir / "agent_state.json", to_json_str_safe(agent_state) + ) + with self._conn: + self._write_events_data(sample_dir / "events_data.json") + self._write_json_object_from_rows( + sample_dir / "attachments.json", + self._conn.execute( + "SELECT hash, content FROM attachments ORDER BY hash" + ), + ) + self._write_json_array( + sample_dir / "events.json", + self._conn.execute( + "SELECT latest_json FROM events ORDER BY first_seq" + ), + ) + + def _count_rows(self, table: str) -> int: + row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + assert row is not None + return int(row[0]) + + def _next_event_seq(self) -> int: + row = self._conn.execute( + "SELECT COALESCE(MAX(first_seq), -1) + 1 FROM events" + ).fetchone() + assert row is not None + return int(row[0]) + + def _condense_event(self, logical_id: str, event: Event) -> Event: + message_cache = self._pending_message_pos_by_event.get(logical_id) + call_cache = self._pending_call_pos_by_event.get(logical_id) + if event.pending: + message_cache = message_cache if message_cache is not None else {} + call_cache = call_cache if call_cache is not None else {} + self._pending_message_pos_by_event[logical_id] = message_cache + self._pending_call_pos_by_event[logical_id] = call_cache + + event = condense_model_event_inputs_with_lookup( + event, lambda message: self._message_pos(message, message_cache) + ) + event = condense_model_event_calls_with_lookup( + event, lambda call_message: self._call_pos(call_message, call_cache) + ) + if not event.pending: + self._pending_message_pos_by_event.pop(logical_id, None) + self._pending_call_pos_by_event.pop(logical_id, None) + return event + + def _message_pos( + self, message: ChatMessage, cache: dict[int, int] | None = None + ) -> int: + message_id = id(message) + if cache is not None: + cached_pos = cache.get(message_id) + if cached_pos is not None: + return cached_pos + + message_jsonable = to_jsonable_python( + message, + exclude_none=True, + fallback=lambda _: None, + ) + pos = self._pool_pos( + "message_pool", + _msg_hash(message), + json.dumps(message_jsonable, sort_keys=True), + ) + if cache is not None: + cache[message_id] = pos + return pos + + def _call_pos( + self, call_message: JsonValue, cache: dict[int, int] | None = None + ) -> int: + call_message_id = id(call_message) + if cache is not None: + cached_pos = cache.get(call_message_id) + if cached_pos is not None: + return cached_pos + + call_json = json.dumps(call_message, sort_keys=True) + pos = self._pool_pos("call_pool", mm3_hash(call_json), call_json) + if cache is not None: + cache[call_message_id] = pos + return pos + + def _pool_pos(self, table: str, hash_value: str, json_text: str) -> int: + row = self._conn.execute( + f"SELECT pos FROM {table} WHERE hash = ?", + (hash_value,), + ).fetchone() + if row is not None: + return int(row[0]) + + pos_row = self._conn.execute( + f"SELECT COALESCE(MAX(pos), -1) + 1 FROM {table}" + ).fetchone() + assert pos_row is not None + pos = int(pos_row[0]) + self._conn.execute( + f"INSERT INTO {table}(pos, hash, json) VALUES (?, ?, ?)", + (pos, hash_value, json_text), + ) + return pos + + @staticmethod + def _attachment_refs(event: Event) -> set[str]: + return attachment_refs_from_value(event.model_dump(mode="python")) + + @staticmethod + def _write_text_atomic(path: Path, content: str) -> None: + CheckpointEventStore._write_atomic(path, lambda file: file.write(content)) + + @staticmethod + def _write_json_array(path: Path, rows: Iterable[tuple[str]]) -> None: + def write(file: TextIO) -> None: + CheckpointEventStore._write_json_array_to_file(file, rows) + file.write("\n") + + CheckpointEventStore._write_atomic(path, write) + + @staticmethod + def _write_json_object_from_rows( + path: Path, rows: Iterable[tuple[str, str]] + ) -> None: + def write(file: TextIO) -> None: + file.write("{") + first = True + for key, value in rows: + if not first: + file.write(",") + file.write(json.dumps(str(key))) + file.write(":") + file.write(json.dumps(value)) + first = False + file.write("}\n") + + CheckpointEventStore._write_atomic(path, write) + + def _write_events_data(self, path: Path) -> None: + def write(file: TextIO) -> None: + file.write('{"messages":') + self._write_json_array_to_file( + file, + self._conn.execute("SELECT json FROM message_pool ORDER BY pos"), + ) + file.write(',"calls":') + self._write_json_array_to_file( + file, + self._conn.execute("SELECT json FROM call_pool ORDER BY pos"), + ) + file.write("}\n") + + self._write_atomic(path, write) + + @staticmethod + def _write_atomic(path: Path, write: Callable[[TextIO], object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_name = tempfile.mkstemp( + dir=path.parent, + prefix=f".{path.name}.", + suffix=".tmp", + text=True, + ) + tmp_path = Path(tmp_name) + try: + with open(fd, "w") as file: + write(file) + file.flush() + os.fsync(file.fileno()) + os.replace(tmp_path, path) + except Exception: + tmp_path.unlink(missing_ok=True) + raise + + @staticmethod + def _write_json_array_to_file(file: TextIO, rows: Iterable[tuple[str]]) -> None: + file.write("[") + first = True + for (json_text,) in rows: + if not first: + file.write(",") + file.write(str(json_text)) + first = False + file.write("]") + + @staticmethod + def _init_schema(conn: Connection) -> None: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS events ( + logical_id TEXT PRIMARY KEY, + first_seq INTEGER NOT NULL UNIQUE, + latest_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS message_pool ( + pos INTEGER PRIMARY KEY, + hash TEXT NOT NULL UNIQUE, + json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS call_pool ( + pos INTEGER PRIMARY KEY, + hash TEXT NOT NULL UNIQUE, + json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS attachments ( + hash TEXT PRIMARY KEY, + content TEXT NOT NULL + ); + """ + ) + conn.commit() diff --git a/src/inspect_ai/util/_checkpoint/_layout/host_context.py b/src/inspect_ai/util/_checkpoint/_layout/host_context.py index 32c1fb518e..7cd8b3bc20 100644 --- a/src/inspect_ai/util/_checkpoint/_layout/host_context.py +++ b/src/inspect_ai/util/_checkpoint/_layout/host_context.py @@ -24,13 +24,10 @@ from pathlib import Path from typing import Any -import anyio from pydantic import JsonValue -from pydantic_core import to_jsonable_python from inspect_ai.event._event import Event from inspect_ai.event._validate import validate_chat_messages, validate_events_json -from inspect_ai.log import EventsData from inspect_ai.model._chat_message import ChatMessage EVENTS = "events.json" @@ -55,18 +52,6 @@ class HostContext: ``None`` is returned when the file is absent.""" -async def write(working_dir: str, ctx: HostContext) -> None: - """Write all host-context files to ``working_dir``, overwriting in place.""" - sample_dir = anyio.Path(working_dir) - await (sample_dir / EVENTS).write_text(_json_dump(ctx.condensed_events)) - events_data = EventsData(messages=ctx.msg_pool, calls=ctx.call_pool) - await (sample_dir / EVENTS_DATA).write_text(_json_dump(events_data)) - await (sample_dir / ATTACHMENTS).write_text(_json_dump(ctx.attachments)) - await (sample_dir / STORE).write_text(_json_dump(ctx.store)) - if ctx.agent_state is not None: - await (sample_dir / AGENT_STATE).write_text(_json_dump(ctx.agent_state)) - - def read(working_dir: str) -> HostContext: """Read all host-context files from ``working_dir``. @@ -91,11 +76,3 @@ def read(working_dir: str) -> HostContext: store=store_data, agent_state=agent_state, ) - - -def _json_dump(obj: object) -> str: - """Serialize ``obj`` to JSON, excluding ``None`` fields, with a trailing newline.""" - return ( - json.dumps(to_jsonable_python(obj, exclude_none=True, fallback=lambda _: None)) - + "\n" - ) diff --git a/src/inspect_ai/util/_checkpoint/checkpointer.py b/src/inspect_ai/util/_checkpoint/checkpointer.py index 31041df511..276b7f0dea 100644 --- a/src/inspect_ai/util/_checkpoint/checkpointer.py +++ b/src/inspect_ai/util/_checkpoint/checkpointer.py @@ -134,6 +134,21 @@ class is unambiguous. ... +class CheckpointerSetup(Protocol): + """Per-sample setup object stored on ActiveSample. + + Enters to yield the agent-facing :class:`Checkpointer` and closes any + cached resources at sample teardown. ``close()`` is intentionally here, + not on ``Checkpointer``, so agents don't see lifecycle concerns. + """ + + async def __aenter__(self) -> Checkpointer: ... + + async def __aexit__(self, *exc: object) -> None: ... + + def close(self) -> None: ... + + @contextlib.asynccontextmanager async def checkpointer() -> AsyncIterator[Checkpointer]: """Enter the checkpointer bound to the active sample. diff --git a/src/inspect_ai/util/_checkpoint/checkpointer_factory.py b/src/inspect_ai/util/_checkpoint/checkpointer_factory.py index 4a0e60b94f..9ca31d48cd 100644 --- a/src/inspect_ai/util/_checkpoint/checkpointer_factory.py +++ b/src/inspect_ai/util/_checkpoint/checkpointer_factory.py @@ -1,8 +1,7 @@ import os -from contextlib import AbstractAsyncContextManager from inspect_ai._util.logger import warn_once -from inspect_ai.util._checkpoint.checkpointer import Checkpointer, ResumeCheckpoint +from inspect_ai.util._checkpoint.checkpointer import CheckpointerSetup, ResumeCheckpoint from inspect_ai.util._checkpoint.checkpointer_impl import _CheckpointerSetup, logger from inspect_ai.util._checkpoint.checkpointer_noop import _NoopCheckpointer from inspect_ai.util._checkpoint.config import ResolvedCheckpointConfig @@ -15,7 +14,7 @@ def create_checkpointer( sample_id: int | str, epoch: int, resume_checkpoint: ResumeCheckpoint | None = None, -) -> AbstractAsyncContextManager[Checkpointer]: +) -> CheckpointerSetup: """Build the per-sample checkpointer setup. Returns a :class:`_NoopCheckpointer` when ``config`` is ``None`` or diff --git a/src/inspect_ai/util/_checkpoint/checkpointer_impl.py b/src/inspect_ai/util/_checkpoint/checkpointer_impl.py index 84cbbcce05..531957537e 100644 --- a/src/inspect_ai/util/_checkpoint/checkpointer_impl.py +++ b/src/inspect_ai/util/_checkpoint/checkpointer_impl.py @@ -13,7 +13,13 @@ import contextlib import time -from collections.abc import AsyncIterator, Awaitable, Callable, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Mapping, + Sequence, +) from contextlib import AbstractAsyncContextManager from datetime import datetime, timezone from functools import partial @@ -22,26 +28,23 @@ from typing import Any, TypeVar import anyio -from pydantic import BaseModel, JsonValue, TypeAdapter +from pydantic import BaseModel, TypeAdapter from inspect_ai._util._async import tg_collect from inspect_ai.event._checkpoint import CheckpointEvent from inspect_ai.event._event import Event -from inspect_ai.log._pool import ( - _build_call_index, - _build_msg_index, - condense_model_event_calls, - condense_model_event_inputs, -) from inspect_ai.log._transcript import transcript -from inspect_ai.model._chat_message import ChatMessage from inspect_ai.solver._task_state import sample_state +from inspect_ai.util._checkpoint._event_store import ( + CHECKPOINT_EVENT_STORE, + CheckpointEventStore, +) from inspect_ai.util._restic import ResticBackupSummary, run_backup from inspect_ai.util._sandbox.context import sandbox from inspect_ai.util._span import span from inspect_ai.util._store import Store, store_jsonable -from ._layout import CheckpointDetails, SnapshotDetails, host_context, write_sidecar +from ._layout import CheckpointDetails, SnapshotDetails, write_sidecar from ._sandbox_restic import egress_sandbox, run_sandbox_backup from ._triggers import CheckpointTriggerKind, create_trigger from .checkpointer import ( @@ -86,6 +89,7 @@ def __init__( self._epoch = epoch self._resume_checkpoint = resume_checkpoint self._cached: _EnteredCheckpointer | None = None + self._reset_event_store_on_next_enter = True async def __aenter__(self) -> Checkpointer: if self._cached is not None: @@ -97,16 +101,24 @@ async def __aenter__(self) -> Checkpointer: epoch=self._epoch, resume_checkpoint=self._resume_checkpoint, ) + reset_event_store = self._reset_event_store_on_next_enter self._cached = _EnteredCheckpointer( config=self._config, hydration=result, resume_checkpoint=self._resume_checkpoint, + reset_event_store=reset_event_store, ) + self._reset_event_store_on_next_enter = False return self._cached async def __aexit__(self, *exc: object) -> None: return None + def close(self) -> None: + if self._cached is not None: + self._cached.close() + self._cached = None + class _EnteredCheckpointer: """Fully-formed agent-facing checkpointer. @@ -123,6 +135,7 @@ def __init__( config: ResolvedCheckpointConfig, hydration: HydrationResult, resume_checkpoint: ResumeCheckpoint | None, + reset_event_store: bool, ) -> None: self._config = config self._sample_checkpoints_dir = hydration.sample_checkpoints_dir @@ -146,30 +159,24 @@ def __init__( # work-between-fires window. Owned across `span_session()`'s # enter/exit and rotated inside `_fire()`. self._current_span_cm: AbstractAsyncContextManager[None] | None = None - # Persisted across fires: each fire processes only the new event slice - # and appends to these accumulators. Safe because checkpoints fire at - # turn boundaries, after which prior events are immutable. - # - # The accumulator + `_events_consumed` exist for performance — the - # next condense uses the prior pool as a starting point rather than - # re-walking the full transcript each fire. Revisit if profiling - # later shows the from-scratch alternative is fine at expected scale. - # - # `_events_consumed` is set lazily by the first `_open_next_span()` - # call to the transcript index where that first `span_begin: - # checkpoint` will land — so pre-first-span setup events (system - # message, sample init chatter) never enter the accumulator, and - # the persisted snapshot contains only checkpoint spans + contents. - # On resume, hydrate seeds the pools and pushes prior span content - # into the transcript; the lazy init then captures the index of the - # new `span_begin checkpoint M+1` so the next fire's slice is just - # the new span. - self._condensed_events: list[Event] = list(hydration.host.condensed_events) - self._msg_pool: list[ChatMessage] = list(hydration.host.msg_pool) - self._msg_index: dict[str, int] = _build_msg_index(self._msg_pool) - self._call_pool: list[JsonValue] = list(hydration.host.call_pool) - self._call_index: dict[str, int] = _build_call_index(self._call_pool) - self._events_consumed: int | None = None + self._event_store = CheckpointEventStore( + Path(self._sample_working_dir) / CHECKPOINT_EVENT_STORE, + reset=reset_event_store, + ) + self._transcript_subscription: Callable[[], None] | None = None + self._closed = False + self._transcript_seeded = False + self._seed_event_store(hydration) + self._ensure_transcript_subscription() + + def close(self) -> None: + if self._closed: + return + if self._transcript_subscription is not None: + self._transcript_subscription() + self._transcript_subscription = None + self._event_store.close() + self._closed = True @property def is_resuming(self) -> bool: @@ -201,11 +208,6 @@ async def _open_next_span(self) -> None: next_id = await anyio.to_thread.run_sync( _scan_next_checkpoint_id, self._sample_checkpoints_dir ) - # First-span lazy init for `_events_consumed`: capture the - # transcript index where the about-to-open `span_begin` will land - # so the persisted snapshot starts at the first checkpoint span. - if self._events_consumed is None: - self._events_consumed = len(transcript().events) cm = span(name=f"checkpoint {next_id}", type="checkpoint") await cm.__aenter__() self._current_span_cm = cm @@ -242,7 +244,7 @@ def track( ) self._on_checkpoint_callbacks[key] = callback if key in self._agent_state: - raw = self._agent_state[key] + raw = self._agent_state.pop(key) if value_type is not None: return TypeAdapter(value_type).validate_python(raw) if isinstance(initial_value, BaseModel): @@ -280,73 +282,76 @@ async def _fire(self, trigger: CheckpointTriggerKind) -> None: # ``SpanEndEvent`` lands in this checkpoint's ``events.json`` — # the persisted snapshot must show the span closing within it. await self._close_current_span() + try: + state = sample_state() + if not state: + raise RuntimeError("Checkpointer must find sample state") + ts = transcript() + await self._write_host_context( + self._sample_working_dir, + ts.resident_events, + ts.attachments, + state.store, + ) - state = sample_state() - if not state: - raise RuntimeError("Checkpointer must find sample state") - ts = transcript() - await self._write_host_context( - self._sample_working_dir, - ts.events, - ts.attachments, - state.store, - ) - - # Host + each sandbox (backup → egress) in parallel. The - # backup-then-egress pair for a given sandbox is sequential - # (egress diffs against what backup just wrote), but the pairs - # are independent across sandboxes and from the host backup. - # `tg_collect` takes thunks (zero-arg callables) so coroutines - # are only created at task-group start time. - sandbox_items = list((self._config.sandbox_paths or {}).items()) - backup_funcs: list[Callable[[], Awaitable[ResticBackupSummary]]] = [ - partial(self._backup_host, next_checkpoint_id), - *[ - partial( - self._backup_and_egress_sandbox, name, paths, next_checkpoint_id - ) - for name, paths in sandbox_items - ], - ] - summaries = await tg_collect(backup_funcs) - host_info = _snapshot_info(summaries[0]) - sandbox_infos = { - name: _snapshot_info(summary) - for (name, _), summary in zip(sandbox_items, summaries[1:]) - } - - # Cycle duration measured up to the sidecar write — the write - # itself is the commit point, so its cost lands on the next - # cycle's clock if anywhere. - duration_ms = int((time.monotonic() - cycle_start) * 1000) - - sidecar = CheckpointDetails( - checkpoint_id=next_checkpoint_id, - trigger=trigger, - turn=self._turn, - created_at=datetime.now(timezone.utc), - duration_ms=duration_ms, - size_bytes=host_info.size_bytes - + sum(s.size_bytes for s in sandbox_infos.values()), - host=host_info, - sandboxes=sandbox_infos, - ) - - await write_sidecar( - sample_checkpoints_dir=self._sample_checkpoints_dir, - sidecar=sidecar, - ) + # Host + each sandbox (backup → egress) in parallel. The + # backup-then-egress pair for a given sandbox is sequential + # (egress diffs against what backup just wrote), but the pairs + # are independent across sandboxes and from the host backup. + # `tg_collect` takes thunks (zero-arg callables) so coroutines + # are only created at task-group start time. + sandbox_items = list((self._config.sandbox_paths or {}).items()) + backup_funcs: list[Callable[[], Awaitable[ResticBackupSummary]]] = [ + partial(self._backup_host, next_checkpoint_id), + *[ + partial( + self._backup_and_egress_sandbox, + name, + paths, + next_checkpoint_id, + ) + for name, paths in sandbox_items + ], + ] + summaries = await tg_collect(backup_funcs) + host_info = _snapshot_info(summaries[0]) + sandbox_infos = { + name: _snapshot_info(summary) + for (name, _), summary in zip(sandbox_items, summaries[1:]) + } + + # Cycle duration measured up to the sidecar write — the write + # itself is the commit point, so its cost lands on the next + # cycle's clock if anywhere. + duration_ms = int((time.monotonic() - cycle_start) * 1000) + + sidecar = CheckpointDetails( + checkpoint_id=next_checkpoint_id, + trigger=trigger, + turn=self._turn, + created_at=datetime.now(timezone.utc), + duration_ms=duration_ms, + size_bytes=host_info.size_bytes + + sum(s.size_bytes for s in sandbox_infos.values()), + host=host_info, + sandboxes=sandbox_infos, + ) - # Emit the CheckpointEvent now that the sidecar is committed. - # By construction the event is NOT in this fire's events.json - # (already written above); it IS captured in the next fire's - # events.json. On resume, hydrate synthesizes the trailing - # event from the latest sidecar (working.md §8a). - transcript()._event(CheckpointEvent.from_details(sidecar)) + await write_sidecar( + sample_checkpoints_dir=self._sample_checkpoints_dir, + sidecar=sidecar, + ) - # Sidecar is committed; open the next `checkpoint N+1` span so - # subsequent agent events nest under it. - await self._open_next_span() + # Emit the CheckpointEvent now that the sidecar is committed. + # By construction the event is NOT in this fire's events.json + # (already written above); it IS captured in the next fire's + # events.json. On resume, hydrate synthesizes the trailing + # event from the latest sidecar (working.md §8a). + transcript()._event(CheckpointEvent.from_details(sidecar)) + finally: + # Reopen even if checkpointing fails after closing the prior span; + # subsequent agent events should stay nested under a checkpoint span. + await self._open_next_span() async def _write_host_context( self, @@ -355,75 +360,65 @@ async def _write_host_context( attachments: Mapping[str, str], store: Store, ) -> None: - """Write the host context across up to five files. - - - ``events.json`` — condensed events; ModelEvent inputs / calls - replaced with refs into the pools below. - - ``events_data.json`` — ``{messages, calls}`` dedup pools. - - ``attachments.json`` — hash → original-content pool that - ``ModelEvent.call`` refs (`attachment://`) point into. - Captured live by ``Transcript._process_event``; serialized - here so the snapshot is self-contained. - - ``store.json`` — Store key/value as a single JSON object. - - ``agent_state.json`` — agent-defined property bag, written - only when the agent registered at least one callback via - :meth:`Checkpointer.track`. Each registered key becomes a - top-level field in the dict. The agent's conversation - messages typically live here (e.g. under the ``"messages"`` - key) — the protocol no longer privileges them as a top-level - file. Presence on disk signals opt-in. - """ - # Pool ModelEvent input + call messages — the big O(N²) redundancy. - # We process only the new event slice each fire and append to the - # accumulators on the session, so total hashing work is O(N) over a - # sample rather than O(N) per fire. Safe because checkpoints fire at - # turn boundaries, after which prior events are immutable. - # Attachments come pre-extracted on the transcript (call payloads - # >100 chars are rewritten to attachment:// refs as events flow in, - # with originals in transcript.attachments) — we persist that pool - # here so resume can resolve the refs. - # `_events_consumed` is set lazily by the first `_open_next_span()`, - # which runs in `span_session().__aenter__()` before any fire can - # happen — so it's guaranteed non-None here. - assert self._events_consumed is not None - # Filter the new slice: persisted events.json contains only events - # inside checkpoint / prior_run spans (inclusive of begin/end) plus - # CheckpointEvents. Stray events that land between checkpoint spans - # (e.g. `sandbox:exec` / `sandbox:read_file` emitted by restic - # operations during the fire's backup phase) stay in the live - # transcript but don't get persisted. See working.md §5. - new = _filter_persisted_events(events[self._events_consumed :]) - if new: - cond, self._msg_index, new_msgs = condense_model_event_inputs( - new, len(self._msg_pool), self._msg_index - ) - self._msg_pool.extend(m for _, m in new_msgs) - cond, self._call_index, new_calls = condense_model_event_calls( - cond, len(self._call_pool), self._call_index - ) - self._call_pool.extend(c for _, c in new_calls) - self._condensed_events.extend(cond) - # Advance regardless of whether the filtered slice was empty: - # this fire's events have been consumed from the live transcript's - # perspective even if none made it into the persisted snapshot. - self._events_consumed = len(events) + """Write the host context snapshot files.""" agent_state = ( {key: cb() for key, cb in self._on_checkpoint_callbacks.items()} if self._on_checkpoint_callbacks else None ) - await host_context.write( + self._event_store.export_snapshot_files( sample_working_dir, - host_context.HostContext( - condensed_events=self._condensed_events, - msg_pool=self._msg_pool, - call_pool=self._call_pool, - attachments=dict(attachments), - store=store_jsonable(store), - agent_state=agent_state, - ), + store_json=store_jsonable(store), + agent_state=agent_state, + ) + + def _seed_event_store(self, hydration: HydrationResult) -> None: + if self._transcript_seeded: + return + ts = transcript() + try: + attachments = ts.attachments + self._event_store.merge_message_pool(hydration.host.msg_pool) + self._event_store.merge_call_pool(hydration.host.call_pool) + seeded_event_ids: set[str] = set() + if hydration.host.condensed_events: + self._event_store.merge_events( + hydration.host.condensed_events, attachments + ) + seeded_event_ids = { + event.uuid + for event in hydration.host.condensed_events + if event.uuid is not None + } + if ts.resident_events_truncated: + history_provider = ts._history_provider + if history_provider is None: + raise RuntimeError( + "Cannot seed checkpoint events from a truncated Transcript. " + "Create the checkpointer before bounded transcript eviction starts." + ) + history_provider.import_checkpoint_events(self._event_store) + else: + for event in ts.resident_events: + if event.uuid in seeded_event_ids: + continue + self._event_store.merge_event(event, attachments.get) + self._event_store.merge_attachments(attachments) + self._transcript_seeded = True + except Exception: + self.close() + raise + + def _ensure_transcript_subscription(self) -> None: + if self._transcript_subscription is not None: + return + self._transcript_subscription = transcript().subscribe( + self._track_transcript_event ) + def _track_transcript_event(self, event: Event) -> None: + self._event_store.merge_event(event, transcript().attachments.get) + async def _backup_host(self, checkpoint_id: int) -> ResticBackupSummary: return await run_backup( self._host_restic, @@ -481,55 +476,3 @@ def _snapshot_info(summary: ResticBackupSummary) -> SnapshotDetails: size_bytes=summary.data_added_packed, duration_ms=int(summary.total_duration * 1000), ) - - -def _filter_persisted_events(events: Sequence[Event]) -> list[Event]: - r"""Keep only events that belong in the persisted ``events.json``. - - Three things pass through: - - - Events inside a ``type="checkpoint"`` span (inclusive of the - span_begin/span_end). - - Events inside a ``type="prior_run"`` span (inclusive). - - ``CheckpointEvent``\ s, even when between checkpoint spans. - - Everything else is dropped — in practice the ``sandbox:exec`` and - ``sandbox:read_file`` events emitted by restic operations during - the fire's backup phase, which land in the live transcript between - ``span_end checkpoint N`` and ``span_begin checkpoint N+1``. They - stay in the live transcript (visible in ``inspect view`` of the - running eval) but don't get persisted. - """ - result: list[Event] = [] - # Track currently-open tracked-span ids. Depth = len(tracked_open_ids). - # We track checkpoint + prior_run spans; everything inside (including - # nested non-tracked spans like bash/tool) is kept. - tracked_open_ids: set[str] = set() - for e in events: - if e.event == "span_begin": - type_ = getattr(e, "type", None) - if type_ in ("checkpoint", "prior_run"): - tracked_open_ids.add(e.id) - result.append(e) - elif tracked_open_ids: - # Nested non-tracked span inside a tracked one — keep. - result.append(e) - # else: stray span_begin outside any tracked span — drop. - elif e.event == "span_end": - id_ = getattr(e, "id", None) - if id_ in tracked_open_ids: - tracked_open_ids.discard(id_) - result.append(e) - elif tracked_open_ids: - # Inner span_end inside a tracked span — keep. - result.append(e) - # else: stray span_end outside any tracked span — drop. - elif e.event == "checkpoint": - # CheckpointEvent — always keep, whether inside a tracked - # span or not (in practice it lands between checkpoint spans). - result.append(e) - elif tracked_open_ids: - # Inside a tracked span — keep. - result.append(e) - # else: stray event outside any tracked span — drop. - return result diff --git a/src/inspect_ai/util/_checkpoint/checkpointer_noop.py b/src/inspect_ai/util/_checkpoint/checkpointer_noop.py index 53cab24c02..b89174011f 100644 --- a/src/inspect_ai/util/_checkpoint/checkpointer_noop.py +++ b/src/inspect_ai/util/_checkpoint/checkpointer_noop.py @@ -32,6 +32,9 @@ async def checkpoint(self) -> None: def span_session(self) -> contextlib.AbstractAsyncContextManager[None]: return contextlib.nullcontext() + def close(self) -> None: + return None + def track( self, key: str, diff --git a/src/inspect_ai/util/_checkpoint/hydrate.py b/src/inspect_ai/util/_checkpoint/hydrate.py index cf1aef4149..b239c1bdb8 100644 --- a/src/inspect_ai/util/_checkpoint/hydrate.py +++ b/src/inspect_ai/util/_checkpoint/hydrate.py @@ -46,6 +46,7 @@ from inspect_ai._util.file import file, local_path from inspect_ai.event._checkpoint import CheckpointEvent from inspect_ai.event._event import Event +from inspect_ai.event._pool import materialize_pooled_events from inspect_ai.event._span import SpanBeginEvent, SpanEndEvent from inspect_ai.log._transcript import transcript from inspect_ai.model._chat_message import ChatMessage @@ -103,6 +104,12 @@ class _HostHydrationResult: call_pool: list[JsonValue] = field(default_factory=list) """From ``events_data.json[calls]`` — seeds the dedup pool.""" + attachments: dict[str, str] = field(default_factory=dict) + """From ``attachments.json`` — pushed into the live transcript on resume.""" + + store: dict[str, Any] = field(default_factory=dict) + """From ``store.json`` — pushed into the live sample state on resume.""" + @dataclass class HydrationResult: @@ -260,12 +267,13 @@ async def _hydrate_host( sample_checkpoints_dir = str(Path(host_repo).parent) result = await anyio.to_thread.run_sync( partial( - _load_and_push_host_state, + _load_host_state, sample_working_dir, sample_checkpoints_dir, latest_committed_id, ) ) + result = _push_host_state(result, sample_checkpoints_dir, latest_committed_id) _debug( f"[hydrate.host] resume done: " f"events={len(result.condensed_events)} " @@ -441,26 +449,12 @@ def _fs_copy_repo( _debug(f"[hydrate.copy] {label} repo: {src} -> {new_repo} ({file_count} files)") -def _load_and_push_host_state( +def _load_host_state( sample_working_dir: str, sample_checkpoints_dir: str, latest_committed_id: int | None, ) -> _HostHydrationResult: - """Read the restored host context and push framework state. - - Loads via :mod:`host_context`, pushes events/attachments/store into - the live ``Transcript`` and ``Store`` (so the agent's continued run - sees the cumulative history), and returns the parts the agent-facing - :class:`_EnteredCheckpointer` needs at construction: - - - ``agent_state`` — for ``track()`` to return persisted values. - - ``condensed_events``, ``msg_pool``, ``call_pool`` — seeds for the - checkpointer's pools so the next fire writes a cumulative snapshot. - - Validates the loaded events against expected resume invariants - (checkpoint span structure, sidecar parity) so a broken resume - fails loudly here rather than cascading into the agent loop. - """ + """Read restored host context and prepare it for loop-thread hydration.""" ctx = host_context.read(local_path(sample_working_dir)) _debug( @@ -489,6 +483,10 @@ def _load_and_push_host_state( f"checkpoint {latest_committed_id}" ) + rehydrated_events = materialize_pooled_events( + rehydrated_events, ctx.msg_pool, ctx.call_pool + ) + # Wrap the most-recent prior session's unwrapped checkpoint spans in # a new `prior_run` span before pushing — every prior # session ends up as a sibling wrap inside this attempt's events. @@ -496,25 +494,41 @@ def _load_and_push_host_state( # mechanics. pushed_events = _wrap_prior_run(rehydrated_events) - # Push framework-owned state into the live transcript + store so the - # agent's continued run appends to (rather than replaces) the prior - # history. Direct mutation of the internal lists bypasses - # ``_process_event`` — the events are already in their condensed, - # attachment-ref form and must not be reprocessed. The persisted - # `condensed_events` is already span-only (the writer's accumulator - # only ever captured events from the first `span_begin: checkpoint` - # onward), and the prior-session wrap we just added is balanced. + return _HostHydrationResult( + agent_state=ctx.agent_state, + condensed_events=pushed_events, + msg_pool=ctx.msg_pool, + call_pool=ctx.call_pool, + attachments=ctx.attachments, + store=ctx.store, + ) + + +def _push_host_state( + result: _HostHydrationResult, + sample_checkpoints_dir: str, + latest_committed_id: int | None, +) -> _HostHydrationResult: + """Push loaded host state into loop-owned Transcript and Store.""" ts = transcript() pre = [_event_label(e) for e in ts._events] - restored = [_event_label(e) for e in pushed_events] + restored = [_event_label(e) for e in result.condensed_events] _debug(f"[hydrate.host] pre-hydration transcript.events (n={len(pre)}): {pre}") _debug(f"[hydrate.host] restored events to push (n={len(restored)}): {restored}") - ts._events.extend(pushed_events) - ts._attachments.update(ctx.attachments) + result.condensed_events = materialize_pooled_events( + result.condensed_events, + result.msg_pool, + result.call_pool, + ) + ts._extend_restored_events( + result.condensed_events, + result.attachments, + notify_subscribers=True, + ) state = sample_state() if state is None: raise RuntimeError("_hydrate_host: no active sample state to populate Store") - for key, value in ctx.store.items(): + for key, value in result.store.items(): state.store.set(key, value) _debug( f"[hydrate.host] pushed: transcript.events={len(ts._events)} " @@ -522,14 +536,10 @@ def _load_and_push_host_state( f"store_keys={len(list(state.store.keys()))}" ) - _validate_resume_state(pushed_events, sample_checkpoints_dir, latest_committed_id) - - return _HostHydrationResult( - agent_state=ctx.agent_state, - condensed_events=pushed_events, - msg_pool=ctx.msg_pool, - call_pool=ctx.call_pool, + _validate_resume_state( + result.condensed_events, sample_checkpoints_dir, latest_committed_id ) + return result def _wrap_prior_run(events: list[Event]) -> list[Event]: diff --git a/tests/_eval/test_retry_error_events.py b/tests/_eval/test_retry_error_events.py new file mode 100644 index 0000000000..839e8e9521 --- /dev/null +++ b/tests/_eval/test_retry_error_events.py @@ -0,0 +1,275 @@ +from datetime import datetime, timezone +from types import TracebackType +from typing import Iterator + +import pytest +from pydantic import JsonValue + +from inspect_ai._eval.task.log import TaskLogger +from inspect_ai._eval.task.run import _eval_retry_error, _sample_transcript_config +from inspect_ai.event import Event, InfoEvent, ModelEvent +from inspect_ai.log import EvalError, Transcript +from inspect_ai.log._recorders.buffer.database import SampleBufferDatabase +from inspect_ai.log._recorders.buffer.history_provider import ( + BufferTranscriptHistoryProvider, +) +from inspect_ai.log._recorders.streaming import eval_retry_error_from_history +from inspect_ai.log._recorders.types import SampleEvent +from inspect_ai.log._transcript import init_transcript +from inspect_ai.model import ChatMessageUser, GenerateConfig, ModelOutput + + +def _model(uuid: str, content: str) -> ModelEvent: + output = ModelOutput.from_content("mockllm/model", content) + return ModelEvent( + uuid=uuid, + timestamp=datetime(2026, 5, 18, tzinfo=timezone.utc), + working_start=0.0, + model="mockllm/model", + input=[ChatMessageUser(content="question")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=output, + ) + + +def _error() -> EvalError: + return EvalError(message="boom", traceback="traceback", traceback_ansi="ansi") + + +def test_transcript_events_since_last_returns_suffix_from_latest_type() -> None: + first_info = InfoEvent(uuid="info-1", data={"note": "before"}) + first_model = _model("model-1", "first") + second_info = InfoEvent(uuid="info-2", data={"note": "middle"}) + second_model = _model("model-2", "second") + tail_info = InfoEvent(uuid="info-3", data={"note": "after"}) + transcript = Transcript( + [first_info, first_model, second_info, second_model, tail_info] + ) + + assert transcript.events_since_last(ModelEvent) == [second_model, tail_info] + + +def test_eval_retry_error_uses_latest_resident_model_event_suffix() -> None: + first_model = _model("model-1", "first") + middle_info = InfoEvent(uuid="info-1", data={"note": "middle"}) + second_model = _model("model-2", "second") + tail_info = InfoEvent(uuid="info-2", data={"note": "after"}) + init_transcript(Transcript([first_model, middle_info, second_model, tail_info])) + + retry = _eval_retry_error(_error()) + + assert retry.events == [second_model, tail_info] + + +def test_eval_retry_error_does_not_claim_evicted_bounded_history() -> None: + first_model = _model("model-1", "first") + middle_info = InfoEvent(uuid="info-1", data={"note": "middle"}) + tail_info = InfoEvent(uuid="info-2", data={"note": "after"}) + bounded = Transcript(bounded=True, resident_tail=2) + init_transcript(bounded) + bounded._event(first_model) + bounded._event(middle_info) + bounded._event(tail_info) + + retry = _eval_retry_error(_error()) + + assert bounded.resident_events_truncated is True + assert retry.events == [] + + +def test_eval_retry_error_uses_buffer_history_when_transcript_truncated( + tmp_path, +) -> None: + first_model = _model("model-1", "first") + middle_info = InfoEvent(uuid="info-1", data={"note": "middle"}) + second_model = _model("model-2", "second") + tail_info = InfoEvent(uuid="info-2", data={"note": "after"}) + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events( + [ + SampleEvent(id="sample", epoch=1, event=first_model), + SampleEvent(id="sample", epoch=1, event=middle_info), + SampleEvent(id="sample", epoch=1, event=second_model), + SampleEvent(id="sample", epoch=1, event=tail_info), + ] + ) + + bounded = Transcript(bounded=True, resident_tail=1) + init_transcript(bounded) + bounded._event(first_model) + bounded._event(middle_info) + bounded._event(second_model) + bounded._event(tail_info) + assert bounded.resident_events_truncated is True + + with db.open_sample_history("sample", 1) as history: + retry = eval_retry_error_from_history(_error(), history) + + assert retry.events is not None + assert [event.uuid for event in retry.events] == ["model-2", "info-2"] + assert isinstance(retry.events[0], ModelEvent) + assert len(retry.events[0].input) == 1 + assert isinstance(retry.events[0].input[0], ChatMessageUser) + assert retry.events[0].input[0].content == "question" + assert retry.events[0].input_refs is None + + +def test_bounded_transcript_events_since_last_uses_buffer_provider(tmp_path) -> None: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + first_model = ModelEvent( + uuid="model-1", + model="mockllm/model", + input=[ChatMessageUser(content="first")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "first"), + ) + middle = InfoEvent(uuid="info-1", data="middle") + second_model = ModelEvent( + uuid="model-2", + model="mockllm/model", + input=[ChatMessageUser(content="second")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "second"), + ) + tail = InfoEvent(uuid="info-2", data="tail") + events: list[Event] = [first_model, middle, second_model, tail] + db.log_events([SampleEvent(id="sample", epoch=0, event=event) for event in events]) + + provider = BufferTranscriptHistoryProvider(db, "sample", 0) + transcript = Transcript(bounded=True, resident_tail=1, history_provider=provider) + for event in events: + transcript._event(event) + + assert transcript.resident_events_truncated is True + assert transcript.events_since_last(ModelEvent) == [second_model, tail] + + +def test_buffer_provider_iter_events_streams_first_event_before_later_rows() -> None: + first = InfoEvent(uuid="info-1", data="first") + + class RaisingHistory: + events_data: dict[str, list[JsonValue]] = {"messages": [], "calls": []} + + def iter_events(self) -> Iterator[Event]: + yield first + raise AssertionError("iter_events should not materialize later rows first") + + class HistoryContext: + def __enter__(self) -> RaisingHistory: + return RaisingHistory() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + return None + + class FakeBufferDb: + def open_sample_history(self, sample_id: str, epoch: int) -> HistoryContext: + assert sample_id == "sample" + assert epoch == 0 + return HistoryContext() + + def sample_event_count(self, sample_id: str, epoch: int) -> int: + return 1 + + provider = BufferTranscriptHistoryProvider(FakeBufferDb(), "sample", 0) # type: ignore[arg-type] + events = provider.iter_events() + + assert next(events) == first + with pytest.raises(AssertionError, match="later rows"): + next(events) + + +def test_eval_retry_error_uses_provider_when_transcript_resident_tail_truncated( + tmp_path, +) -> None: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + first_model = _model("model-1", "first") + middle = InfoEvent(uuid="info-1", data="middle") + second_model = _model("model-2", "second") + tail = InfoEvent(uuid="info-2", data="tail") + events: list[Event] = [first_model, middle, second_model, tail] + db.log_events([SampleEvent(id="sample", epoch=0, event=event) for event in events]) + + provider = BufferTranscriptHistoryProvider(db, "sample", 0) + transcript = Transcript(bounded=True, resident_tail=1, history_provider=provider) + init_transcript(transcript) + for event in events: + transcript._event(event) + + retry = _eval_retry_error(_error()) + + assert transcript.resident_events_truncated is True + assert transcript.full_history_available is True + assert retry.events is not None + assert [event.uuid for event in retry.events] == ["model-2", "info-2"] + + +def test_eval_retry_error_preserves_epoch_zero(tmp_path) -> None: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events( + [ + SampleEvent(id="sample", epoch=0, event=_model("model-0", "zero")), + SampleEvent(id="sample", epoch=1, event=_model("model-1", "one")), + ] + ) + logger = _TaskLoggerShim(db) + + retry = _eval_retry_error(_error(), logger, "sample", 0) + + assert retry.events is not None + assert [event.uuid for event in retry.events] == ["model-0"] + + +def test_eval_retry_error_requires_epoch_for_buffer_history(tmp_path) -> None: + logger = _TaskLoggerShim( + SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + ) + + with pytest.raises( + ValueError, match="epoch is required when reading retry events from buffer DB" + ): + _eval_retry_error(_error(), logger, "sample") + + +def test_sample_transcript_config_requires_buffer_for_bounded_mode( + monkeypatch, +) -> None: + monkeypatch.setenv("INSPECT_TRANSCRIPT_BOUNDED", "true") + + bounded, history_provider = _sample_transcript_config( + logger=_TaskLoggerShim(buffer_db=None), sample_id="sample", epoch=0 + ) + + assert bounded is False + assert history_provider is None + + +def test_sample_transcript_config_defaults_to_unbounded_with_buffer( + monkeypatch, tmp_path +) -> None: + monkeypatch.delenv("INSPECT_TRANSCRIPT_BOUNDED", raising=False) + logger = _TaskLoggerShim( + SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + ) + + bounded, history_provider = _sample_transcript_config( + logger=logger, sample_id="sample", epoch=0 + ) + + assert bounded is False + assert history_provider is not None + + +class _TaskLoggerShim(TaskLogger): + def __init__(self, buffer_db: SampleBufferDatabase | None) -> None: + self._buffer_db = buffer_db diff --git a/tests/checkpoint/test_checkpoint_event_store.py b/tests/checkpoint/test_checkpoint_event_store.py new file mode 100644 index 0000000000..989e9e87e3 --- /dev/null +++ b/tests/checkpoint/test_checkpoint_event_store.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +import json +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import cast + +import pytest +from pydantic import JsonValue + +from inspect_ai._util.constants import get_deserializing_context +from inspect_ai.event._info import InfoEvent +from inspect_ai.event._model import ModelEvent +from inspect_ai.log import expand_events +from inspect_ai.model._chat_message import ( + ChatMessage, + ChatMessageAssistant, + ChatMessageSystem, + ChatMessageUser, +) +from inspect_ai.model._generate_config import GenerateConfig +from inspect_ai.model._model_call import ModelCall +from inspect_ai.model._model_output import ModelOutput +from inspect_ai.util._checkpoint._event_store import ( + CHECKPOINT_EVENT_STORE, + CheckpointEventStore, +) + + +def _exported_events(work_dir: Path) -> list[dict[str, object]]: + return json.loads((work_dir / "events.json").read_text()) + + +def _no_attachment(_: str) -> str | None: + return None + + +def _model_event( + messages: list[ChatMessage], call: ModelCall | None = None +) -> ModelEvent: + return ModelEvent( + model="test", + input=messages, + tools=[], + tool_choice="auto", + config=GenerateConfig(), + output=ModelOutput(), + call=call, + ) + + +def test_checkpoint_event_store_initializes_schema(tmp_path: Path) -> None: + store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + + counts = store.counts() + + assert counts.events == 0 + assert counts.message_pool == 0 + assert counts.call_pool == 0 + assert counts.attachments == 0 + assert (tmp_path / CHECKPOINT_EVENT_STORE).exists() + + +def test_checkpoint_event_store_exports_events_in_first_seen_order( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + first = InfoEvent(data="first") + second = InfoEvent(data="second") + + first_id = first.uuid + second_id = second.uuid + + event_store.merge_event(first, attachment_lookup=_no_attachment) + event_store.merge_event(second, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events = _exported_events(tmp_path) + assert [event["data"] for event in events] == ["first", "second"] + assert [event["uuid"] for event in events] == [first_id, second_id] + + +def test_checkpoint_event_store_updates_existing_logical_event( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event = InfoEvent(data="first") + + event_id = event.uuid + event_store.merge_event(event, attachment_lookup=_no_attachment) + event.data = "updated" + event_store.merge_event(event, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events = _exported_events(tmp_path) + assert len(events) == 1 + assert events[0]["uuid"] == event_id + assert events[0]["data"] == "updated" + assert event_store.counts().events == 1 + + +def test_checkpoint_event_store_accepts_cross_thread_events(tmp_path: Path) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + input_events = [InfoEvent(data=f"from-thread-{index}") for index in range(8)] + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(event_store.merge_event, event, _no_attachment) + for event in input_events + ] + for future in futures: + future.result() + + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + exported_events = _exported_events(tmp_path) + assert {event["data"] for event in exported_events} == { + f"from-thread-{index}" for index in range(8) + } + + +def test_checkpoint_event_store_assigns_uuid_to_uuidless_events(tmp_path: Path) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event = InfoEvent.model_validate( + {"event": "info", "data": "uuidless"}, context=get_deserializing_context() + ) + assert event.uuid is None + + event_store.merge_event(event, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + assert isinstance(event.uuid, str) + events = _exported_events(tmp_path) + assert len(events) == 1 + assert events[0]["uuid"] == event.uuid + + +def test_checkpoint_event_store_assigns_uuid_for_uuidless_event_updates( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event = InfoEvent.model_validate( + {"event": "info", "data": "pending", "pending": True}, + context=get_deserializing_context(), + ) + assert event.uuid is None + + event_store.merge_event(event, attachment_lookup=_no_attachment) + event.pending = False + event.data = "done" + event_store.merge_event(event, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events = _exported_events(tmp_path) + assert len(events) == 1 + assert events[0]["uuid"] == event.uuid + assert events[0]["data"] == "done" + + +def test_checkpoint_event_store_serializes_agent_state_models( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + + event_store.export_snapshot_files( + tmp_path, + store_json={"store_message": ChatMessageSystem(content="store")}, + agent_state={ + "messages": [ + ChatMessageSystem(content="sys"), + ChatMessageUser(content="user"), + ] + }, + ) + + store_json = json.loads((tmp_path / "store.json").read_text()) + agent_state = json.loads((tmp_path / "agent_state.json").read_text()) + + assert store_json["store_message"]["role"] == "system" + assert store_json["store_message"]["content"] == "store" + assert agent_state["messages"][0]["role"] == "system" + assert agent_state["messages"][0]["content"] == "sys" + assert agent_state["messages"][1]["role"] == "user" + assert agent_state["messages"][1]["content"] == "user" + + +def test_checkpoint_event_store_writes_store_and_agent_state( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event_store.export_snapshot_files( + tmp_path, + store_json={"x": 1}, + agent_state={"agent": {"step": 2}}, + ) + + assert json.loads((tmp_path / "store.json").read_text()) == {"x": 1} + assert json.loads((tmp_path / "agent_state.json").read_text()) == { + "agent": {"step": 2} + } + + +def test_checkpoint_event_store_exports_only_referenced_attachments( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event_store.merge_event( + InfoEvent(data={"blob": "attachment://kept"}), + attachment_lookup={"kept": "payload", "unused": "ignore me"}.get, + ) + + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + assert json.loads((tmp_path / "attachments.json").read_text()) == { + "kept": "payload" + } + assert not (tmp_path / "agent_state.json").exists() + + +def test_checkpoint_event_store_warns_for_missing_attachment_ref( + tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + + with caplog.at_level( + logging.WARNING, logger="inspect_ai.util._checkpoint._event_store" + ): + event_store.merge_event( + InfoEvent(data={"blob": "attachment://missing"}), + attachment_lookup=_no_attachment, + ) + + assert "Checkpoint event references missing attachment: missing" in caplog.text + + +def test_checkpoint_event_store_attachment_refs_follow_condense_protocol() -> None: + from inspect_ai.log._condense import ATTACHMENT_PROTOCOL + + refs = CheckpointEventStore.attachment_refs_from_json( + json.dumps({"blob": f"{ATTACHMENT_PROTOCOL}kept"}) + ) + + assert refs == {"kept"} + + +def test_checkpoint_event_store_retains_cumulative_attachments( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + first_event = InfoEvent(data={"blob": "attachment://abc"}) + second_event = InfoEvent(data={"blob": "attachment://def"}) + + event_store.merge_event(first_event, attachment_lookup={"abc": "payload"}.get) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + assert json.loads((tmp_path / "attachments.json").read_text()) == {"abc": "payload"} + + event_store.merge_event(second_event, attachment_lookup={"def": "payload2"}.get) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + assert json.loads((tmp_path / "attachments.json").read_text()) == { + "abc": "payload", + "def": "payload2", + } + + +def test_checkpoint_event_store_retains_attachment_on_event_update( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event = InfoEvent(data={"blob": "attachment://abc"}) + + event_store.merge_event(event, attachment_lookup={"abc": "payload"}.get) + event.data = {"blob": "attachment://abc", "status": "done"} + event_store.merge_event(event, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + assert json.loads((tmp_path / "attachments.json").read_text()) == {"abc": "payload"} + + +def test_checkpoint_event_store_attaches_raw_model_call_update( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + event = _model_event( + [ChatMessageUser(content="question")], + call=ModelCall.create( + {"messages": [{"role": "user", "content": "short"}]}, None + ), + ) + event_store.merge_event(event, attachment_lookup=_no_attachment) + + raw_payload = "late payload" * 100 + event.call = ModelCall.create( + {"messages": [{"role": "user", "content": raw_payload}]}, None + ) + event_store.merge_event(event, attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + attachments = json.loads((tmp_path / "attachments.json").read_text()) + + assert raw_payload in attachments.values() + events_data = json.loads((tmp_path / "events_data.json").read_text()) + call = events_data["calls"][-1] + assert isinstance(call, dict) + content = call["content"] + assert isinstance(content, str) + assert content.startswith("attachment://") + + +def test_checkpoint_event_store_attaches_raw_model_input( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + raw_payload = "input payload" * 100 + + event_store.merge_event( + _model_event([ChatMessageUser(content=raw_payload)]), + attachment_lookup=_no_attachment, + ) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + attachments = json.loads((tmp_path / "attachments.json").read_text()) + events_data = json.loads((tmp_path / "events_data.json").read_text()) + + assert raw_payload in attachments.values() + content = events_data["messages"][0]["content"] + assert isinstance(content, str) + assert content.startswith("attachment://") + + +def test_checkpoint_event_store_exports_stable_message_pool(tmp_path: Path) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + sys = ChatMessageSystem(content="sys") + user = ChatMessageUser(content="question") + assistant = ChatMessageAssistant(content="answer") + + event_store.merge_event(_model_event([sys, user]), attachment_lookup=_no_attachment) + event_store.merge_event( + _model_event([sys, user, assistant]), attachment_lookup=_no_attachment + ) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events_data = json.loads((tmp_path / "events_data.json").read_text()) + events = json.loads((tmp_path / "events.json").read_text()) + + assert len(events_data["messages"]) == 3 + assert events[0]["input_refs"] == [[0, 2]] + assert events[1]["input_refs"] == [[0, 3]] + + expanded = expand_events( + (tmp_path / "events.json").read_text(), + (tmp_path / "events_data.json").read_text(), + ) + model_events = [event for event in expanded if isinstance(event, ModelEvent)] + assert [len(event.input) for event in model_events] == [2, 3] + + +def test_checkpoint_event_store_import_pool_entry_returns_canonical_position( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + + first_pos = event_store.merge_message_pool_entry( + "first-hash", ChatMessageUser(content="first").model_dump_json() + ) + duplicate_pos = event_store.merge_message_pool_entry( + "first-hash", ChatMessageUser(content="first").model_dump_json() + ) + second_pos = event_store.merge_message_pool_entry( + "second-hash", ChatMessageUser(content="second").model_dump_json() + ) + + assert first_pos == 0 + assert duplicate_pos == first_pos + assert second_pos == 1 + + +def test_checkpoint_event_store_exports_stable_call_pool(tmp_path: Path) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + request_message = cast( + dict[str, JsonValue], + {"role": "user", "content": "question"}, + ) + + event_store.merge_event( + _model_event( + [], + call=ModelCall( + request={"model": "test", "messages": [request_message]}, + response=None, + ), + ), + attachment_lookup=_no_attachment, + ) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events_data = json.loads((tmp_path / "events_data.json").read_text()) + events = json.loads((tmp_path / "events.json").read_text()) + + assert events_data["calls"] == [request_message] + assert events[0]["call"]["call_refs"] == [[0, 1]] + assert "messages" not in events[0]["call"]["request"] + + expanded = expand_events( + (tmp_path / "events.json").read_text(), + (tmp_path / "events_data.json").read_text(), + ) + model_events = [event for event in expanded if isinstance(event, ModelEvent)] + assert model_events[0].call is not None + assert model_events[0].call.request["messages"] == [request_message] + + +def test_checkpoint_event_store_reuses_pending_message_positions_on_update( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + message = ChatMessageUser(content="question") + event = _model_event([message]) + event.pending = True + + event_store.merge_event(event, attachment_lookup=_no_attachment) + event.pending = False + event_store.merge_event(event, attachment_lookup=_no_attachment) + + assert event_store.counts().message_pool == 1 + + +def test_checkpoint_event_store_deduplicates_messages_using_pool_hash( + tmp_path: Path, +) -> None: + event_store = CheckpointEventStore(tmp_path / CHECKPOINT_EVENT_STORE) + first = ChatMessageUser(content="same") + second = ChatMessageUser(content="same") + assert first.id != second.id + + event_store.merge_event(_model_event([first]), attachment_lookup=_no_attachment) + event_store.merge_event(_model_event([second]), attachment_lookup=_no_attachment) + event_store.export_snapshot_files(tmp_path, store_json={}, agent_state=None) + + events_data = json.loads((tmp_path / "events_data.json").read_text()) + events = json.loads((tmp_path / "events.json").read_text()) + + assert len(events_data["messages"]) == 1 + assert events[0]["input_refs"] == [[0, 1]] + assert events[1]["input_refs"] == [[0, 1]] diff --git a/tests/checkpoint/test_checkpointer.py b/tests/checkpoint/test_checkpointer.py index 157e71ff51..9777c1e221 100644 --- a/tests/checkpoint/test_checkpointer.py +++ b/tests/checkpoint/test_checkpointer.py @@ -11,7 +11,7 @@ import json import os -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta @@ -19,11 +19,13 @@ from unittest.mock import patch import pytest +from test_helpers.transcript import FakeTranscriptHistoryProvider from inspect_ai.event._event import Event +from inspect_ai.event._info import InfoEvent from inspect_ai.event._model import ModelEvent -from inspect_ai.event._span import SpanBeginEvent, SpanEndEvent -from inspect_ai.log import expand_events +from inspect_ai.log import Transcript, expand_events +from inspect_ai.log._transcript import init_transcript, transcript from inspect_ai.model._chat_message import ( ChatMessage, ChatMessageAssistant, @@ -31,6 +33,7 @@ ChatMessageUser, ) from inspect_ai.model._generate_config import GenerateConfig +from inspect_ai.model._model_call import ModelCall from inspect_ai.model._model_output import ModelOutput from inspect_ai.util._checkpoint import ( Manual, @@ -39,7 +42,9 @@ TurnInterval, checkpointer, ) +from inspect_ai.util._checkpoint._event_store import CheckpointEventStore from inspect_ai.util._checkpoint._triggers import CheckpointTriggerKind +from inspect_ai.util._checkpoint.checkpointer import ResumeCheckpoint from inspect_ai.util._checkpoint.checkpointer_impl import ( _CheckpointerSetup, _EnteredCheckpointer, @@ -88,26 +93,14 @@ class _Dirs: @contextmanager def _patch_sample_runtime(events: list[object]) -> Iterator[None]: - """Patch sample_state() and transcript() for tests that drive _fire. - - `_CheckpointerSetup._fire` reads ``sample_state().store`` and - ``transcript().events`` directly from ContextVars. Tests that - construct `_CheckpointerSetup` outside a real sample run need stand-ins. - - ``events`` is the externally-owned event-collection list — the fake - transcript's ``events`` attribute and ``_event`` callable both wire - to it so tests can inspect emit behavior. - """ + """Patch sample_state() and transcript() for tests that drive _fire.""" from types import SimpleNamespace from inspect_ai.util._store import Store fake_state = SimpleNamespace(store=Store()) - fake_transcript = SimpleNamespace( - events=events, - attachments={}, - _event=events.append, - ) + fake_transcript = Transcript(bounded=False) + fake_transcript.subscribe(events.append) with ( patch( "inspect_ai.util._checkpoint.checkpointer_impl.sample_state", @@ -133,6 +126,13 @@ def dirs(tmp_path: Path) -> Iterator[_Dirs]: yield d +@pytest.fixture(autouse=True) +def _isolate_transcript() -> Iterator[None]: + init_transcript(Transcript(bounded=False)) + yield + init_transcript(Transcript(bounded=False)) + + class _CountingCheckpointer(_EnteredCheckpointer): """Counts fires on top of the real fire path; stubs out restic.""" @@ -164,11 +164,8 @@ def _counting(config: ResolvedCheckpointConfig, dirs: _Dirs) -> _CountingCheckpo config=config, hydration=_fake_hydration(dirs.checkpoints, dirs.working), resume_checkpoint=None, + reset_event_store=True, ) - # Policy tests drive `tick()`/`checkpoint()` without going through - # `span_session()`, so the first-span lazy init for `_events_consumed` - # never fires. Seed it here so `_write_host_context` can slice. - cp._events_consumed = 0 return cp @@ -325,6 +322,44 @@ async def test_fire_emits_checkpoint_event(dirs: _Dirs) -> None: assert first.host.snapshot_id.startswith("fake-snap-") +async def test_fire_includes_events_emitted_before_checkpointer_construction( + dirs: _Dirs, +) -> None: + preexisting = InfoEvent(data="before-checkpointer") + active_transcript = transcript() + active_transcript._event(preexisting) + + with patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=active_transcript, + ): + cp = _counting(ResolvedCheckpointConfig(trigger=Manual()), dirs) + await cp.checkpoint() + + events = json.loads((Path(dirs.working) / "events.json").read_text()) + assert events[0]["uuid"] == preexisting.uuid + assert events[0]["data"] == "before-checkpointer" + + +async def test_fire_reopens_checkpoint_span_after_failure(dirs: _Dirs) -> None: + cp = _counting(ResolvedCheckpointConfig(trigger=Manual()), dirs) + assert cp._current_span_cm is None + await cp._open_next_span() + open_before = cp._current_span_cm + assert open_before is not None + + async def fail_write_host_context(*_args: object) -> None: + raise RuntimeError("write failed") + + with patch.object(cp, "_write_host_context", side_effect=fail_write_host_context): + with pytest.raises(RuntimeError, match="write failed"): + await cp.checkpoint() + + assert cp._current_span_cm is not None + assert cp._current_span_cm is not open_before + await cp._close_current_span() + + def test_synthesize_trailing_checkpoint_event(tmp_path: Path) -> None: """Hydrate reconstructs the trailing CheckpointEvent from a sidecar.""" from datetime import datetime, timezone @@ -516,24 +551,6 @@ async def test_fire_writes_sample_json_and_sidecars( # === _write_host_context: condensed events round-trip ======================= -def _wrap_in_checkpoint_span(checkpoint_id: int, events: list[Event]) -> list[Event]: - """Wrap a list of events in `span_begin/span_end` of type "checkpoint". - - `_write_host_context`'s filter (`_filter_persisted_events`) keeps only - events inside checkpoint / prior_run spans + CheckpointEvents. Tests - that drive `_write_host_context` directly with raw events need to - bracket them so they survive the filter. - """ - span_id = f"test-ckpt-{checkpoint_id}" - return [ - SpanBeginEvent( - id=span_id, name=f"checkpoint {checkpoint_id}", type="checkpoint" - ), - *events, - SpanEndEvent(id=span_id), - ] - - async def test_write_host_context_condenses_and_round_trips(tmp_path: Path) -> None: """Pooled ModelEvent inputs round-trip via expand_events; pool < total slots.""" msg_sys: ChatMessage = ChatMessageSystem(content="sys") @@ -554,16 +571,12 @@ def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: ) # Each ModelEvent carries the full prior history — 2 + 4 + 5 = 11 input - # slots across 5 unique messages. Wrapped in a checkpoint span so the - # write_host_context filter keeps them. - events: list[Event] = _wrap_in_checkpoint_span( - 1, - [ - _model_event([msg_sys, msg_u1]), - _model_event([msg_sys, msg_u1, msg_a1, msg_u2]), - _model_event(messages), - ], - ) + # slots across 5 unique messages. + events: list[Event] = [ + _model_event([msg_sys, msg_u1]), + _model_event([msg_sys, msg_u1, msg_a1, msg_u2]), + _model_event(messages), + ] work = tmp_path / "work" work.mkdir() @@ -571,9 +584,11 @@ def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), hydration=_fake_hydration("/tmp/cp-test/ckpts", "/tmp/cp-test/work"), resume_checkpoint=None, + reset_event_store=True, ) - cp._events_consumed = 0 - await cp._write_host_context(str(work), events, {}, Store()) + for event in events: + cp._track_transcript_event(event) + await cp._write_host_context(str(work), [], {}, Store()) assert (work / "events.json").is_file() assert (work / "events_data.json").is_file() @@ -599,13 +614,10 @@ def _make_cp(**kwargs: object) -> _EnteredCheckpointer: "config": ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), "hydration": _fake_hydration(str(base / "ckpts"), str(base / "work")), "resume_checkpoint": None, + "reset_event_store": True, } defaults.update(kwargs) cp = _EnteredCheckpointer(**defaults) # type: ignore[arg-type] - # In real use, `_events_consumed` is set lazily by the first - # `_open_next_span()` call. Tests bypass that by driving - # `_write_host_context` directly, so seed the precondition here. - cp._events_consumed = 0 return cp @@ -623,21 +635,29 @@ def test_track_duplicate_key_raises(tmp_path: Path) -> None: cp.track("attempt_count", lambda: 2, 0) -async def test_track_single_key_writes_file(tmp_path: Path) -> None: - """Registered callback's return value lands in agent_state.json.""" +async def _write_agent_state( + tmp_path: Path, cp: _EnteredCheckpointer +) -> dict[str, object] | None: work = tmp_path / "work" work.mkdir() + await cp._write_host_context(str(work), [], {}, Store()) + agent_state = work / "agent_state.json" + if not agent_state.exists(): + return None + return json.loads(agent_state.read_text()) + + +async def test_track_single_key_writes_file(tmp_path: Path) -> None: + """Registered callback's return value lands in agent_state.json.""" cp = _make_cp() value = 3 cp.track("attempt_count", lambda: value, 0) - await cp._write_host_context(str(work), [], {}, Store()) - assert json.loads((work / "agent_state.json").read_text()) == {"attempt_count": 3} + + assert await _write_agent_state(tmp_path, cp) == {"attempt_count": 3} async def test_track_messages_via_track(tmp_path: Path) -> None: """Messages persist via `track('messages', ...)` — Pydantic model lists serialize.""" - work = tmp_path / "work" - work.mkdir() cp = _make_cp() messages: list[ChatMessage] = [ ChatMessageSystem(content="sys"), @@ -649,23 +669,23 @@ async def test_track_messages_via_track(tmp_path: Path) -> None: messages, value_type=list[ChatMessage], ) - await cp._write_host_context(str(work), [], {}, Store()) - state = json.loads((work / "agent_state.json").read_text()) + state = await _write_agent_state(tmp_path, cp) + assert state is not None assert "messages" in state - assert [m["role"] for m in state["messages"]] == ["system", "user"] - assert [m["content"] for m in state["messages"]] == ["sys", "hi"] + messages_state = state["messages"] + assert isinstance(messages_state, list) + assert [m["role"] for m in messages_state] == ["system", "user"] + assert [m["content"] for m in messages_state] == ["sys", "hi"] async def test_track_multiple_keys_merge(tmp_path: Path) -> None: """Multiple registered keys merge into one top-level dict.""" - work = tmp_path / "work" - work.mkdir() cp = _make_cp() cp.track("attempt_count", lambda: 3, 0) cp.track("phase", lambda: "explore", "") - await cp._write_host_context(str(work), [], {}, Store()) - assert json.loads((work / "agent_state.json").read_text()) == { + + assert await _write_agent_state(tmp_path, cp) == { "attempt_count": 3, "phase": "explore", } @@ -673,11 +693,9 @@ async def test_track_multiple_keys_merge(tmp_path: Path) -> None: async def test_track_not_registered_no_file(tmp_path: Path) -> None: """Without any callback, agent_state.json is not written.""" - work = tmp_path / "work" - work.mkdir() cp = _make_cp() - await cp._write_host_context(str(work), [], {}, Store()) - assert not (work / "agent_state.json").exists() + + assert await _write_agent_state(tmp_path, cp) is None def test_track_noop_session() -> None: @@ -692,8 +710,6 @@ def test_is_resuming_noop_false() -> None: def test_is_resuming_reflects_resume_checkpoint() -> None: - from inspect_ai.util._checkpoint.checkpointer import ResumeCheckpoint - assert _make_cp().is_resuming is False assert ( _make_cp( @@ -786,8 +802,9 @@ async def test_setup_aenter_defers_io_setup(tmp_path: Path) -> None: init_sandbox.assert_awaited_once() -async def test_write_host_context_persists_attachments(tmp_path: Path) -> None: - """transcript.attachments survives the checkpoint as attachments.json.""" +async def test_write_host_context_exports_event_store_attachments( + tmp_path: Path, +) -> None: attachments = {"abc123": "data:image/png;base64,iVBORw0", "def456": "long-text"} work = tmp_path / "work" @@ -796,9 +813,10 @@ async def test_write_host_context_persists_attachments(tmp_path: Path) -> None: config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), hydration=_fake_hydration("/tmp/cp-test/ckpts", "/tmp/cp-test/work"), resume_checkpoint=None, + reset_event_store=True, ) - cp._events_consumed = 0 - await cp._write_host_context(str(work), [], attachments, Store()) + cp._event_store.merge_attachments(attachments) + await cp._write_host_context(str(work), [], {}, Store()) assert json.loads((work / "attachments.json").read_text()) == attachments @@ -820,20 +838,12 @@ def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: output=ModelOutput(), ) - # Each fire's events wrapped in a checkpoint span so the - # write_host_context filter keeps them. - fire1_events: list[Event] = _wrap_in_checkpoint_span( - 1, - [ - _model_event([msg_sys, msg_u1]), - _model_event([msg_sys, msg_u1, msg_a1]), - ], - ) - # Fire 2 cumulatively contains fire 1's events + a new checkpoint span - # with one more model event. The two prior events stay condensed-as-is. + fire1_events: list[Event] = [ + _model_event([msg_sys, msg_u1]), + _model_event([msg_sys, msg_u1, msg_a1]), + ] fire2_events: list[Event] = [ - *fire1_events, - *_wrap_in_checkpoint_span(2, [_model_event([msg_sys, msg_u1, msg_a1, msg_u2])]), + _model_event([msg_sys, msg_u1, msg_a1, msg_u2]), ] work = tmp_path / "work" @@ -842,28 +852,26 @@ def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), hydration=_fake_hydration("/tmp/cp-test/ckpts", "/tmp/cp-test/work"), resume_checkpoint=None, + reset_event_store=True, ) - cp._events_consumed = 0 - - await cp._write_host_context(str(work), fire1_events, {}, Store()) + for event in fire1_events: + cp._track_transcript_event(event) + await cp._write_host_context(str(work), [], {}, Store()) pool_after_1 = json.loads((work / "events_data.json").read_text())["messages"] events_after_1 = json.loads((work / "events.json").read_text()) assert len(pool_after_1) == 3 # sys, u1, a1 - # Fire 1's persisted events: [span_begin_1, model_1, model_2, span_end_1]. - assert len(events_after_1) == 4 - assert cp._events_consumed == len(fire1_events) + assert len(events_after_1) == 2 - await cp._write_host_context(str(work), fire2_events, {}, Store()) + for event in fire2_events: + cp._track_transcript_event(event) + await cp._write_host_context(str(work), [], {}, Store()) pool_after_2 = json.loads((work / "events_data.json").read_text())["messages"] events_after_2 = json.loads((work / "events.json").read_text()) # Append-only: pool grew by exactly one (u2); first 3 entries unchanged. assert pool_after_2[:3] == pool_after_1 assert len(pool_after_2) == 4 - # Events grew by 3 (span_begin_2 + model_3 + span_end_2 = checkpoint 2's - # wrap); the first 4 are byte-identical to fire 1's persisted output. - assert events_after_2[:4] == events_after_1 - assert len(events_after_2) == 7 - assert cp._events_consumed == len(fire2_events) + assert events_after_2[:2] == events_after_1 + assert len(events_after_2) == 3 # Full round-trip still works on the cumulative output. expanded = expand_events( @@ -872,3 +880,527 @@ def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: ) model_events = [e for e in expanded if isinstance(e, ModelEvent)] assert [len(e.input) for e in model_events] == [2, 3, 4] + + +def test_track_consumes_hydrated_agent_state() -> None: + hydration = _fake_hydration("/tmp/cp-test/ckpts", "/tmp/cp-test/work") + hydration.host.agent_state = {"phase": "resume", "other": "kept"} + cp = _make_cp(hydration=hydration) + + assert cp.track("phase", lambda: "fresh", "fresh") == "resume" + + assert "phase" not in cp._agent_state + assert cp._agent_state == {"other": "kept"} + + +def test_seed_event_store_uses_history_provider_for_truncated_transcript( + tmp_path: Path, +) -> None: + provider_events = [ + InfoEvent(uuid=f"provider-event-{index}", data=f"from-provider-{index}") + for index in range(3) + ] + provider = FakeTranscriptHistoryProvider(provider_events) + fake_transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=provider, + ) + for event in provider_events: + fake_transcript._event(event) + assert fake_transcript.resident_events_truncated is True + + with patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ): + cp = _make_cp( + hydration=_fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "work")), + ) + + work = tmp_path / "snapshot" + work.mkdir() + cp._event_store.export_snapshot_files(work, store_json={}, agent_state=None) + cp.close() + + events = json.loads((work / "events.json").read_text()) + assert [event["data"] for event in events] == [ + "from-provider-0", + "from-provider-1", + "from-provider-2", + ] + + +def test_checkpointer_closes_store_when_transcript_seed_fails(tmp_path: Path) -> None: + fake_transcript = Transcript(bounded=True, resident_tail=0) + fake_transcript._event(InfoEvent(data="evicted")) + assert fake_transcript.resident_events_truncated is True + + with patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ): + with pytest.raises(RuntimeError, match="Cannot seed checkpoint events"): + _make_cp( + hydration=_fake_hydration("/tmp/cp-test/ckpts", str(tmp_path / "work")), + ) + + +async def test_checkpointer_setup_resets_event_store_after_seed_failure( + tmp_path: Path, +) -> None: + work = tmp_path / "work" + checkpoints = tmp_path / "ckpts" + work.mkdir() + checkpoints.mkdir() + fake_transcript = Transcript(bounded=False) + fake_transcript._event(InfoEvent(uuid="seeded", data="seeded")) + hydration = _fake_hydration(str(checkpoints), str(work)) + calls = 0 + + def fail_once_merge_event(self: CheckpointEventStore, *args: object) -> None: + nonlocal calls + calls += 1 + if calls == 1: + original_merge_event(self, *args) # type: ignore[arg-type] + raise RuntimeError("seed failed") + original_merge_event(self, *args) # type: ignore[arg-type] + + original_merge_event = CheckpointEventStore.merge_event + setup = _CheckpointerSetup( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + log_location=str(tmp_path / "t.eval"), + sample_id="s", + epoch=0, + ) + + with ( + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.hydrate", + return_value=hydration, + ), + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ), + patch.object(CheckpointEventStore, "merge_event", fail_once_merge_event), + ): + with pytest.raises(RuntimeError, match="seed failed"): + await setup.__aenter__() + + cp = await setup.__aenter__() + assert isinstance(cp, _EnteredCheckpointer) + snapshot = tmp_path / "snapshot" + snapshot.mkdir() + await cp._write_host_context(str(snapshot), [], {}, Store()) + setup.close() + + events = json.loads((snapshot / "events.json").read_text()) + assert [event["data"] for event in events] == ["seeded"] + + +async def test_fire_retains_attachment_from_evicted_event( + tmp_path: Path, +) -> None: + transcript = Transcript(bounded=True, resident_tail=1, log_model_api=True) + cp = _make_cp( + hydration=_fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "work")) + ) + payload = "evicted payload" * 100 + evicted = ModelEvent( + uuid="evicted", + model="mockllm/model", + input=[ChatMessageUser(content="question")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "answer"), + ) + evicted.call = ModelCall.create( + {"messages": [{"role": "user", "content": payload}]}, None + ) + + with patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=transcript, + ): + transcript.subscribe(cp._track_transcript_event) + transcript._event(evicted) + transcript._event(InfoEvent(data="resident")) + + assert transcript.resident_events == [transcript.last_event] + assert transcript.attachments == {} + + work = tmp_path / "snapshot" + work.mkdir() + await cp._write_host_context(str(work), [], {}, Store()) + cp.close() + + attachments = json.loads((work / "attachments.json").read_text()) + assert payload in attachments.values() + + +async def test_checkpointer_setup_close_unsubscribes_and_closes_store( + tmp_path: Path, +) -> None: + callbacks: list[Callable[[Event], None]] = [] + + def subscribe(callback: Callable[[Event], None]) -> Callable[[], None]: + callbacks.append(callback) + + def unsubscribe() -> None: + callbacks.remove(callback) + + return unsubscribe + + from types import SimpleNamespace + + fake_transcript = SimpleNamespace( + attachments={}, + resident_events=[], + resident_events_truncated=False, + full_history_available=True, + subscribe=subscribe, + ) + + setup = _CheckpointerSetup( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + log_location=str(tmp_path / "t.eval"), + sample_id="s", + epoch=0, + ) + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "work")) + Path(hydration.sample_checkpoints_dir).mkdir(parents=True) + Path(hydration.sample_working_dir).mkdir(parents=True) + + with ( + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.hydrate", + return_value=hydration, + ), + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ), + ): + async with setup as cp: + assert isinstance(cp, _EnteredCheckpointer) + assert len(callbacks) == 1 + assert len(callbacks) == 1 + + async with setup as cp2: + assert cp2 is cp + assert len(callbacks) == 1 + + setup.close() + assert callbacks == [] + assert setup._cached is None + setup.close() + + +async def test_checkpointer_setup_reconstruct_preserves_event_store( + tmp_path: Path, +) -> None: + from types import SimpleNamespace + + callbacks: list[Callable[[Event], None]] = [] + + def subscribe(callback: Callable[[Event], None]) -> Callable[[], None]: + callbacks.append(callback) + + def unsubscribe() -> None: + callbacks.remove(callback) + + return unsubscribe + + fake_transcript = SimpleNamespace( + attachments={}, + resident_events=[], + resident_events_truncated=False, + full_history_available=True, + subscribe=subscribe, + ) + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "work")) + Path(hydration.sample_checkpoints_dir).mkdir(parents=True) + Path(hydration.sample_working_dir).mkdir(parents=True) + setup = _CheckpointerSetup( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + log_location=str(tmp_path / "t.eval"), + sample_id="s", + epoch=0, + ) + + with ( + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.hydrate", + return_value=hydration, + ), + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ), + ): + async with setup as first: + assert isinstance(first, _EnteredCheckpointer) + first._track_transcript_event(InfoEvent(data="first")) + + async with setup as second: + assert isinstance(second, _EnteredCheckpointer) + second._track_transcript_event(InfoEvent(data="second")) + work = tmp_path / "snapshot" + work.mkdir() + await second._write_host_context(str(work), [], {}, Store()) + + events = json.loads((work / "events.json").read_text()) + assert [event["data"] for event in events] == ["first", "second"] + + +async def test_resume_resets_restored_event_store_before_seeding( + tmp_path: Path, +) -> None: + from types import SimpleNamespace + + callbacks: list[Callable[[Event], None]] = [] + + def subscribe(callback: Callable[[Event], None]) -> Callable[[], None]: + callbacks.append(callback) + return lambda: callbacks.remove(callback) + + work = tmp_path / "work" + work.mkdir() + restored_store = CheckpointEventStore(work / "checkpoint_events.sqlite", reset=True) + restored_store.merge_event(InfoEvent(data="orphan"), lambda _: None) + restored_store.close() + + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(work)) + hydration.host.condensed_events = [InfoEvent(data="committed")] + fake_transcript = SimpleNamespace( + attachments={}, + resident_events=[], + resident_events_truncated=False, + full_history_available=True, + subscribe=subscribe, + ) + setup = _CheckpointerSetup( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + log_location=str(tmp_path / "t.eval"), + sample_id="s", + epoch=0, + resume_checkpoint=ResumeCheckpoint( + sample_checkpoints_dir=str(tmp_path / "prior-ckpts") + ), + ) + + with ( + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.hydrate", + return_value=hydration, + ), + patch( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + return_value=fake_transcript, + ), + ): + cp = await setup.__aenter__() + assert isinstance(cp, _EnteredCheckpointer) + cp._track_transcript_event(InfoEvent(data="new")) + snapshot = tmp_path / "snapshot" + snapshot.mkdir() + await cp._write_host_context(str(snapshot), [], {}, Store()) + setup.close() + + events = json.loads((snapshot / "events.json").read_text()) + assert [event["data"] for event in events] == ["committed", "new"] + + +def test_resume_seed_skips_restored_resident_events( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + restored = InfoEvent(uuid="restored", data="committed") + fake_transcript = Transcript([restored], bounded=False) + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "work")) + hydration.host.condensed_events = [restored] + + monkeypatch.setattr( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + lambda: fake_transcript, + ) + + cp = _EnteredCheckpointer( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + hydration=hydration, + resume_checkpoint=ResumeCheckpoint( + sample_checkpoints_dir=str(tmp_path / "old") + ), + reset_event_store=True, + ) + snapshot = tmp_path / "snapshot" + snapshot.mkdir() + cp._event_store.export_snapshot_files(snapshot, store_json={}, agent_state=None) + cp.close() + + events = json.loads((snapshot / "events.json").read_text()) + assert [event["uuid"] for event in events] == ["restored"] + assert [event["data"] for event in events] == ["committed"] + + +async def test_resume_seed_preserves_message_pool_positions(tmp_path: Path) -> None: + msg_sys: ChatMessage = ChatMessageSystem(content="sys") + msg_u1: ChatMessage = ChatMessageUser(content="q1") + msg_a1: ChatMessage = ChatMessageAssistant(content="a1") + msg_u2: ChatMessage = ChatMessageUser(content="q2") + + def _model_event(input_msgs: list[ChatMessage]) -> ModelEvent: + return ModelEvent( + model="test", + input=input_msgs, + tools=[], + tool_choice="auto", + config=GenerateConfig(), + output=ModelOutput(), + ) + + first_work = tmp_path / "first" + first_work.mkdir() + first_cp = _EnteredCheckpointer( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + hydration=_fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "state1")), + resume_checkpoint=None, + reset_event_store=True, + ) + first_cp._track_transcript_event(_model_event([msg_sys, msg_u1, msg_a1])) + await first_cp._write_host_context(str(first_work), [], {}, Store()) + + from inspect_ai.util._checkpoint._layout import host_context + + first_context = host_context.read(str(first_work)) + assert len(first_context.condensed_events) == 1 + + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "state2")) + hydration.host.condensed_events = first_context.condensed_events + hydration.host.msg_pool = first_context.msg_pool + hydration.host.call_pool = first_context.call_pool + resumed_cp = _EnteredCheckpointer( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + hydration=hydration, + resume_checkpoint=None, + reset_event_store=True, + ) + + resumed_work = tmp_path / "resumed" + resumed_work.mkdir() + resumed_cp._track_transcript_event(_model_event([msg_u2])) + await resumed_cp._write_host_context(str(resumed_work), [], {}, Store()) + + resumed_data = json.loads((resumed_work / "events_data.json").read_text()) + assert len(resumed_data["messages"]) == 4 + expanded = expand_events( + (resumed_work / "events.json").read_text(), + (resumed_work / "events_data.json").read_text(), + ) + model_events = [event for event in expanded if isinstance(event, ModelEvent)] + assert [len(event.input) for event in model_events] == [3, 1] + assert [message.content for message in model_events[0].input] == ["sys", "q1", "a1"] + assert [message.content for message in model_events[1].input] == ["q2"] + + +async def test_resume_materializes_pooled_model_event_and_seeds_transcript( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + msg_sys: ChatMessage = ChatMessageSystem(content="sys") + msg_user: ChatMessage = ChatMessageUser(content="question") + call_request = {"messages": [{"role": "user", "content": "question"}]} + restored = ModelEvent( + uuid="restored-model", + model="test", + input=[msg_sys, msg_user], + tools=[], + tool_choice="auto", + config=GenerateConfig(), + output=ModelOutput(), + call=ModelCall.create(call_request, None), + ) + + first_work = tmp_path / "first" + first_work.mkdir() + first_cp = _EnteredCheckpointer( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + hydration=_fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "state1")), + resume_checkpoint=None, + reset_event_store=True, + ) + first_cp._track_transcript_event(restored) + await first_cp._write_host_context(str(first_work), [], {}, Store()) + first_cp.close() + + from inspect_ai.util._checkpoint._layout import host_context + + first_context = host_context.read(str(first_work)) + assert len(first_context.condensed_events) == 1 + condensed = first_context.condensed_events[0] + assert isinstance(condensed, ModelEvent) + assert condensed.input_refs is not None + assert condensed.call is not None + assert condensed.call.call_refs is not None + + fake_transcript = Transcript(bounded=True, resident_tail=10, log_model_api=True) + seeded_events: list[Event] = [] + fake_transcript.subscribe(seeded_events.append) + monkeypatch.setattr( + "inspect_ai.util._checkpoint.hydrate.transcript", + lambda: fake_transcript, + ) + monkeypatch.setattr( + "inspect_ai.util._checkpoint.checkpointer_impl.transcript", + lambda: fake_transcript, + ) + monkeypatch.setattr( + "inspect_ai.util._checkpoint.hydrate.sample_state", + lambda: type("FakeState", (), {"store": Store()})(), + ) + + hydration = _fake_hydration(str(tmp_path / "ckpts"), str(tmp_path / "state2")) + hydration.host.condensed_events = first_context.condensed_events + hydration.host.msg_pool = first_context.msg_pool + hydration.host.call_pool = first_context.call_pool + + from inspect_ai.util._checkpoint.hydrate import _push_host_state + + _push_host_state(hydration.host, str(tmp_path / "ckpts"), None) + assert seeded_events == fake_transcript.resident_events + seeded_model = seeded_events[0] + assert isinstance(seeded_model, ModelEvent) + assert seeded_model.input_refs is None + assert [message.content for message in seeded_model.input] == ["sys", "question"] + assert seeded_model.call is not None + assert seeded_model.call.call_refs is None + assert seeded_model.call.request == call_request + + resumed_cp = _EnteredCheckpointer( + config=ResolvedCheckpointConfig(trigger=TurnInterval(every=1)), + hydration=hydration, + resume_checkpoint=ResumeCheckpoint( + sample_checkpoints_dir=str(tmp_path / "prior") + ), + reset_event_store=True, + ) + + resumed_work = tmp_path / "resumed" + resumed_work.mkdir() + resumed_cp._track_transcript_event(InfoEvent(uuid="after-resume", data="new")) + await resumed_cp._write_host_context(str(resumed_work), [], {}, Store()) + resumed_cp.close() + + expanded = expand_events( + (resumed_work / "events.json").read_text(), + (resumed_work / "events_data.json").read_text(), + ) + assert [event.uuid for event in expanded] == ["restored-model", "after-resume"] + expanded_model = expanded[0] + assert isinstance(expanded_model, ModelEvent) + assert [message.content for message in expanded_model.input] == [ + "sys", + "question", + ] + assert expanded_model.call is not None + assert expanded_model.call.request == call_request diff --git a/tests/display/test_textual_transcript.py b/tests/display/test_textual_transcript.py new file mode 100644 index 0000000000..f930cbfe52 --- /dev/null +++ b/tests/display/test_textual_transcript.py @@ -0,0 +1,99 @@ +from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from contextlib import asynccontextmanager +from typing import cast + +import pytest + +from inspect_ai._display.textual.widgets.transcript import TranscriptView +from inspect_ai.dataset import Sample +from inspect_ai.event import Event, InfoEvent +from inspect_ai.event._sample_init import SampleInitEvent +from inspect_ai.log import Transcript +from inspect_ai.log._samples import ActiveSample +from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + + +class _RaisingHistoryProvider: + def __init__(self) -> None: + self.event_count_calls = 0 + + @property + def event_count(self) -> int: + self.event_count_calls += 1 + return 3 + + def iter_events(self) -> Iterator[Event]: + raise AssertionError("textual transcript view should use resident events") + + def events(self) -> Sequence[Event]: + raise AssertionError("textual transcript view should use resident events") + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + raise AssertionError("textual transcript view should use resident events") + + def events_from(self, start: int) -> Sequence[Event]: + raise AssertionError("textual transcript view should use resident events") + + def events_since_last(self, event_type: type[Event]) -> list[Event]: + raise AssertionError("textual transcript view should use resident events") + + def attachments(self) -> Mapping[str, str]: + return {} + + def attachment(self, hash: str) -> str | None: + return None + + def import_checkpoint_events(self, event_store: CheckpointEventStore) -> int: + raise AssertionError("textual transcript view should use resident events") + + +class _Sample: + id = "sample" + + def __init__(self, transcript: Transcript) -> None: + self.transcript = transcript + + +@pytest.mark.anyio +async def test_textual_transcript_view_uses_resident_events(monkeypatch) -> None: + provider = _RaisingHistoryProvider() + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=provider, + ) + sample_init = SampleInitEvent(sample=Sample(input="input", id="sample"), state={}) + transcript._event(sample_init) + transcript._event(InfoEvent(data="evicted")) + transcript._event(InfoEvent(data="resident")) + + async def remove_children(self: TranscriptView) -> None: + pass + + async def mount_all(self: TranscriptView, widgets: object) -> None: + pass + + @asynccontextmanager + async def batch(self: TranscriptView) -> AsyncIterator[None]: + yield + + def scroll_end(self: TranscriptView, animate: bool = False) -> None: + pass + + def widgets_for_events( + self: TranscriptView, events: Sequence[Event] + ) -> list[object]: + return [] + + monkeypatch.setattr(TranscriptView, "remove_children", remove_children) + monkeypatch.setattr(TranscriptView, "mount_all", mount_all) + monkeypatch.setattr(TranscriptView, "_widgets_for_events", widgets_for_events) + monkeypatch.setattr(TranscriptView, "batch", batch) + monkeypatch.setattr(TranscriptView, "scroll_end", scroll_end) + + view = TranscriptView() + view._active = True + + await view.sync_sample(cast(ActiveSample, _Sample(transcript))) + + assert provider.event_count_calls == 0 diff --git a/tests/log/test_transcript_bounded.py b/tests/log/test_transcript_bounded.py new file mode 100644 index 0000000000..14c3a30be5 --- /dev/null +++ b/tests/log/test_transcript_bounded.py @@ -0,0 +1,684 @@ +import contextvars +import logging +from typing import Sequence + +import pytest +from test_helpers.transcript import FakeTranscriptHistoryProvider + +from inspect_ai.dataset._dataset import Sample +from inspect_ai.event._event import Event +from inspect_ai.event._info import InfoEvent +from inspect_ai.event._model import ModelEvent +from inspect_ai.event._sample_init import SampleInitEvent +from inspect_ai.log._transcript import ( + Transcript, + transcript, + transcript_bounded_enabled, +) +from inspect_ai.model import GenerateConfig +from inspect_ai.model._chat_message import ChatMessageUser +from inspect_ai.model._model_call import ModelCall +from inspect_ai.model._model_output import ModelOutput + + +class _RaisingEventCountProvider(FakeTranscriptHistoryProvider): + @property + def event_count(self) -> int: + raise AssertionError("Transcript.event_count should be in-memory") + + +class _SliceOnlyProvider(FakeTranscriptHistoryProvider): + @property + def event_count(self) -> int: + raise AssertionError("Transcript slice should not read event_count") + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + raise AssertionError("Transcript positive slice should not read recent_events") + + +class _NoIterProvider(FakeTranscriptHistoryProvider): + def iter_events(self): + raise AssertionError( + "Transcript membership should not iterate provider history" + ) + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + raise AssertionError("Transcript resident tail should not read provider") + + +class _CountingIterProvider(FakeTranscriptHistoryProvider): + iterated: int = 0 + + def iter_events(self): + for event in self._events: + self.iterated += 1 + yield event + + +def _data(events): + return [event.data for event in events] + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + ("true", True), + ("1", True), + ("false", False), + (" FALSE ", False), + ("Off", False), + ], +) +def test_transcript_bounded_env_escape_hatch( + monkeypatch: pytest.MonkeyPatch, value: str | None, expected: bool +) -> None: + monkeypatch.delenv("INSPECT_TRANSCRIPT_BOUNDED", raising=False) + if value is not None: + monkeypatch.setenv("INSPECT_TRANSCRIPT_BOUNDED", value) + assert transcript_bounded_enabled() is expected + + +def test_transcript_context_default_is_lazy_and_isolated() -> None: + first_context = contextvars.Context() + second_context = contextvars.Context() + + first = first_context.run(transcript) + second = second_context.run(transcript) + + assert first is first_context.run(transcript) + assert second is second_context.run(transcript) + assert first is not second + + +def test_bounded_transcript_assigns_keys_to_uuidless_events() -> None: + first = InfoEvent.model_validate( + {"event": "info", "data": "first"}, context={"deserializing": True} + ) + second = InfoEvent.model_validate( + {"event": "info", "data": "second"}, context={"deserializing": True} + ) + transcript = Transcript(bounded=True, resident_tail=1) + + transcript._event(first) + transcript._event(second) + + assert first.uuid is not None + assert second.uuid is not None + assert first.uuid != second.uuid + assert _data(transcript.events) == ["second"] + + +def test_bounded_transcript_evictable_queue_stays_bounded() -> None: + transcript = Transcript(bounded=True, resident_tail=3) + + for index in range(20): + transcript._event(InfoEvent(data=index)) + + assert _data(transcript.resident_events) == [17, 18, 19] + assert len(transcript._evictable_event_ids) == 3 + + +def test_bounded_transcript_does_not_prune_when_no_event_is_evicted( + monkeypatch: pytest.MonkeyPatch, +) -> None: + transcript = Transcript(bounded=True, resident_tail=3) + prune_calls = 0 + original_prune = transcript._prune_pin_state + + def count_prune() -> None: + nonlocal prune_calls + prune_calls += 1 + original_prune() + + monkeypatch.setattr(transcript, "_prune_pin_state", count_prune) + + transcript._event(InfoEvent(data=1)) + transcript._event(InfoEvent(data=2)) + transcript._event(InfoEvent(data=3)) + + assert prune_calls == 0 + + +def test_bounded_transcript_evicts_to_resident_tail(): + transcript = Transcript(bounded=True, resident_tail=3) + + for data in range(5): + transcript._event(InfoEvent(data=data)) + + assert transcript.event_count == 5 + assert transcript.resident_events_truncated is True + assert _data(transcript.events) == [2, 3, 4] + assert _data(transcript.recent_events(2)) == [3, 4] + assert transcript.recent_events(0) == [] + assert transcript.last_event is not None + assert transcript.last_event.data == 4 + + +def test_bounded_transcript_events_uses_provider_for_full_history() -> None: + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.resident_events_truncated is True + assert transcript.resident_events_truncated is True + assert transcript.full_history_available is True + assert _data(transcript.resident_events) == [2] + assert _data(transcript.events) == [0, 1, 2] + assert len(transcript.events) == 3 + assert transcript.events[-1] is transcript.last_event + assert _data(transcript.events[1:]) == [1, 2] + + +def test_bounded_transcript_event_count_is_in_memory_with_provider() -> None: + full_history: list[Event] = [InfoEvent(data=0), InfoEvent(data=1)] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=_RaisingEventCountProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.resident_events_truncated is True + assert transcript.event_count == 2 + + +def test_provider_backed_events_len_uses_in_memory_event_count() -> None: + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=_RaisingEventCountProvider([InfoEvent(data=1)]), + ) + transcript._event(InfoEvent(data=1)) + transcript._event(InfoEvent(data=2)) + + assert len(transcript.events) == 2 + + +def test_full_history_available_distinguishes_provider_from_resident_truncation() -> ( + None +): + provider_backed = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider( + [InfoEvent(data=0), InfoEvent(data=1)] + ), + ) + provider_backed._event(InfoEvent(data=0)) + provider_backed._event(InfoEvent(data=1)) + + no_provider = Transcript(bounded=True, resident_tail=1) + no_provider._event(InfoEvent(data=0)) + no_provider._event(InfoEvent(data=1)) + + assert provider_backed.resident_events_truncated is True + assert provider_backed.full_history_available is True + assert no_provider.resident_events_truncated is True + assert no_provider.full_history_available is False + + +def test_provider_backed_events_supports_score_suffix_slice() -> None: + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + existing_sample_events = [full_history[0]] + suffix = transcript.events[len(existing_sample_events) :] + + assert _data(suffix) == [1, 2] + + +def test_provider_backed_events_suffix_slice_uses_single_provider_operation() -> None: + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=_SliceOnlyProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert _data(transcript.events[1:]) == [1, 2] + + +def test_provider_backed_events_membership_checks_resident_events_only() -> None: + evicted = InfoEvent(data="evicted") + resident = InfoEvent(data="resident") + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=_NoIterProvider([evicted, resident]), + ) + + transcript._event(evicted) + transcript._event(resident) + + assert resident in transcript.events + assert evicted not in transcript.events + + +def test_provider_backed_positive_index_streams_until_match() -> None: + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + provider = _CountingIterProvider(full_history) + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=provider, + ) + + for event in full_history: + transcript._event(event) + + event = transcript.events[1] + + assert isinstance(event, InfoEvent) + assert event.data == 1 + assert provider.iterated == 2 + + +def test_bounded_transcript_recent_events_uses_provider_when_resident_tail_insufficient() -> ( + None +): + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert _data(transcript.recent_events(2)) == [1, 2] + assert transcript.recent_events(0) == [] + assert _data(transcript.recent_events()) == [0, 1, 2] + + +def test_provider_backed_resident_tail_shortcuts_avoid_provider() -> None: + full_history: list[Event] = [ + InfoEvent(data=0), + InfoEvent(data=1), + InfoEvent(data=2), + ] + transcript = Transcript( + bounded=True, + resident_tail=2, + history_provider=_NoIterProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.resident_events_truncated is True + assert _data(transcript.recent_events(1)) == [2] + assert transcript.events[-1] is full_history[-1] + assert transcript.events[-2:] == full_history[-2:] + + +def test_bounded_transcript_recent_events_uses_provider_with_pinned_gap() -> None: + sample_init = SampleInitEvent( + sample=Sample(input="input", id="sample"), + state={}, + ) + full_history: list[Event] = [sample_init, InfoEvent(data=1), InfoEvent(data=2)] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.resident_events == [sample_init, full_history[-1]] + assert _data(transcript.recent_events(2)) == [1, 2] + + +def test_bounded_transcript_events_negative_index_uses_provider_with_pinned_tail() -> ( + None +): + sample_init = SampleInitEvent( + sample=Sample(input="input", id="sample"), + state={}, + ) + tail = InfoEvent(data="tail") + full_history: list[Event] = [sample_init, tail] + transcript = Transcript( + bounded=True, + resident_tail=0, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.resident_events == [sample_init] + assert transcript.last_event is sample_init + assert transcript.events[-1] is tail + assert list(transcript.events)[-1] is tail + + +def test_bounded_transcript_events_since_last_uses_provider_after_eviction() -> None: + first_model = ModelEvent( + model="mockllm/model", + input=[ChatMessageUser(content="first")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "first"), + ) + middle = InfoEvent(data="middle") + second_model = ModelEvent( + model="mockllm/model", + input=[ChatMessageUser(content="second")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "second"), + ) + tail = InfoEvent(data="tail") + full_history: list[Event] = [first_model, middle, second_model, tail] + transcript = Transcript( + bounded=True, + resident_tail=1, + history_provider=FakeTranscriptHistoryProvider(full_history), + ) + + for event in full_history: + transcript._event(event) + + assert transcript.events_since_last(ModelEvent) == [second_model, tail] + + +def test_events_since_last_raises_when_transcript_truncated() -> None: + transcript = Transcript(bounded=True, resident_tail=1) + transcript._event(InfoEvent(data="first")) + transcript._event(InfoEvent(data="second")) + + with pytest.raises(RuntimeError, match="Full transcript history is not available"): + transcript.events_since_last(ModelEvent) + + +def test_seeded_transcript_defaults_to_unbounded(): + transcript = Transcript([InfoEvent(data=1)], resident_tail=0) + + transcript._event(InfoEvent(data=2)) + + assert transcript.event_count == 2 + assert transcript.resident_events_truncated is False + assert _data(transcript.events) == [1, 2] + + +def test_sample_init_event_is_pinned_in_bounded_transcript(): + transcript = Transcript(bounded=True, resident_tail=1) + sample_init = SampleInitEvent( + sample=Sample(input="input", id="sample"), + state={}, + ) + + transcript._event(sample_init) + transcript._event(InfoEvent(data=1)) + transcript._event(InfoEvent(data=2)) + + assert transcript.event_count == 3 + assert transcript.resident_events_truncated is True + assert transcript.events == [sample_init, transcript.last_event] + + +def test_pending_event_is_pinned_in_bounded_transcript(): + transcript = Transcript(bounded=True, resident_tail=1) + pending = InfoEvent(data="pending", pending=True) + + transcript._event(pending) + transcript._event(InfoEvent(data=1)) + transcript._event(InfoEvent(data=2)) + + assert transcript.event_count == 3 + assert transcript.resident_events_truncated is True + assert transcript.events == [pending, transcript.last_event] + + +def test_completed_pending_event_is_evictable_on_update(): + transcript = Transcript(bounded=True, resident_tail=1) + pending = InfoEvent(data="pending", pending=True) + + transcript._event(pending) + transcript._event(InfoEvent(data=1)) + pending.pending = False + transcript._event_updated(pending) + + assert transcript.event_count == 2 + assert transcript.resident_events_truncated is True + assert _data(transcript.events) == [1] + + +def test_transcript_subscribe_receives_events_and_updates() -> None: + transcript = Transcript() + received: list[Event] = [] + unsubscribe = transcript.subscribe(received.append) + event = InfoEvent(data="first") + + transcript._event(event) + event.data = "updated" + transcript._event_updated(event) + unsubscribe() + transcript._event(InfoEvent(data="after")) + + assert received == [event, event] + + +def test_transcript_subscriber_exception_does_not_skip_processing( + caplog: pytest.LogCaptureFixture, +) -> None: + transcript = Transcript(bounded=True, resident_tail=0, log_model_api=True) + received: list[Event] = [] + + def bad_subscriber(event: Event) -> None: + raise RuntimeError("subscriber failed") + + transcript.subscribe(bad_subscriber) + transcript.subscribe(received.append) + event = _model_event_with_call_payload("event-1", "large payload" * 100) + + transcript_logger = logging.getLogger("inspect_ai.log._transcript") + transcript_logger.addHandler(caplog.handler) + original_propagate = transcript_logger.propagate + transcript_logger.propagate = False + try: + with caplog.at_level("WARNING", logger="inspect_ai.log._transcript"): + transcript._event(event) + finally: + transcript_logger.propagate = original_propagate + transcript_logger.removeHandler(caplog.handler) + + assert received == [event] + assert "Transcript subscriber failed" in caplog.text + assert event.call is not None + messages = event.call.request["messages"] + assert isinstance(messages, list) + message = messages[0] + assert isinstance(message, dict) + content = message["content"] + assert isinstance(content, str) + assert content.startswith("attachment://") + assert transcript.events == [] + assert transcript.resident_events_truncated is True + assert transcript.attachments == {} + + +def test_bounded_transcript_evicts_unreferenced_attachments() -> None: + transcript = Transcript(bounded=True, resident_tail=1, log_model_api=True) + first = _model_event_with_call_payload("event-1", "first large payload" * 100) + second = _model_event_with_call_payload("event-2", "second large payload" * 100) + + transcript._event(first) + first_attachments = set(transcript.attachments) + assert first_attachments + + transcript._event(second) + + assert not first_attachments.intersection(transcript.attachments) + assert transcript.attachments + + +def test_bounded_transcript_update_rebuilds_attachment_refs() -> None: + transcript = Transcript(bounded=True, resident_tail=1, log_model_api=True) + event = _model_event_with_call_payload("event-1", "first payload" * 100) + + transcript._event(event) + first_attachments = set(transcript.attachments) + event.call = ModelCall.create( + {"messages": [{"role": "user", "content": "second payload" * 100}]}, None + ) + transcript._event_updated(event) + + assert not first_attachments.intersection(transcript.attachments) + assert transcript.attachments + + +def test_bounded_transcript_accepts_non_json_metadata() -> None: + transcript = Transcript(bounded=True) + + transcript._event(InfoEvent(data="ok", metadata={"x": object()})) + + assert transcript.last_event is not None + assert isinstance(transcript.last_event, InfoEvent) + assert transcript.last_event.data == "ok" + + +def test_bounded_transcript_update_of_evicted_event_does_not_retain_attachments() -> ( + None +): + transcript = Transcript(bounded=True, resident_tail=1, log_model_api=True) + evicted = _model_event_with_call_payload("event-1", "evicted payload" * 100) + resident = InfoEvent(data="resident") + + transcript._event(evicted) + transcript._event(resident) + assert transcript.events == [resident] + + for index in range(3): + evicted.call = ModelCall.create( + {"messages": [{"role": "user", "content": f"late payload {index}" * 100}]}, + None, + ) + transcript._event_updated(evicted) + + assert transcript.events == [resident] + assert evicted.call is not None + messages = evicted.call.request["messages"] + assert isinstance(messages, list) + message = messages[0] + assert isinstance(message, dict) + content = message["content"] + assert isinstance(content, str) + assert not content.startswith("attachment://") + assert transcript.attachments == {} + + +def _model_event_with_call_payload(uuid: str, payload: str) -> ModelEvent: + event = ModelEvent( + uuid=uuid, + model="mockllm/model", + input=[ChatMessageUser(content="question")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", "answer"), + ) + event.call = ModelCall.create( + {"messages": [{"role": "user", "content": payload}]}, None + ) + return event + + +def test_bounded_transcript_external_mutation_keeps_original_attachment_ref() -> None: + transcript = Transcript(bounded=True, resident_tail=1, log_model_api=True) + first = _model_event_with_call_payload("event-1", "first payload" * 100) + second = InfoEvent(data="second") + + transcript._event(first) + first_attachments = dict(transcript.attachments) + assert first_attachments + + first.call = ModelCall.create( + {"messages": [{"role": "user", "content": "mutated payload" * 100}]}, None + ) + transcript._event(second) + + assert transcript.events == [second] + assert not any(hash in transcript.attachments for hash in first_attachments) + + +def test_restored_events_update_bounded_bookkeeping_and_evict() -> None: + transcript = Transcript(bounded=True, resident_tail=2) + events = [InfoEvent(data=i) for i in range(5)] + + transcript._extend_restored_events(events, {}) + + assert transcript.event_count == 5 + assert transcript.resident_events_truncated is True + assert _data(transcript.resident_events) == [3, 4] + assert _data(transcript.recent_events(2)) == [3, 4] + + +def test_restored_events_reject_duplicate_uuid() -> None: + transcript = Transcript(bounded=True, resident_tail=10) + first = InfoEvent(data="first", uuid="same") + duplicate = InfoEvent(data="duplicate", uuid="same") + + transcript._extend_restored_events([first], {}) + + with pytest.raises(ValueError, match="Duplicate event uuid"): + transcript._extend_restored_events([duplicate], {}) + + +def test_bounded_transcript_allows_duplicate_uuid_after_eviction() -> None: + transcript = Transcript(bounded=True, resident_tail=1) + first = InfoEvent(data="first", uuid="same") + second = InfoEvent(data="second", uuid="other") + duplicate = InfoEvent(data="duplicate", uuid="same") + + transcript._event(first) + transcript._event(second) + assert _data(transcript.resident_events) == ["second"] + + transcript._event(duplicate) + + assert _data(transcript.resident_events) == ["duplicate"] diff --git a/tests/test_helpers/transcript.py b/tests/test_helpers/transcript.py new file mode 100644 index 0000000000..cf3f3d01c9 --- /dev/null +++ b/tests/test_helpers/transcript.py @@ -0,0 +1,52 @@ +from collections.abc import Iterator, Sequence + +from inspect_ai._util.list import find_last_match +from inspect_ai.event._event import Event +from inspect_ai.util._checkpoint._event_store import CheckpointEventStore + + +class FakeTranscriptHistoryProvider: + def __init__( + self, events: Sequence[Event], attachments: dict[str, str] | None = None + ) -> None: + self._events = list(events) + self._attachments = dict(attachments or {}) + + @property + def event_count(self) -> int: + return len(self._events) + + def events(self) -> Sequence[Event]: + return list(self._events) + + def iter_events(self) -> Iterator[Event]: + return iter(self._events) + + def recent_events(self, n: int | None = None) -> Sequence[Event]: + if n is None: + return list(self._events) + if n <= 0: + return [] + return list(self._events[-n:]) + + def events_from(self, start: int) -> Sequence[Event]: + return list(self._events[start:]) + + def events_since_last(self, event_type: type[Event]) -> list[Event]: + events = list(self._events) + index = find_last_match(events, lambda event: isinstance(event, event_type)) + if index is not None: + return events[index:] + return events + + def attachments(self) -> dict[str, str]: + return dict(self._attachments) + + def attachment(self, hash: str) -> str | None: + return self._attachments.get(hash) + + def import_checkpoint_events(self, event_store: CheckpointEventStore) -> int: + for event in self._events: + event_store.merge_event(event, self._attachments.get) + event_store.merge_attachments(self._attachments) + return len(self._events)