diff --git a/pyproject.toml b/pyproject.toml index a03bdcab..0ddac04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +firestore = [ + "google-cloud-firestore>=2.11,<3", +] s3 = [ "aioboto3>=13.0.0", # For S3ArtifactService ] diff --git a/src/google/adk_community/sessions/__init__.py b/src/google/adk_community/sessions/__init__.py index 90bf28d7..b46d380c 100644 --- a/src/google/adk_community/sessions/__init__.py +++ b/src/google/adk_community/sessions/__init__.py @@ -14,6 +14,7 @@ """Community session services for ADK.""" +from .firestore_session_service import BufferedFirestoreSessionService from .redis_session_service import RedisSessionService -__all__ = ["RedisSessionService"] +__all__ = ["BufferedFirestoreSessionService", "RedisSessionService"] diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py new file mode 100644 index 00000000..ec15d96b --- /dev/null +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -0,0 +1,721 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Firestore-backed ADK session service with batched, buffered event writes. + +``BufferedFirestoreSessionService`` mirrors the data model of the builtin +``google.adk.integrations.firestore.FirestoreSessionService`` (same collection +hierarchy, app/user/session state scoping, optimistic concurrency via a +``revision`` field, and idempotent event documents keyed by ``event.id``) but +**owns** the Firestore I/O so it can persist a whole batch of buffered events +in a **single transaction**. + +Collection hierarchy (matches the ADK builtin):: + + adk-session/{app}/users/{user}/sessions/{session}/events/{event} + app_states/{app} + user_states/{app}/users/{user} + +Events accumulate in a per-session in-memory buffer and flush when the buffer +reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (via +the background task started by :meth:`start`), when ``flush_session`` / +``flush_all`` / ``flush`` is called, or when :meth:`stop` runs. Set +``durable_mode=True`` to persist every event immediately (no buffering). + +Batching collapses the repeated session-doc + state-doc updates and per-event +transactions from N to 1 (fewer round-trips, less optimistic-lock contention). +On an abrupt process death before a flush, up to ``flush_interval_seconds`` of +events (or ``buffer_max_events - 1`` per session) may be lost; :meth:`stop` +flushes on graceful shutdown. +""" + +import asyncio +from collections import deque +from collections.abc import Awaitable +from collections.abc import Callable +import copy +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from datetime import timezone +import logging +import random +import time +from typing import Any +import uuid + +from google.adk.errors.already_exists_error import AlreadyExistsError +from google.adk.events.event import Event +from google.adk.sessions import _session_util +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.adk.sessions.state import State +from typing_extensions import override + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + +_RETRYABLE_ERROR_NAMES = frozenset({ + "DeadlineExceeded", + "ServiceUnavailable", + "Aborted", + "ResourceExhausted", + "InternalServerError", + "Internal", + "Cancelled", + "RetryError", + "TooManyRequests", +}) +_NON_RETRYABLE_TYPES: tuple[type[BaseException], ...] = ( + ValueError, + TypeError, + KeyError, + AlreadyExistsError, + PermissionError, +) +_NON_RETRYABLE_ERROR_NAMES = frozenset({ + "PermissionDenied", + "InvalidArgument", + "NotFound", + "Unauthenticated", + "FailedPrecondition", +}) + + +class SessionPersistenceError(RuntimeError): + """Raised when an explicit flush fails to persist after exhausting retries.""" + + +def is_retryable_error(exc: BaseException) -> bool: + """Classifies an error as transient/retryable vs. a permanent caller error.""" + if isinstance(exc, _NON_RETRYABLE_TYPES): + return False + name = type(exc).__name__ + if name in _NON_RETRYABLE_ERROR_NAMES: + return False + return name in _RETRYABLE_ERROR_NAMES + + +@dataclass +class _SessionBuffer: + """In-memory pending state for a single session.""" + + pending_events: deque[Event] = field(default_factory=deque) + last_flush_monotonic: float = 0.0 + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + flush_in_progress: bool = False + + +class BufferedFirestoreSessionService(BaseSessionService): # type: ignore[misc] + """A Firestore-backed session service with batched, buffered event writes.""" + + def __init__( + self, + client: Any = None, + root_collection: str | None = None, + *, + sessions_collection: str = DEFAULT_SESSIONS_COLLECTION, + events_collection: str = DEFAULT_EVENTS_COLLECTION, + app_state_collection: str = DEFAULT_APP_STATE_COLLECTION, + user_state_collection: str = DEFAULT_USER_STATE_COLLECTION, + flat_layout: bool = False, + durable_mode: bool = False, + buffer_max_events: int = 10, + flush_interval_seconds: float = 120.0, + max_retry_attempts: int = 5, + retry_base_delay_seconds: float = 0.5, + clock: Callable[[], float] = time.monotonic, + sleeper: Callable[[float], Awaitable[None]] = asyncio.sleep, + ) -> None: + """Initializes the buffered Firestore session service. + + Args: + client: An optional Firestore ``AsyncClient``. If not provided, a new + one is created (requires ``google-cloud-firestore``). + root_collection: Root collection name. Defaults to ``'adk-session'``. + sessions_collection: Sessions subcollection name. Defaults to + ``'sessions'``. + events_collection: Events subcollection name. Defaults to ``'events'``. + app_state_collection: Collection for app-scoped state. Defaults to + ``'app_states'``. + user_state_collection: Collection for user-scoped state. Defaults to + ``'user_states'``. + flat_layout: When ``True``, session documents are stored directly at + ``root_collection/{session_id}`` instead of the default nested ADK + path. Useful when the session id already encodes the user (e.g. + ``{phone}-{date}``) or to match an existing flat collection. + durable_mode: When ``True``, every event is persisted immediately (no + buffering). Equivalent to the builtin service behaviour. + buffer_max_events: Flush when this many events are buffered per session. + flush_interval_seconds: Background flush cadence (see :meth:`start`). + max_retry_attempts: Max attempts on a retryable Firestore error. + retry_base_delay_seconds: Base delay for exponential backoff with jitter. + clock: Monotonic clock, injectable for tests. + sleeper: Async sleep function, injectable for tests. + """ + try: + from google.cloud import firestore + except ImportError as e: + raise ImportError( + "BufferedFirestoreSessionService requires google-cloud-firestore." + " Install it with: pip install google-adk-community[firestore]" + ) from e + + self._firestore = firestore + self.client = client if client is not None else firestore.AsyncClient() + self.root_collection = root_collection or DEFAULT_ROOT_COLLECTION + self.sessions_collection = sessions_collection + self.events_collection = events_collection + self.app_state_collection = app_state_collection + self.user_state_collection = user_state_collection + self._flat_layout = flat_layout + self._durable_mode = durable_mode + self._buffer_max_events = buffer_max_events + self._flush_interval_seconds = flush_interval_seconds + self._max_retry_attempts = max_retry_attempts + self._retry_base_delay_seconds = retry_base_delay_seconds + self._clock = clock + self._sleeper = sleeper + self._transactional = firestore.async_transactional + + self._buffers: dict[str, _SessionBuffer] = {} + self._session_refs: dict[str, Session] = {} + self._buffers_guard = asyncio.Lock() + self._task: asyncio.Task[None] | None = None + self._check_interval = max(1.0, min(flush_interval_seconds, 5.0)) + + def _get_sessions_ref(self, app_name: str, user_id: str) -> Any: + if self._flat_layout: + return self.client.collection(self.root_collection) + return ( + self.client.collection(self.root_collection) + .document(app_name) + .collection("users") + .document(user_id) + .collection(self.sessions_collection) + ) + + def _app_state_ref(self, app_name: str) -> Any: + return self.client.collection(self.app_state_collection).document(app_name) + + def _user_state_ref(self, app_name: str, user_id: str) -> Any: + return ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + + @staticmethod + def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], + ) -> dict[str, Any]: + merged = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged[State.APP_PREFIX + key] = value + for key, value in user_state.items(): + merged[State.USER_PREFIX + key] = value + return merged + + async def _read_state(self, ref: Any) -> dict[str, Any]: + doc = await ref.get() + return (doc.to_dict() or {}) if doc.exists else {} + + @staticmethod + def _coerce_timestamp(value: Any) -> float: + if isinstance(value, datetime): + return value.timestamp() + try: + return float(value) + except (ValueError, TypeError): + return 0.0 + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: dict[str, Any] | None = None, + session_id: str | None = None, + ) -> Session: + """Creates a new session (raises AlreadyExistsError on a duplicate id).""" + session_id = session_id or str(uuid.uuid4()) + deltas = _session_util.extract_state_delta(state or {}) + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + app_ref = self._app_state_ref(app_name) + user_ref = self._user_state_ref(app_name, user_id) + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": deltas["session"], + "createTime": self._firestore.SERVER_TIMESTAMP, + "updateTime": self._firestore.SERVER_TIMESTAMP, + "revision": 1, + } + + async def _create_txn(transaction: Any) -> None: + snap = await session_ref.get(transaction=transaction) + if snap.exists: + raise AlreadyExistsError(f"Session {session_id} already exists.") + if deltas["app"]: + app_snap = await app_ref.get(transaction=transaction) + current = app_snap.to_dict() if app_snap.exists else {} + current.update(deltas["app"]) + transaction.set(app_ref, current, merge=True) + if deltas["user"]: + user_snap = await user_ref.get(transaction=transaction) + current = user_snap.to_dict() if user_snap.exists else {} + current.update(deltas["user"]) + transaction.set(user_ref, current, merge=True) + transaction.set(session_ref, session_data) + + await self._transactional(_create_txn)(self.client.transaction()) + + merged = self._merge_state( + await self._read_state(app_ref), + await self._read_state(user_ref), + deltas["session"], + ) + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged, + events=[], + last_update_time=datetime.now(timezone.utc).timestamp(), + ) + session._storage_update_marker = "1" + return session + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: GetSessionConfig | None = None, + ) -> Session | None: + """Gets a session, merging persisted and not-yet-flushed buffered events.""" + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + doc = await session_ref.get() + if not doc.exists: + return None + data = doc.to_dict() or {} + + query = session_ref.collection(self.events_collection).order_by("timestamp") + if config: + if config.after_timestamp: + query = query.where( + "timestamp", + ">=", + datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc), + ) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + events: list[Event] = [] + for event_doc in await query.get(): + event_data = event_doc.to_dict() or {} + if "event_data" in event_data: + events.append(Event.model_validate(event_data["event_data"])) + + merged_state = self._merge_state( + await self._read_state(self._app_state_ref(app_name)), + await self._read_state(self._user_state_ref(app_name, user_id)), + data.get("state", {}) or {}, + ) + revision = data.get("revision", 0) + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged_state, + events=events, + last_update_time=self._coerce_timestamp(data.get("updateTime")), + ) + session._storage_update_marker = str(revision) if revision > 0 else None + return self._merge_buffered(session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str | None = None + ) -> ListSessionsResponse: + """Lists sessions for an app (optionally a single user).""" + if self._flat_layout: + query = self.client.collection(self.root_collection).where( + "appName", "==", app_name + ) + if user_id: + query = query.where("userId", "==", user_id) + docs = await query.get() + elif user_id: + docs = await ( + self._get_sessions_ref(app_name, user_id) + .where("appName", "==", app_name) + .get() + ) + else: + docs = await ( + self.client.collection_group(self.sessions_collection) + .where("appName", "==", app_name) + .get() + ) + + app_state = await self._read_state(self._app_state_ref(app_name)) + user_states: dict[str, dict[str, Any]] = {} + if user_id: + user_states[user_id] = await self._read_state( + self._user_state_ref(app_name, user_id) + ) + else: + users_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + for u_doc in await users_ref.get(): + user_states[u_doc.id] = u_doc.to_dict() or {} + + sessions: list[Session] = [] + for doc in docs: + data = doc.to_dict() + if not data: + continue + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=self._merge_state( + app_state, + user_states.get(data["userId"], {}), + data.get("state", {}) or {}, + ), + events=[], + last_update_time=0.0, + ) + ) + return ListSessionsResponse(sessions=sessions) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session, its events, and drops any pending buffer.""" + async with self._buffers_guard: + self._buffers.pop(session_id, None) + self._session_refs.pop(session_id, None) + + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + events_ref = session_ref.collection(self.events_collection) + batch = self.client.batch() + count = 0 + async for event_doc in events_ref.stream(): + batch.delete(event_doc.reference) + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 + if count > 0: + await batch.commit() + await session_ref.delete() + + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Returns the raw (un-prefixed) user-scoped state for an app/user.""" + return dict(await self._read_state(self._user_state_ref(app_name, user_id))) + + @override + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event in memory and buffers (or immediately persists) it.""" + event = await super().append_event(session=session, event=event) + if event.partial: + return event + + buffered = event.model_copy(deep=True) + if self._durable_mode: + await self._persist_batch(session, [buffered]) + return event + + buffer = await self._get_or_create_buffer(session) + async with buffer.lock: + buffer.pending_events.append(buffered) + pending = len(buffer.pending_events) + + if pending >= self._buffer_max_events: + await self._flush(session.id, explicit=False) + return event + + async def flush_session(self, session_id: str) -> None: + """Explicitly flushes a session's buffer, raising on failure.""" + await self._flush(session_id, explicit=True) + + async def flush_all(self) -> None: + """Flushes every buffered session. Failures are logged, events kept.""" + for session_id in list(self._buffers.keys()): + try: + await self._flush(session_id, explicit=False) + except Exception: # noqa: BLE001 + logger.exception("flush_all_session_failed session_id=%s", session_id) + + async def flush(self) -> None: + """ADK lifecycle hook (Runner.close()): flushes all buffered sessions.""" + await self.flush_all() + + async def start(self) -> None: + """Starts the background periodic-flush task (idempotent).""" + if self._task is not None and not self._task.done(): + return + self._task = asyncio.create_task(self._periodic_flush_loop()) + + async def stop(self) -> None: + """Stops the background task and performs a final flush (idempotent).""" + task = self._task + self._task = None + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + await self.flush_all() + + async def close(self) -> None: + """Closes the underlying Firestore AsyncClient.""" + closer = getattr(self.client, "close", None) + if closer is not None: + result = closer() + if asyncio.iscoroutine(result): + await result + + async def _flush(self, session_id: str, *, explicit: bool) -> None: + buffer = self._buffers.get(session_id) + if buffer is None: + return + + async with buffer.lock: + if buffer.flush_in_progress: + return + if not buffer.pending_events: + buffer.last_flush_monotonic = self._clock() + return + buffer.flush_in_progress = True + batch = list(buffer.pending_events) + buffer.pending_events.clear() + buffer.last_flush_monotonic = self._clock() + session = self._session_refs.get(session_id) + + if session is None: # pragma: no cover + async with buffer.lock: + buffer.pending_events.extendleft(reversed(batch)) + buffer.flush_in_progress = False + return + + try: + await self._persist_with_retry(session, batch, session_id) + except Exception as exc: # noqa: BLE001 + async with buffer.lock: + buffer.pending_events.extendleft(reversed(batch)) + buffer.flush_in_progress = False + if explicit: + raise SessionPersistenceError( + f"Failed to flush session {session_id} after retries." + ) from exc + return + + async with buffer.lock: + buffer.flush_in_progress = False + + async def _persist_with_retry( + self, session: Session, batch: list[Event], session_id: str + ) -> None: + attempt = 0 + while True: + attempt += 1 + try: + await self._persist_batch(session, batch) + return + except Exception as exc: # noqa: BLE001 + if not is_retryable_error(exc) or attempt >= self._max_retry_attempts: + logger.error( + "session_flush_failed session_id=%s events=%s attempt=%s" + " error=%s", + session_id, + len(batch), + attempt, + type(exc).__name__, + ) + raise + delay = self._retry_base_delay_seconds * (2 ** (attempt - 1)) + delay += random.uniform(0.0, self._retry_base_delay_seconds) + await self._sleeper(delay) + + async def _persist_batch(self, session: Session, events: list[Event]) -> None: + """Persists a batch of events for one session in a single transaction.""" + session_ref = self._get_sessions_ref( + session.app_name, session.user_id + ).document(session.id) + app_ref = self._app_state_ref(session.app_name) + user_ref = self._user_state_ref(session.app_name, session.user_id) + + agg: dict[str, dict[str, Any]] = {"app": {}, "user": {}, "session": {}} + for event in events: + delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + scoped = _session_util.extract_state_delta(delta) + agg["app"].update(scoped["app"]) + agg["user"].update(scoped["user"]) + agg["session"].update(scoped["session"]) + has_app, has_user = bool(agg["app"]), bool(agg["user"]) + + async def _append_txn(transaction: Any) -> int: + snap = await session_ref.get(transaction=transaction) + if not snap.exists: + raise ValueError(f"Session {session.id} not found.") + doc = snap.to_dict() or {} + if doc.get("status") == "DELETING": + raise ValueError(f"Session {session.id} is currently being deleted.") + current_revision = doc.get("revision", 0) + marker = getattr(session, "_storage_update_marker", None) + if marker is not None and marker != str(current_revision): + raise ValueError( + "The session has been modified in storage since it was loaded." + " Please reload the session before appending more events." + ) + + app_snap = await app_ref.get(transaction=transaction) if has_app else None + user_snap = ( + await user_ref.get(transaction=transaction) if has_user else None + ) + + if has_app: + current = app_snap.to_dict() if app_snap.exists else {} + current.update(agg["app"]) + transaction.set(app_ref, current, merge=True) + if has_user: + current = user_snap.to_dict() if user_snap.exists else {} + current.update(agg["user"]) + transaction.set(user_ref, current, merge=True) + for key, value in agg["session"].items(): + session.state[key] = value + + for event in events: + event_ref = session_ref.collection(self.events_collection).document( + event.id + ) + # Use event's own timestamp so intra-batch order survives a shared commit time. + transaction.set( + event_ref, + { + "event_data": event.model_dump(exclude_none=True, mode="json"), + "timestamp": datetime.fromtimestamp( + event.timestamp, tz=timezone.utc + ), + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + new_revision = current_revision + 1 + session_only_state = { + k: v + for k, v in session.state.items() + if not k.startswith(State.APP_PREFIX) + and not k.startswith(State.USER_PREFIX) + and not k.startswith(State.TEMP_PREFIX) + } + transaction.update( + session_ref, + { + "state": session_only_state, + "updateTime": self._firestore.SERVER_TIMESTAMP, + "revision": new_revision, + }, + ) + return new_revision + + new_revision = await self._transactional(_append_txn)( + self.client.transaction() + ) + session._storage_update_marker = str(new_revision) + if events: + session.last_update_time = events[-1].timestamp + + async def _periodic_flush_loop(self) -> None: + try: + while True: + await self._sleeper(self._check_interval) + await self._flush_due() + except asyncio.CancelledError: + raise + + async def _flush_due(self) -> list[asyncio.Task[None]]: + now = self._clock() + tasks: list[asyncio.Task[None]] = [] + for session_id, buffer in list(self._buffers.items()): + if ( + buffer.pending_events + and (now - buffer.last_flush_monotonic) + >= self._flush_interval_seconds + ): + tasks.append( + asyncio.create_task(self._safe_background_flush(session_id)) + ) + return tasks + + async def _safe_background_flush(self, session_id: str) -> None: + try: + await self._flush(session_id, explicit=False) + except Exception: # noqa: BLE001 + logger.exception("background_flush_failed session_id=%s", session_id) + + async def _get_or_create_buffer(self, session: Session) -> _SessionBuffer: + async with self._buffers_guard: + buffer = self._buffers.get(session.id) + if buffer is None: + buffer = _SessionBuffer(last_flush_monotonic=self._clock()) + self._buffers[session.id] = buffer + self._session_refs[session.id] = session + return buffer + + def _merge_buffered(self, session: Session) -> Session: + buffer = self._buffers.get(session.id) + if buffer is None or not buffer.pending_events: + return session + seen = {e.id for e in session.events} + merged = list(session.events) + for event in list(buffer.pending_events): + if event.id not in seen: + merged.append(event) + seen.add(event.id) + merged.sort(key=lambda e: (e.timestamp or 0.0)) + session.events = merged + return session diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 00000000..0a7e2794 --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,600 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BufferedFirestoreSessionService. + +Uses an in-memory fake Firestore AsyncClient (no external services), a +deterministic clock, and a recording sleeper. The service's transactional +wrapper is replaced with an identity (or gated/flaky) runner so the fake +transaction is driven directly. +""" + +import asyncio + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig +from google.genai import types +import pytest + +from google.adk_community.sessions.firestore_session_service import BufferedFirestoreSessionService +from google.adk_community.sessions.firestore_session_service import SessionPersistenceError + +APP = "app" +USER = "user-1" +SID = "session-1" + + +# --- fake Firestore ---------------------------------------------------------- + + +class FakeSnapshot: + + def __init__(self, doc_id, data, ref): + self.id = doc_id + self._data = data + self.reference = ref + + @property + def exists(self): + return self._data is not None + + def to_dict(self): + return dict(self._data) if self._data is not None else None + + +class FakeDoc: + + def __init__(self, doc_id): + self.id = doc_id + self.data = None + self._subcollections = {} + self.reference = self + + async def get(self, transaction=None): + return FakeSnapshot(self.id, self.data, self) + + def collection(self, name): + return self._subcollections.setdefault(name, FakeCollection(name)) + + async def delete(self): + self.data = None + + def set(self, data, merge=False): + if merge and isinstance(self.data, dict): + merged = dict(self.data) + merged.update(data) + self.data = merged + else: + self.data = dict(data) + + def update(self, data): + self.data = {**(self.data or {}), **data} + + +def _match(actual, op, value): + if actual is None: + return False + if op == "==": + return actual == value + if op == ">=": + return actual >= value + return False + + +class FakeQuery: + + def __init__(self, docs): + self._docs = docs + self._order = None + self._filters = [] + self._limit_last = None + + def order_by(self, field): + self._order = field + return self + + def where(self, field, op, value): + self._filters.append((field, op, value)) + return self + + def limit_to_last(self, n): + self._limit_last = n + return self + + async def get(self): + rows = [d for d in self._docs if d.data is not None] + for field, op, value in self._filters: + rows = [d for d in rows if _match(d.data.get(field), op, value)] + if self._order: + rows = sorted(rows, key=lambda d: d.data.get(self._order)) + if self._limit_last is not None: + rows = rows[-self._limit_last :] + return [FakeSnapshot(d.id, d.data, d) for d in rows] + + +class FakeCollection: + + def __init__(self, name): + self.name = name + self.docs = {} + + def document(self, doc_id): + if doc_id not in self.docs: + self.docs[doc_id] = FakeDoc(doc_id) + return self.docs[doc_id] + + def order_by(self, field): + return FakeQuery(list(self.docs.values())).order_by(field) + + def where(self, field, op, value): + return FakeQuery(list(self.docs.values())).where(field, op, value) + + async def get(self): + return await FakeQuery(list(self.docs.values())).get() + + async def stream(self): + for d in list(self.docs.values()): + if d.data is not None: + yield FakeSnapshot(d.id, d.data, d) + + +class FakeTransaction: + + def set(self, ref, data, merge=False): + ref.set(data, merge=merge) + + def update(self, ref, data): + ref.update(data) + + +class FakeBatch: + + def __init__(self): + self._ops = [] + + def delete(self, ref): + self._ops.append(ref) + + async def commit(self): + for ref in self._ops: + ref.data = None + self._ops = [] + + +class FakeFirestore: + + def __init__(self): + self.collections = {} + self.transaction_count = 0 + + def collection(self, name): + return self.collections.setdefault(name, FakeCollection(name)) + + def collection_group(self, name): + return FakeQuery(self._gather_group(name)) + + def transaction(self): + self.transaction_count += 1 + return FakeTransaction() + + def batch(self): + return FakeBatch() + + def _gather_group(self, name): + result = [] + + def walk(coll): + for doc in coll.docs.values(): + for sub_name, sub in doc._subcollections.items(): + if sub_name == name: + result.extend(d for d in sub.docs.values() if d.data is not None) + walk(sub) + + for coll in self.collections.values(): + if coll.name == name: + result.extend(d for d in coll.docs.values() if d.data is not None) + walk(coll) + return result + + +# --- helpers ----------------------------------------------------------------- + + +class Clock: + + def __init__(self, start=1000.0): + self.now = start + + def __call__(self): + return self.now + + def advance(self, seconds): + self.now += seconds + + +class RecordingSleeper: + + def __init__(self): + self.delays = [] + + async def __call__(self, delay): + self.delays.append(delay) + + +class Aborted(Exception): + """Name matches the retryable allowlist.""" + + +def _identity_transactional(fn): + + async def run(transaction): + return await fn(transaction) + + return run + + +def _make(**kwargs): + client = FakeFirestore() + clock = Clock() + sleeper = RecordingSleeper() + service = BufferedFirestoreSessionService( + client, clock=clock, sleeper=sleeper, **kwargs + ) + service._transactional = _identity_transactional + return service, client, clock, sleeper + + +def _event(author, text, timestamp, *, state_delta=None): + return Event( + invocation_id=f"inv-{timestamp}", + author=author, + timestamp=timestamp, + content=types.Content( + role="user" if author == "user" else "model", + parts=[types.Part(text=text)], + ), + actions=EventActions(state_delta=state_delta or {}), + ) + + +def _session_doc(client, session_id=SID): + return ( + client.collection("adk-session") + .document(APP) + .collection("users") + .document(USER) + .collection("sessions") + .document(session_id) + ) + + +def _persisted_event_count(client, session_id=SID): + events = _session_doc(client, session_id)._subcollections.get("events") + if events is None: + return 0 + return sum(1 for d in events.docs.values() if d.data is not None) + + +# --- tests ------------------------------------------------------------------- + + +async def test_create_session_writes_metadata(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + assert session.id == SID + doc = _session_doc(client).data + assert doc["appName"] == APP + assert doc["userId"] == USER + assert doc["revision"] == 1 + + +async def test_buffered_append_defers_persistence(): + service, client, *_ = _make(buffer_max_events=10) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + base = client.transaction_count + for i in range(9): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 0 + assert client.transaction_count == base + + +async def test_flush_persists_whole_batch_in_one_transaction(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(9): + await service.append_event(session, _event("user", f"m{i}", float(i))) + before = client.transaction_count + await service.flush_session(SID) + assert client.transaction_count - before == 1 + assert _persisted_event_count(client) == 9 + assert _session_doc(client).data["revision"] == 2 + + +async def test_reaching_max_events_auto_flushes(): + service, client, *_ = _make(buffer_max_events=10) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(10): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 10 + + +async def test_durable_mode_writes_each_event_immediately(): + service, client, *_ = _make(durable_mode=True) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + base = client.transaction_count + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert client.transaction_count - base == 3 + assert _persisted_event_count(client) == 3 + assert SID not in service._buffers + + +async def test_periodic_flush_after_interval(): + service, client, clock, _ = _make(flush_interval_seconds=120.0) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 0 + clock.advance(121.0) + await asyncio.gather(*await service._flush_due()) + assert _persisted_event_count(client) == 3 + + +async def test_flush_hook_and_stop_final_flush(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush() # ADK Runner.close() hook + assert _persisted_event_count(client) == 1 + + +async def test_get_session_merges_and_orders_without_duplicates(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "persisted", 1.0)) + await service.flush_session(SID) + await service.append_event(session, _event("user", "buffered", 2.0)) + loaded = await service.get_session(app_name=APP, user_id=USER, session_id=SID) + texts = [e.content.parts[0].text for e in loaded.events] + assert texts == ["persisted", "buffered"] + assert len(texts) == len({e.id for e in loaded.events}) + + +async def test_state_delta_scoping(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event( + session, + _event( + "user", + "a", + 1.0, + state_delta={ + "app:shared": "yes", + "user:goal": "fat loss", + "sessionOnly": "kept", + "temp:scratch": "discard", + }, + ), + ) + await service.flush_session(SID) + loaded = await service.get_session(app_name=APP, user_id=USER, session_id=SID) + assert loaded.state["app:shared"] == "yes" + assert loaded.state["user:goal"] == "fat loss" + assert loaded.state["sessionOnly"] == "kept" + assert "temp:scratch" not in loaded.state + + +async def test_get_user_state(): + service, *_ = _make() + await service.create_session( + app_name=APP, user_id=USER, session_id=SID, state={"user:goal": "lose"} + ) + state = await service.get_user_state(app_name=APP, user_id=USER) + assert state == {"goal": "lose"} + + +async def test_retryable_failures_backoff_then_succeed(): + service, client, _, sleeper = _make(max_retry_attempts=5) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + errors = [Aborted(), Aborted()] + + def flaky(fn): + + async def run(transaction): + if errors: + raise errors.pop(0) + return await fn(transaction) + + return run + + service._transactional = flaky + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush_session(SID) + assert _persisted_event_count(client) == 1 + assert len(sleeper.delays) == 2 + assert sleeper.delays[1] > sleeper.delays[0] + + +async def test_permanent_failure_not_retried(): + service, _, _, sleeper = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + + def boom(fn): + + async def run(transaction): + raise ValueError("permission denied") + + return run + + service._transactional = boom + await service.append_event(session, _event("user", "a", 1.0)) + with pytest.raises(SessionPersistenceError): + await service.flush_session(SID) + assert sleeper.delays == [] + + +async def test_events_appended_during_flush_not_lost(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + gate = asyncio.Event() + entered = asyncio.Event() + + def gated(fn): + + async def run(transaction): + entered.set() + await gate.wait() + return await fn(transaction) + + return run + + await service.append_event(session, _event("user", "a", 1.0)) + await service.append_event(session, _event("user", "b", 2.0)) + service._transactional = gated + + flush_task = asyncio.create_task(service.flush_session(SID)) + await entered.wait() + await service.append_event(session, _event("user", "c", 3.0)) + gate.set() + await flush_task + + pending = service._buffers[SID].pending_events + assert [e.content.parts[0].text for e in pending] == ["c"] + assert _persisted_event_count(client) == 2 + + +async def test_concurrent_flushes_do_not_duplicate(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + gate = asyncio.Event() + entered = asyncio.Event() + + def gated(fn): + + async def run(transaction): + entered.set() + await gate.wait() + return await fn(transaction) + + return run + + await service.append_event(session, _event("user", "a", 1.0)) + await service.append_event(session, _event("user", "b", 2.0)) + service._transactional = gated + before = client.transaction_count + + t1 = asyncio.create_task(service.flush_session(SID)) + await entered.wait() + t2 = asyncio.create_task(service.flush_session(SID)) + gate.set() + await asyncio.gather(t1, t2) + + assert client.transaction_count - before == 1 + assert _persisted_event_count(client) == 2 + + +async def test_delete_session_removes_events_and_buffer(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush_session(SID) + await service.delete_session(app_name=APP, user_id=USER, session_id=SID) + assert _session_doc(client).data is None + assert _persisted_event_count(client) == 0 + assert SID not in service._buffers + + +async def test_get_session_not_found_returns_none(): + service, *_ = _make() + result = await service.get_session( + app_name=APP, user_id=USER, session_id="missing" + ) + assert result is None + + +async def test_list_sessions(): + service, *_ = _make() + await service.create_session(app_name=APP, user_id=USER, session_id="s1") + await service.create_session(app_name=APP, user_id=USER, session_id="s2") + per_user = await service.list_sessions(app_name=APP, user_id=USER) + all_users = await service.list_sessions(app_name=APP) + assert {s.id for s in per_user.sessions} == {"s1", "s2"} + assert {s.id for s in all_users.sessions} == {"s1", "s2"} + + +async def test_get_session_with_config_num_recent_events(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + await service.flush_session(SID) + loaded = await service.get_session( + app_name=APP, + user_id=USER, + session_id=SID, + config=GetSessionConfig(num_recent_events=2), + ) + assert [e.content.parts[0].text for e in loaded.events] == ["m1", "m2"] + + +async def test_create_session_duplicate_raises(): + from google.adk.errors.already_exists_error import AlreadyExistsError + + service, *_ = _make() + await service.create_session(app_name=APP, user_id=USER, session_id=SID) + with pytest.raises(AlreadyExistsError): + await service.create_session(app_name=APP, user_id=USER, session_id=SID) + + +async def test_start_stop_cancellation_is_clean(): + service, *_ = _make() + service._sleeper = asyncio.sleep # real sleep so the loop blocks + await service.start() + await service.start() # idempotent + task = service._task + await service.stop() + await service.stop() # idempotent + assert task.cancelled() or task.done()