Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/inspect_ai/_display/textual/widgets/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def sync_sample(self, sample: ActiveSample | None) -> None:
# if we have either a new sample or a new event count then proceed
if (
sample.id != self._sample_id
or len(sample.transcript.events) != self._sample_events
or sample.transcript.event_count != self._sample_events
):
# update (scrolling to end if we are already close to it)
new_sample = sample.id != self._sample_id
Expand All @@ -92,14 +92,14 @@ async def sync_sample(self, sample: ActiveSample | None) -> None:
async with self.batch():
await self.remove_children()
await self.mount_all(
self._widgets_for_events(sample.transcript.events)
self._widgets_for_events(sample.transcript.resident_events)
)
if scroll_to_end:
self.scroll_end(animate=not new_sample)

# set members
self._sample_id = sample.id
self._sample_events = len(sample.transcript.events)
self._sample_events = sample.transcript.event_count

# if we aren't active then save as a pending sample
else:
Expand Down
2 changes: 1 addition & 1 deletion src/inspect_ai/_eval/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ async def _run_score_task(
init_subtask_store(state.store)

# load a copy of the current sample events into the transcript
init_transcript(Transcript([*sample.events], log_model_api=False))
init_transcript(Transcript([*sample.events], log_model_api=False, bounded=False))

if state.scores is None:
state.scores = {}
Expand Down
63 changes: 53 additions & 10 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@
EvalSampleSummary,
eval_error,
)
from inspect_ai.log._recorders.streaming import materialize_streaming_sample
from inspect_ai.log._recorders.buffer.history_provider import (
BufferTranscriptHistoryProvider,
)
from inspect_ai.log._recorders.streaming import (
eval_retry_error_from_history,
materialize_streaming_sample,
)
from inspect_ai.log._samples import (
active_sample,
)
from inspect_ai.log._transcript import (
Transcript,
TranscriptHistoryProvider,
init_transcript,
transcript,
transcript_bounded_enabled,
)
from inspect_ai.model import (
GenerateConfig,
Expand Down Expand Up @@ -174,6 +182,18 @@
SAMPLE_TOTAL_PROGRESS_UNITS = 1


def _sample_transcript_config(
logger: TaskLogger | None, sample_id: str | int, epoch: int
) -> tuple[bool, TranscriptHistoryProvider | None]:
if logger is not None and logger.buffer_db is not None:
return (
transcript_bounded_enabled(),
BufferTranscriptHistoryProvider(logger.buffer_db, sample_id, epoch),
)
else:
return False, None


@dataclass
class TaskRunOptions:
task: Task
Expand Down Expand Up @@ -881,7 +901,15 @@ def on_sample_event(event: Event) -> None:
init_sample_model_usage()
init_sample_role_usage()
set_sample_state(state)
sample_transcript = Transcript(log_model_api=log_model_api)
sample_transcript_bounded, history_provider = _sample_transcript_config(
logger, sample_id, state.epoch
)
sample_transcript = Transcript(
log_model_api=log_model_api,
bounded=sample_transcript_bounded,
resident_tail=100,
history_provider=history_provider,
)
init_transcript(sample_transcript)
init_subtask_store(state.store)
sample_transcript.subscribe(on_sample_event)
Expand Down Expand Up @@ -1430,6 +1458,8 @@ def make_eval_sample(include_events: bool = True) -> EvalSample:
):
await emit_attempt_end(will_retry=True)

retry_error = _eval_retry_error(error, logger, state.sample_id, state.epoch)

# remove any buffered sample events
if logger is not None:
logger.remove_sample(state.sample_id, state.epoch)
Expand Down Expand Up @@ -1463,7 +1493,7 @@ def make_eval_sample(include_events: bool = True) -> EvalSample:
retry_on_error=retry_on_error - 1,
score_on_error=score_on_error,
# forward on error that caused retry
error_retries=copy(error_retries) + [_eval_retry_error(error)],
error_retries=copy(error_retries) + [retry_error],
time_limit=time_limit,
working_limit=working_limit,
semaphore=semaphore,
Expand Down Expand Up @@ -1745,16 +1775,29 @@ def init_sample_assistant_internal() -> None:
pass


def _eval_retry_error(error: EvalError) -> EvalRetryError:
def _eval_retry_error(
error: EvalError,
logger: TaskLogger | None = None,
sample_id: str | int | None = None,
epoch: int | None = None,
) -> EvalRetryError:
"""Create retry error with events from the most recent ModelEvent onward."""
from inspect_ai.event._model import ModelEvent

events = transcript().events
recent_events = list(events)
for i in range(len(events) - 1, -1, -1):
if isinstance(events[i], ModelEvent):
recent_events = list(events[i:])
break
if logger is not None and logger.buffer_db is not None and sample_id is not None:
if epoch is None:
raise ValueError(
"epoch is required when reading retry events from buffer DB"
)
with logger.buffer_db.open_sample_history(sample_id, epoch) as history:
return eval_retry_error_from_history(error, history)

sample_transcript = transcript()
recent_events = (
sample_transcript.events_since_last(ModelEvent)
if sample_transcript.full_history_available
else []
)
return EvalRetryError(
message=error.message,
traceback=error.traceback,
Expand Down
6 changes: 3 additions & 3 deletions src/inspect_ai/agent/_acp/transport_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,9 @@ def _noop_unsubscribe() -> None:
def snapshot(self) -> Sequence[Any]:
if self._captured is None:
return []
# list() the events sequence so callers iterating concurrently
# with new ``_event`` appends don't see size changes mid-iteration.
return list(self._captured.events)[self._attach_index :]
# Slice the transcript view directly so bounded/provider-backed
# transcripts can read only the suffix since attach.
return self._captured.events[self._attach_index :]


class _CancelSnapshot:
Expand Down
83 changes: 82 additions & 1 deletion src/inspect_ai/log/_recorders/buffer/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from logging import getLogger
from pathlib import Path
from sqlite3 import Connection, OperationalError
from typing import Callable, Iterable, Iterator, Literal
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Literal, cast

import psutil
from pydantic import BaseModel, JsonValue
Expand Down Expand Up @@ -67,6 +67,9 @@
logger = getLogger(__name__)
SYNC_CLEANUP_TIMEOUT = 30

if TYPE_CHECKING:
from inspect_ai.util._checkpoint._event_store import CheckpointEventStore


class TaskData(BaseModel):
version: int
Expand Down Expand Up @@ -488,6 +491,84 @@ def sample_attachment(self, id: str | int, epoch: int, hash: str) -> str | None:
conn.rollback()
raise

def import_checkpoint_events(
self, id: str | int, epoch: int, event_store: "CheckpointEventStore"
) -> int:
seed_count = 0
with self._acquire_sample_read_lease(id, epoch):
with self._get_connection() as conn:
conn.execute("BEGIN IMMEDIATE")
try:
pool_attachment_refs: set[str] = set()
message_pos_map: dict[int, int] = {}
for pos, message_entry in enumerate(
self._get_message_pool(conn, id, epoch)
):
message_pos_map[pos] = event_store.merge_message_pool_entry(
message_entry.msg_id, message_entry.data
)
pool_attachment_refs.update(
event_store.attachment_refs_from_json(message_entry.data)
)
call_pos_map: dict[int, int] = {}
for pos, call_entry in enumerate(
self._get_call_pool(conn, id, epoch)
):
call_pos_map[pos] = event_store.merge_call_pool_entry(
call_entry.hash, call_entry.data
)
pool_attachment_refs.update(
event_store.attachment_refs_from_json(call_entry.data)
)

def attachment_lookup(hash: str) -> str | None:
row = conn.execute(
"""
SELECT content FROM attachments
WHERE sample_id = ? AND sample_epoch = ? AND hash = ?
""",
[str(id), epoch, hash],
).fetchone()
return None if row is None else str(row["content"])

event_store.merge_attachment_refs(
pool_attachment_refs, attachment_lookup
)
for row in self._get_events(conn, id, epoch, latest_only=True):
event_store.merge_condensed_event(
row.event_id,
self._remap_pool_refs(
row.event, message_pos_map, call_pos_map
),
attachment_lookup,
)
seed_count += 1
conn.commit()
return seed_count
except Exception:
conn.rollback()
raise

@staticmethod
def _remap_pool_refs(
event: JsonData, message_pos_map: dict[int, int], call_pos_map: dict[int, int]
) -> JsonData:
remapped = dict(event)
input_refs = remapped.get("input_refs")
if isinstance(input_refs, list):
remapped["input_refs"] = cast(
JsonValue, _remap_refs(input_refs, message_pos_map)
)
call = remapped.get("call")
if isinstance(call, dict):
call_refs = call.get("call_refs")
if isinstance(call_refs, list):
remapped["call"] = {
**call,
"call_refs": cast(JsonValue, _remap_refs(call_refs, call_pos_map)),
}
return remapped

@contextmanager
def open_sample_history_tail(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/inspect_ai/log/_recorders/buffer/filestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ def get_sample_data(
def sample_event_count(self, id: str | int, epoch: int) -> int:
raise NotImplementedError("Sample history is only available for buffer DBs")

@override
def import_checkpoint_events(
self, id: str | int, epoch: int, event_store: object
) -> int:
raise NotImplementedError("Sample history is only available for buffer DBs")

@override
def open_sample_history_tail(
self,
Expand Down
119 changes: 119 additions & 0 deletions src/inspect_ai/log/_recorders/buffer/history_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

from collections.abc import Iterator, Mapping, Sequence
from typing import TYPE_CHECKING

from inspect_ai.event._event import Event
from inspect_ai.event._pool import (
materialize_pooled_events,
resolve_model_event_calls,
resolve_model_event_inputs,
)
from inspect_ai.event._validate import validate_events

if TYPE_CHECKING:
from inspect_ai.log._recorders.buffer.database import SampleBufferDatabase
from inspect_ai.log._recorders.buffer.history import SampleHistory
from inspect_ai.log._transcript import TranscriptHistoryProvider
from inspect_ai.util._checkpoint._event_store import CheckpointEventStore


class BufferTranscriptHistoryProvider:
def __init__(
self,
buffer_db: SampleBufferDatabase,
sample_id: str | int,
epoch: int,
) -> None:
self._buffer_db = buffer_db
self._sample_id = sample_id
self._epoch = epoch

@property
def event_count(self) -> int:
return self._buffer_db.sample_event_count(self._sample_id, self._epoch)

def events(self) -> Sequence[Event]:
return self._events()

def iter_events(self) -> Iterator[Event]:
return self._iter_events()

def recent_events(self, n: int | None = None) -> Sequence[Event]:
if n is None:
return self._events()
if n <= 0:
return []
with self._buffer_db.open_sample_history_tail(
self._sample_id, self._epoch, n
) as history:
return _materialize_events(history)

def events_from(self, start: int) -> Sequence[Event]:
if start <= 0:
return self._events()
with self._buffer_db.open_sample_history_from(
self._sample_id, self._epoch, start
) as history:
return _materialize_events(history)

def events_since_last(self, event_type: type[Event]) -> list[Event]:
suffix: list[Event] = []
with self._buffer_db.open_sample_history(
self._sample_id, self._epoch
) as history:
for event in _iter_materialized_events(history):
if isinstance(event, event_type):
suffix = [event]
else:
suffix.append(event)
return suffix

def attachments(self) -> Mapping[str, str]:
with self._buffer_db.open_sample_history(
self._sample_id, self._epoch
) as history:
return dict(history.attachments)

def attachment(self, hash: str) -> str | None:
return self._buffer_db.sample_attachment(self._sample_id, self._epoch, hash)

def import_checkpoint_events(self, event_store: "CheckpointEventStore") -> int:
return self._buffer_db.import_checkpoint_events(
self._sample_id, self._epoch, event_store
)

def _events(self) -> list[Event]:
with self._buffer_db.open_sample_history(
self._sample_id, self._epoch
) as history:
return _materialize_events(history)

def _iter_events(self) -> Iterator[Event]:
with self._buffer_db.open_sample_history(
self._sample_id, self._epoch
) as history:
yield from _iter_materialized_events(history)


def _materialize_events(history: SampleHistory) -> list[Event]:
return materialize_pooled_events(
history.iter_events(),
history.events_data["messages"],
history.events_data["calls"],
)


def _iter_materialized_events(history: SampleHistory) -> Iterator[Event]:
message_pool = history.events_data["messages"]
call_pool = history.events_data["calls"]
for raw_event in history.iter_events():
event = validate_events([raw_event])[0]
event = resolve_model_event_inputs([event], message_pool)[0]
yield resolve_model_event_calls([event], call_pool)[0]


if TYPE_CHECKING:
_buffer_transcript_history_provider: type[TranscriptHistoryProvider] = (
BufferTranscriptHistoryProvider
)
Loading