From fa936aabc7ec286cf0c59c17efe9042b63f2d8be Mon Sep 17 00:00:00 2001 From: enesdemirag Date: Fri, 19 Jun 2026 20:44:15 +0300 Subject: [PATCH 1/4] feat(sessions): add BufferedFirestoreSessionService Add a Firestore-backed session service that buffers events in memory and flushes them in a single transaction per session, collapsing the repeated session-doc/state-doc updates and per-event transactions from N to 1. Mirrors the data model of the builtin google.adk.integrations.firestore.FirestoreSessionService (collection hierarchy, app/user/session state scoping, optimistic concurrency via a revision field, idempotent event docs keyed by event.id) and adds: - in-memory per-session buffering with count/interval/explicit/shutdown flush - durable_mode to persist every event immediately - exponential backoff with jitter on retryable errors - start()/stop()/flush() ADK lifecycle hooks Gated behind the new optional [firestore] extra. Unit tests use an in-memory fake AsyncClient (no external services). --- pyproject.toml | 3 + src/google/adk_community/sessions/__init__.py | 3 +- .../sessions/firestore_session_service.py | 720 ++++++++++++++++++ .../test_firestore_session_service.py | 600 +++++++++++++++ 4 files changed, 1325 insertions(+), 1 deletion(-) create mode 100644 src/google/adk_community/sessions/firestore_session_service.py create mode 100644 tests/unittests/sessions/test_firestore_session_service.py diff --git a/pyproject.toml b/pyproject.toml index a03bdcab..3f7ddd53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +firestore = [ + "google-cloud-firestore>=2.11.0, <3.0.0", # For BufferedFirestoreSessionService +] s3 = [ "aioboto3>=13.0.0", # For S3ArtifactService ] diff --git a/src/google/adk_community/sessions/__init__.py b/src/google/adk_community/sessions/__init__.py index 90bf28d7..b46d380c 100644 --- a/src/google/adk_community/sessions/__init__.py +++ b/src/google/adk_community/sessions/__init__.py @@ -14,6 +14,7 @@ """Community session services for ADK.""" +from .firestore_session_service import BufferedFirestoreSessionService from .redis_session_service import RedisSessionService -__all__ = ["RedisSessionService"] +__all__ = ["BufferedFirestoreSessionService", "RedisSessionService"] diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py new file mode 100644 index 00000000..cce6937e --- /dev/null +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -0,0 +1,720 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Firestore-backed ADK session service with batched, buffered event writes. + +``BufferedFirestoreSessionService`` mirrors the data model of the builtin +``google.adk.integrations.firestore.FirestoreSessionService`` (same collection +hierarchy, app/user/session state scoping, optimistic concurrency via a +``revision`` field, and idempotent event documents keyed by ``event.id``) but +**owns** the Firestore I/O so it can persist a whole batch of buffered events in +a **single transaction**. + +Collection hierarchy:: + + adk-session/{app}/users/{user}/sessions/{session}/events/{event} + app_states/{app} + user_states/{app}/users/{user} + +Events accumulate in a per-session in-memory buffer and flush when the buffer +reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (the +background task started by :meth:`start`), when ``flush_session`` / ``flush_all`` +/ ``flush`` is called, or when :meth:`stop` runs. Set ``durable_mode=True`` to +persist every event immediately (no buffering). + +Batching does not change the event-document count, but it collapses the repeated +session-doc + state-doc updates and per-event transactions from N to 1 (fewer +round-trips and less optimistic-lock contention). On an abrupt process death +before a flush, up to ``flush_interval_seconds`` of events (or +``buffer_max_events - 1`` per session) may be lost; ``stop()`` flushes on +graceful shutdown but cannot protect against crashes. +""" + +from __future__ import annotations + +import asyncio +from collections import deque +from collections.abc import Awaitable +from collections.abc import Callable +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from datetime import timezone +import logging +import random +import time +from typing import Any +from typing import Optional +import uuid + +from google.adk.errors.already_exists_error import AlreadyExistsError +from google.adk.events.event import Event +from google.adk.sessions import _session_util +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.adk.sessions.state import State +from typing_extensions import override + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + +_SessionLockKey = tuple[str, str, str] + +# Transient Firestore / gRPC failures worth retrying. Matched by class name to +# avoid a hard dependency on google.api_core being importable everywhere. +_RETRYABLE_ERROR_NAMES = frozenset({ + "DeadlineExceeded", + "ServiceUnavailable", + "Aborted", + "ResourceExhausted", + "InternalServerError", + "Internal", + "Cancelled", + "RetryError", + "TooManyRequests", +}) +_NON_RETRYABLE_TYPES: tuple[type[BaseException], ...] = ( + ValueError, + TypeError, + KeyError, + AlreadyExistsError, + PermissionError, +) +_NON_RETRYABLE_ERROR_NAMES = frozenset({ + "PermissionDenied", + "InvalidArgument", + "NotFound", + "Unauthenticated", + "FailedPrecondition", +}) + + +class SessionPersistenceError(RuntimeError): + """Raised when an explicit flush fails to persist after exhausting retries.""" + + +def is_retryable_error(exc: BaseException) -> bool: + """Classifies an error as transient/retryable vs. a permanent caller error.""" + if isinstance(exc, _NON_RETRYABLE_TYPES): + return False + name = type(exc).__name__ + if name in _NON_RETRYABLE_ERROR_NAMES: + return False + if name in _RETRYABLE_ERROR_NAMES: + return True + return False + + +@dataclass +class _SessionBuffer: + """In-memory pending state for a single session.""" + + pending_events: deque[Event] = field(default_factory=deque) + last_flush_monotonic: float = 0.0 + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + flush_in_progress: bool = False + + +class BufferedFirestoreSessionService(BaseSessionService): # type: ignore[misc] + """A Firestore-backed session service with batched, buffered event writes.""" + + def __init__( + self, + client: Any = None, + root_collection: Optional[str] = None, + *, + durable_mode: bool = False, + buffer_max_events: int = 10, + flush_interval_seconds: float = 120.0, + max_retry_attempts: int = 5, + retry_base_delay_seconds: float = 0.5, + clock: Callable[[], float] = time.monotonic, + sleeper: Callable[[float], Awaitable[None]] = asyncio.sleep, + ) -> None: + """Initializes the buffered Firestore session service. + + Args: + client: An optional Firestore ``AsyncClient``. If not provided, a new one + is created (requires ``google-cloud-firestore``). + root_collection: Root collection name. Defaults to ``'adk-session'``. + durable_mode: When True, every event is persisted immediately and no + buffering happens. + buffer_max_events: Flush a session once this many events are buffered. + flush_interval_seconds: Background flush cadence (see :meth:`start`). + max_retry_attempts: Max attempts when a flush hits a retryable error. + retry_base_delay_seconds: Base delay for exponential backoff with jitter. + clock: Monotonic clock, injectable for tests. + sleeper: Async sleep function, injectable for tests. + """ + try: + from google.cloud import firestore + except ImportError as e: + raise ImportError( + "BufferedFirestoreSessionService requires google-cloud-firestore." + " Install it with: pip install google-adk-community[firestore]" + ) from e + + self.client = client if client is not None else firestore.AsyncClient() + self.root_collection = root_collection or DEFAULT_ROOT_COLLECTION + self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + self.events_collection = DEFAULT_EVENTS_COLLECTION + self.app_state_collection = DEFAULT_APP_STATE_COLLECTION + self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + + self._durable_mode = durable_mode + self._buffer_max_events = buffer_max_events + self._flush_interval_seconds = flush_interval_seconds + self._max_retry_attempts = max_retry_attempts + self._retry_base_delay_seconds = retry_base_delay_seconds + self._clock = clock + self._sleeper = sleeper + # Injectable so tests can drive a fake client without the real transactional + # retry wrapper. + self._transactional = firestore.async_transactional + + self._buffers: dict[str, _SessionBuffer] = {} + self._session_refs: dict[str, Session] = {} + self._buffers_guard = asyncio.Lock() + self._task: Optional[asyncio.Task[None]] = None + self._check_interval = max(1.0, min(flush_interval_seconds, 5.0)) + + # -- Firestore refs / helpers --------------------------------------------- + + def _get_sessions_ref(self, app_name: str, user_id: str) -> Any: + return ( + self.client.collection(self.root_collection) + .document(app_name) + .collection("users") + .document(user_id) + .collection(self.sessions_collection) + ) + + def _app_state_ref(self, app_name: str) -> Any: + return self.client.collection(self.app_state_collection).document(app_name) + + def _user_state_ref(self, app_name: str, user_id: str) -> Any: + return ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + + @staticmethod + def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], + ) -> dict[str, Any]: + import copy + + merged = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged[State.APP_PREFIX + key] = value + for key, value in user_state.items(): + merged[State.USER_PREFIX + key] = value + return merged + + async def _read_state(self, ref: Any) -> dict[str, Any]: + doc = await ref.get() + return (doc.to_dict() or {}) if doc.exists else {} + + @staticmethod + def _coerce_timestamp(value: Any) -> float: + if isinstance(value, datetime): + return value.timestamp() + try: + return float(value) + except (ValueError, TypeError): + return 0.0 + + # -- CRUD ------------------------------------------------------------------ + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Creates a new session (raises AlreadyExistsError on a duplicate id).""" + from google.cloud import firestore + + session_id = session_id or str(uuid.uuid4()) + deltas = _session_util.extract_state_delta(state or {}) + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + app_ref = self._app_state_ref(app_name) + user_ref = self._user_state_ref(app_name, user_id) + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": deltas["session"], + "createTime": firestore.SERVER_TIMESTAMP, + "updateTime": firestore.SERVER_TIMESTAMP, + "revision": 1, + } + + async def _create_txn(transaction: Any) -> None: + snap = await session_ref.get(transaction=transaction) + if snap.exists: + raise AlreadyExistsError(f"Session {session_id} already exists.") + if deltas["app"]: + app_snap = await app_ref.get(transaction=transaction) + current = app_snap.to_dict() if app_snap.exists else {} + current.update(deltas["app"]) + transaction.set(app_ref, current, merge=True) + if deltas["user"]: + user_snap = await user_ref.get(transaction=transaction) + current = user_snap.to_dict() if user_snap.exists else {} + current.update(deltas["user"]) + transaction.set(user_ref, current, merge=True) + transaction.set(session_ref, session_data) + + await self._transactional(_create_txn)(self.client.transaction()) + + merged = self._merge_state( + await self._read_state(app_ref), + await self._read_state(user_ref), + deltas["session"], + ) + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged, + events=[], + last_update_time=datetime.now(timezone.utc).timestamp(), + ) + session._storage_update_marker = "1" + return session + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Gets a session, merging persisted and not-yet-flushed buffered events.""" + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + doc = await session_ref.get() + if not doc.exists: + return None + data = doc.to_dict() or {} + + query = session_ref.collection(self.events_collection).order_by("timestamp") + if config: + if config.after_timestamp: + query = query.where( + "timestamp", + ">=", + datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc), + ) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + events: list[Event] = [] + for event_doc in await query.get(): + event_data = event_doc.to_dict() or {} + if "event_data" in event_data: + events.append(Event.model_validate(event_data["event_data"])) + + merged_state = self._merge_state( + await self._read_state(self._app_state_ref(app_name)), + await self._read_state(self._user_state_ref(app_name, user_id)), + data.get("state", {}) or {}, + ) + revision = data.get("revision", 0) + session = Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=merged_state, + events=events, + last_update_time=self._coerce_timestamp(data.get("updateTime")), + ) + session._storage_update_marker = str(revision) if revision > 0 else None + return self._merge_buffered(session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + """Lists sessions for an app (optionally a single user).""" + if user_id: + docs = await ( + self._get_sessions_ref(app_name, user_id) + .where("appName", "==", app_name) + .get() + ) + else: + docs = await ( + self.client.collection_group(self.sessions_collection) + .where("appName", "==", app_name) + .get() + ) + + app_state = await self._read_state(self._app_state_ref(app_name)) + user_states: dict[str, dict[str, Any]] = {} + if user_id: + user_states[user_id] = await self._read_state( + self._user_state_ref(app_name, user_id) + ) + else: + users_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + for u_doc in await users_ref.get(): + user_states[u_doc.id] = u_doc.to_dict() or {} + + sessions: list[Session] = [] + for doc in docs: + data = doc.to_dict() + if not data: + continue + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=self._merge_state( + app_state, + user_states.get(data["userId"], {}), + data.get("state", {}) or {}, + ), + events=[], + last_update_time=0.0, + ) + ) + return ListSessionsResponse(sessions=sessions) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session, its events, and drops any pending buffer.""" + async with self._buffers_guard: + self._buffers.pop(session_id, None) + self._session_refs.pop(session_id, None) + + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + events_ref = session_ref.collection(self.events_collection) + batch = self.client.batch() + count = 0 + async for event_doc in events_ref.stream(): + batch.delete(event_doc.reference) + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 + if count > 0: + await batch.commit() + await session_ref.delete() + + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Returns the raw (un-prefixed) user-scoped state for an app/user.""" + return dict(await self._read_state(self._user_state_ref(app_name, user_id))) + + # -- buffered append ------------------------------------------------------- + + @override + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event in memory and buffers (or immediately persists) it.""" + event = await super().append_event(session=session, event=event) + if event.partial: + return event + + buffered = event.model_copy(deep=True) + if self._durable_mode: + await self._persist_batch(session, [buffered]) + return event + + buffer = await self._get_or_create_buffer(session) + async with buffer.lock: + buffer.pending_events.append(buffered) + pending = len(buffer.pending_events) + + if pending >= self._buffer_max_events: + await self._flush(session.id, explicit=False) + return event + + async def flush_session(self, session_id: str) -> None: + """Explicitly flushes a session's buffer, raising on failure.""" + await self._flush(session_id, explicit=True) + + async def flush_all(self) -> None: + """Flushes every buffered session. Failures are logged, events kept.""" + for session_id in list(self._buffers.keys()): + try: + await self._flush(session_id, explicit=False) + except Exception: # noqa: BLE001 - never abort shutdown; already logged + logger.exception("flush_all_session_failed session_id=%s", session_id) + + async def flush(self) -> None: + """ADK lifecycle hook (Runner.close()): flushes all buffered sessions.""" + await self.flush_all() + + async def _flush(self, session_id: str, *, explicit: bool) -> None: + buffer = self._buffers.get(session_id) + if buffer is None: + return + + async with buffer.lock: + if buffer.flush_in_progress: + return # only one flush per session at a time + if not buffer.pending_events: + buffer.last_flush_monotonic = self._clock() + return + buffer.flush_in_progress = True + batch = list(buffer.pending_events) + buffer.pending_events.clear() + buffer.last_flush_monotonic = self._clock() + session = self._session_refs.get(session_id) + + if session is None: # pragma: no cover - defensive + async with buffer.lock: + buffer.pending_events.extendleft(reversed(batch)) + buffer.flush_in_progress = False + return + + try: + await self._persist_with_retry(session, batch, session_id) + except Exception as exc: # noqa: BLE001 - reclassified; never silently dropped + async with buffer.lock: + buffer.pending_events.extendleft(reversed(batch)) + buffer.flush_in_progress = False + if explicit: + raise SessionPersistenceError( + f"Failed to flush session {session_id} after retries." + ) from exc + return + + async with buffer.lock: + buffer.flush_in_progress = False + + async def _persist_with_retry( + self, session: Session, batch: list[Event], session_id: str + ) -> None: + attempt = 0 + while True: + attempt += 1 + try: + await self._persist_batch(session, batch) + return + except Exception as exc: # noqa: BLE001 - retryable vs permanent + if not is_retryable_error(exc) or attempt >= self._max_retry_attempts: + logger.error( + "session_flush_failed session_id=%s events=%s attempt=%s" + " error=%s", + session_id, + len(batch), + attempt, + type(exc).__name__, + ) + raise + delay = self._retry_base_delay_seconds * (2 ** (attempt - 1)) + delay += random.uniform(0.0, self._retry_base_delay_seconds) + await self._sleeper(delay) + + async def _persist_batch(self, session: Session, events: list[Event]) -> None: + """Persists a batch of events for one session in a single transaction.""" + from google.cloud import firestore + + session_ref = self._get_sessions_ref( + session.app_name, session.user_id + ).document(session.id) + app_ref = self._app_state_ref(session.app_name) + user_ref = self._user_state_ref(session.app_name, session.user_id) + + agg: dict[str, dict[str, Any]] = {"app": {}, "user": {}, "session": {}} + for event in events: + delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + scoped = _session_util.extract_state_delta(delta) + agg["app"].update(scoped["app"]) + agg["user"].update(scoped["user"]) + agg["session"].update(scoped["session"]) + has_app, has_user = bool(agg["app"]), bool(agg["user"]) + + async def _append_txn(transaction: Any) -> int: + snap = await session_ref.get(transaction=transaction) + if not snap.exists: + raise ValueError(f"Session {session.id} not found.") + doc = snap.to_dict() or {} + if doc.get("status") == "DELETING": + raise ValueError(f"Session {session.id} is currently being deleted.") + current_revision = doc.get("revision", 0) + marker = getattr(session, "_storage_update_marker", None) + if marker is not None and marker != str(current_revision): + raise ValueError( + "The session has been modified in storage since it was loaded." + " Please reload the session before appending more events." + ) + + app_snap = await app_ref.get(transaction=transaction) if has_app else None + user_snap = ( + await user_ref.get(transaction=transaction) if has_user else None + ) + + if has_app: + current = app_snap.to_dict() if app_snap.exists else {} + current.update(agg["app"]) + transaction.set(app_ref, current, merge=True) + if has_user: + current = user_snap.to_dict() if user_snap.exists else {} + current.update(agg["user"]) + transaction.set(user_ref, current, merge=True) + for key, value in agg["session"].items(): + session.state[key] = value + + for event in events: + event_ref = session_ref.collection(self.events_collection).document( + event.id + ) + transaction.set( + event_ref, + { + "event_data": event.model_dump(exclude_none=True, mode="json"), + # The event's own timestamp (not SERVER_TIMESTAMP) so order is + # preserved within a batch that shares a commit time. + "timestamp": datetime.fromtimestamp( + event.timestamp, tz=timezone.utc + ), + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + new_revision = current_revision + 1 + session_only_state = { + k: v + for k, v in session.state.items() + if not k.startswith(State.APP_PREFIX) + and not k.startswith(State.USER_PREFIX) + and not k.startswith(State.TEMP_PREFIX) + } + transaction.update( + session_ref, + { + "state": session_only_state, + "updateTime": firestore.SERVER_TIMESTAMP, + "revision": new_revision, + }, + ) + return new_revision + + new_revision = await self._transactional(_append_txn)( + self.client.transaction() + ) + session._storage_update_marker = str(new_revision) + if events: + session.last_update_time = events[-1].timestamp + + # -- periodic flushing ----------------------------------------------------- + + async def start(self) -> None: + """Starts the background periodic-flush task (idempotent).""" + if self._task is not None and not self._task.done(): + return + self._task = asyncio.create_task(self._periodic_flush_loop()) + + async def stop(self) -> None: + """Stops the background task and performs a final flush (idempotent).""" + task = self._task + self._task = None + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + await self.flush_all() + + async def close(self) -> None: + """Closes the underlying Firestore AsyncClient.""" + closer = getattr(self.client, "close", None) + if closer is not None: + result = closer() + if asyncio.iscoroutine(result): + await result + + async def _periodic_flush_loop(self) -> None: + try: + while True: + await self._sleeper(self._check_interval) + await self._flush_due() + except asyncio.CancelledError: + raise + + async def _flush_due(self) -> list[asyncio.Task[None]]: + now = self._clock() + tasks: list[asyncio.Task[None]] = [] + for session_id, buffer in list(self._buffers.items()): + due = (now - buffer.last_flush_monotonic) >= self._flush_interval_seconds + if buffer.pending_events and due: + tasks.append( + asyncio.create_task(self._safe_background_flush(session_id)) + ) + return tasks + + async def _safe_background_flush(self, session_id: str) -> None: + try: + await self._flush(session_id, explicit=False) + except Exception: # noqa: BLE001 - background task must not raise unhandled + logger.exception("background_flush_failed session_id=%s", session_id) + + # -- internal helpers ------------------------------------------------------ + + async def _get_or_create_buffer(self, session: Session) -> _SessionBuffer: + async with self._buffers_guard: + buffer = self._buffers.get(session.id) + if buffer is None: + buffer = _SessionBuffer(last_flush_monotonic=self._clock()) + self._buffers[session.id] = buffer + self._session_refs[session.id] = session + return buffer + + def _merge_buffered(self, session: Session) -> Session: + buffer = self._buffers.get(session.id) + if buffer is None or not buffer.pending_events: + return session + seen = {e.id for e in session.events} + merged = list(session.events) + for event in list(buffer.pending_events): + if event.id not in seen: + merged.append(event) + seen.add(event.id) + merged.sort(key=lambda e: (e.timestamp or 0.0)) + session.events = merged + return session diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 00000000..0a7e2794 --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,600 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BufferedFirestoreSessionService. + +Uses an in-memory fake Firestore AsyncClient (no external services), a +deterministic clock, and a recording sleeper. The service's transactional +wrapper is replaced with an identity (or gated/flaky) runner so the fake +transaction is driven directly. +""" + +import asyncio + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig +from google.genai import types +import pytest + +from google.adk_community.sessions.firestore_session_service import BufferedFirestoreSessionService +from google.adk_community.sessions.firestore_session_service import SessionPersistenceError + +APP = "app" +USER = "user-1" +SID = "session-1" + + +# --- fake Firestore ---------------------------------------------------------- + + +class FakeSnapshot: + + def __init__(self, doc_id, data, ref): + self.id = doc_id + self._data = data + self.reference = ref + + @property + def exists(self): + return self._data is not None + + def to_dict(self): + return dict(self._data) if self._data is not None else None + + +class FakeDoc: + + def __init__(self, doc_id): + self.id = doc_id + self.data = None + self._subcollections = {} + self.reference = self + + async def get(self, transaction=None): + return FakeSnapshot(self.id, self.data, self) + + def collection(self, name): + return self._subcollections.setdefault(name, FakeCollection(name)) + + async def delete(self): + self.data = None + + def set(self, data, merge=False): + if merge and isinstance(self.data, dict): + merged = dict(self.data) + merged.update(data) + self.data = merged + else: + self.data = dict(data) + + def update(self, data): + self.data = {**(self.data or {}), **data} + + +def _match(actual, op, value): + if actual is None: + return False + if op == "==": + return actual == value + if op == ">=": + return actual >= value + return False + + +class FakeQuery: + + def __init__(self, docs): + self._docs = docs + self._order = None + self._filters = [] + self._limit_last = None + + def order_by(self, field): + self._order = field + return self + + def where(self, field, op, value): + self._filters.append((field, op, value)) + return self + + def limit_to_last(self, n): + self._limit_last = n + return self + + async def get(self): + rows = [d for d in self._docs if d.data is not None] + for field, op, value in self._filters: + rows = [d for d in rows if _match(d.data.get(field), op, value)] + if self._order: + rows = sorted(rows, key=lambda d: d.data.get(self._order)) + if self._limit_last is not None: + rows = rows[-self._limit_last :] + return [FakeSnapshot(d.id, d.data, d) for d in rows] + + +class FakeCollection: + + def __init__(self, name): + self.name = name + self.docs = {} + + def document(self, doc_id): + if doc_id not in self.docs: + self.docs[doc_id] = FakeDoc(doc_id) + return self.docs[doc_id] + + def order_by(self, field): + return FakeQuery(list(self.docs.values())).order_by(field) + + def where(self, field, op, value): + return FakeQuery(list(self.docs.values())).where(field, op, value) + + async def get(self): + return await FakeQuery(list(self.docs.values())).get() + + async def stream(self): + for d in list(self.docs.values()): + if d.data is not None: + yield FakeSnapshot(d.id, d.data, d) + + +class FakeTransaction: + + def set(self, ref, data, merge=False): + ref.set(data, merge=merge) + + def update(self, ref, data): + ref.update(data) + + +class FakeBatch: + + def __init__(self): + self._ops = [] + + def delete(self, ref): + self._ops.append(ref) + + async def commit(self): + for ref in self._ops: + ref.data = None + self._ops = [] + + +class FakeFirestore: + + def __init__(self): + self.collections = {} + self.transaction_count = 0 + + def collection(self, name): + return self.collections.setdefault(name, FakeCollection(name)) + + def collection_group(self, name): + return FakeQuery(self._gather_group(name)) + + def transaction(self): + self.transaction_count += 1 + return FakeTransaction() + + def batch(self): + return FakeBatch() + + def _gather_group(self, name): + result = [] + + def walk(coll): + for doc in coll.docs.values(): + for sub_name, sub in doc._subcollections.items(): + if sub_name == name: + result.extend(d for d in sub.docs.values() if d.data is not None) + walk(sub) + + for coll in self.collections.values(): + if coll.name == name: + result.extend(d for d in coll.docs.values() if d.data is not None) + walk(coll) + return result + + +# --- helpers ----------------------------------------------------------------- + + +class Clock: + + def __init__(self, start=1000.0): + self.now = start + + def __call__(self): + return self.now + + def advance(self, seconds): + self.now += seconds + + +class RecordingSleeper: + + def __init__(self): + self.delays = [] + + async def __call__(self, delay): + self.delays.append(delay) + + +class Aborted(Exception): + """Name matches the retryable allowlist.""" + + +def _identity_transactional(fn): + + async def run(transaction): + return await fn(transaction) + + return run + + +def _make(**kwargs): + client = FakeFirestore() + clock = Clock() + sleeper = RecordingSleeper() + service = BufferedFirestoreSessionService( + client, clock=clock, sleeper=sleeper, **kwargs + ) + service._transactional = _identity_transactional + return service, client, clock, sleeper + + +def _event(author, text, timestamp, *, state_delta=None): + return Event( + invocation_id=f"inv-{timestamp}", + author=author, + timestamp=timestamp, + content=types.Content( + role="user" if author == "user" else "model", + parts=[types.Part(text=text)], + ), + actions=EventActions(state_delta=state_delta or {}), + ) + + +def _session_doc(client, session_id=SID): + return ( + client.collection("adk-session") + .document(APP) + .collection("users") + .document(USER) + .collection("sessions") + .document(session_id) + ) + + +def _persisted_event_count(client, session_id=SID): + events = _session_doc(client, session_id)._subcollections.get("events") + if events is None: + return 0 + return sum(1 for d in events.docs.values() if d.data is not None) + + +# --- tests ------------------------------------------------------------------- + + +async def test_create_session_writes_metadata(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + assert session.id == SID + doc = _session_doc(client).data + assert doc["appName"] == APP + assert doc["userId"] == USER + assert doc["revision"] == 1 + + +async def test_buffered_append_defers_persistence(): + service, client, *_ = _make(buffer_max_events=10) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + base = client.transaction_count + for i in range(9): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 0 + assert client.transaction_count == base + + +async def test_flush_persists_whole_batch_in_one_transaction(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(9): + await service.append_event(session, _event("user", f"m{i}", float(i))) + before = client.transaction_count + await service.flush_session(SID) + assert client.transaction_count - before == 1 + assert _persisted_event_count(client) == 9 + assert _session_doc(client).data["revision"] == 2 + + +async def test_reaching_max_events_auto_flushes(): + service, client, *_ = _make(buffer_max_events=10) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(10): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 10 + + +async def test_durable_mode_writes_each_event_immediately(): + service, client, *_ = _make(durable_mode=True) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + base = client.transaction_count + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert client.transaction_count - base == 3 + assert _persisted_event_count(client) == 3 + assert SID not in service._buffers + + +async def test_periodic_flush_after_interval(): + service, client, clock, _ = _make(flush_interval_seconds=120.0) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + assert _persisted_event_count(client) == 0 + clock.advance(121.0) + await asyncio.gather(*await service._flush_due()) + assert _persisted_event_count(client) == 3 + + +async def test_flush_hook_and_stop_final_flush(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush() # ADK Runner.close() hook + assert _persisted_event_count(client) == 1 + + +async def test_get_session_merges_and_orders_without_duplicates(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "persisted", 1.0)) + await service.flush_session(SID) + await service.append_event(session, _event("user", "buffered", 2.0)) + loaded = await service.get_session(app_name=APP, user_id=USER, session_id=SID) + texts = [e.content.parts[0].text for e in loaded.events] + assert texts == ["persisted", "buffered"] + assert len(texts) == len({e.id for e in loaded.events}) + + +async def test_state_delta_scoping(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event( + session, + _event( + "user", + "a", + 1.0, + state_delta={ + "app:shared": "yes", + "user:goal": "fat loss", + "sessionOnly": "kept", + "temp:scratch": "discard", + }, + ), + ) + await service.flush_session(SID) + loaded = await service.get_session(app_name=APP, user_id=USER, session_id=SID) + assert loaded.state["app:shared"] == "yes" + assert loaded.state["user:goal"] == "fat loss" + assert loaded.state["sessionOnly"] == "kept" + assert "temp:scratch" not in loaded.state + + +async def test_get_user_state(): + service, *_ = _make() + await service.create_session( + app_name=APP, user_id=USER, session_id=SID, state={"user:goal": "lose"} + ) + state = await service.get_user_state(app_name=APP, user_id=USER) + assert state == {"goal": "lose"} + + +async def test_retryable_failures_backoff_then_succeed(): + service, client, _, sleeper = _make(max_retry_attempts=5) + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + errors = [Aborted(), Aborted()] + + def flaky(fn): + + async def run(transaction): + if errors: + raise errors.pop(0) + return await fn(transaction) + + return run + + service._transactional = flaky + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush_session(SID) + assert _persisted_event_count(client) == 1 + assert len(sleeper.delays) == 2 + assert sleeper.delays[1] > sleeper.delays[0] + + +async def test_permanent_failure_not_retried(): + service, _, _, sleeper = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + + def boom(fn): + + async def run(transaction): + raise ValueError("permission denied") + + return run + + service._transactional = boom + await service.append_event(session, _event("user", "a", 1.0)) + with pytest.raises(SessionPersistenceError): + await service.flush_session(SID) + assert sleeper.delays == [] + + +async def test_events_appended_during_flush_not_lost(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + gate = asyncio.Event() + entered = asyncio.Event() + + def gated(fn): + + async def run(transaction): + entered.set() + await gate.wait() + return await fn(transaction) + + return run + + await service.append_event(session, _event("user", "a", 1.0)) + await service.append_event(session, _event("user", "b", 2.0)) + service._transactional = gated + + flush_task = asyncio.create_task(service.flush_session(SID)) + await entered.wait() + await service.append_event(session, _event("user", "c", 3.0)) + gate.set() + await flush_task + + pending = service._buffers[SID].pending_events + assert [e.content.parts[0].text for e in pending] == ["c"] + assert _persisted_event_count(client) == 2 + + +async def test_concurrent_flushes_do_not_duplicate(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + gate = asyncio.Event() + entered = asyncio.Event() + + def gated(fn): + + async def run(transaction): + entered.set() + await gate.wait() + return await fn(transaction) + + return run + + await service.append_event(session, _event("user", "a", 1.0)) + await service.append_event(session, _event("user", "b", 2.0)) + service._transactional = gated + before = client.transaction_count + + t1 = asyncio.create_task(service.flush_session(SID)) + await entered.wait() + t2 = asyncio.create_task(service.flush_session(SID)) + gate.set() + await asyncio.gather(t1, t2) + + assert client.transaction_count - before == 1 + assert _persisted_event_count(client) == 2 + + +async def test_delete_session_removes_events_and_buffer(): + service, client, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + await service.append_event(session, _event("user", "a", 1.0)) + await service.flush_session(SID) + await service.delete_session(app_name=APP, user_id=USER, session_id=SID) + assert _session_doc(client).data is None + assert _persisted_event_count(client) == 0 + assert SID not in service._buffers + + +async def test_get_session_not_found_returns_none(): + service, *_ = _make() + result = await service.get_session( + app_name=APP, user_id=USER, session_id="missing" + ) + assert result is None + + +async def test_list_sessions(): + service, *_ = _make() + await service.create_session(app_name=APP, user_id=USER, session_id="s1") + await service.create_session(app_name=APP, user_id=USER, session_id="s2") + per_user = await service.list_sessions(app_name=APP, user_id=USER) + all_users = await service.list_sessions(app_name=APP) + assert {s.id for s in per_user.sessions} == {"s1", "s2"} + assert {s.id for s in all_users.sessions} == {"s1", "s2"} + + +async def test_get_session_with_config_num_recent_events(): + service, *_ = _make() + session = await service.create_session( + app_name=APP, user_id=USER, session_id=SID + ) + for i in range(3): + await service.append_event(session, _event("user", f"m{i}", float(i))) + await service.flush_session(SID) + loaded = await service.get_session( + app_name=APP, + user_id=USER, + session_id=SID, + config=GetSessionConfig(num_recent_events=2), + ) + assert [e.content.parts[0].text for e in loaded.events] == ["m1", "m2"] + + +async def test_create_session_duplicate_raises(): + from google.adk.errors.already_exists_error import AlreadyExistsError + + service, *_ = _make() + await service.create_session(app_name=APP, user_id=USER, session_id=SID) + with pytest.raises(AlreadyExistsError): + await service.create_session(app_name=APP, user_id=USER, session_id=SID) + + +async def test_start_stop_cancellation_is_clean(): + service, *_ = _make() + service._sleeper = asyncio.sleep # real sleep so the loop blocks + await service.start() + await service.start() # idempotent + task = service._task + await service.stop() + await service.stop() # idempotent + assert task.cancelled() or task.done() From 761842e32c7be8fc406bebabc976a3cec9bf7c72 Mon Sep 17 00:00:00 2001 From: enesdemirag Date: Sat, 20 Jun 2026 00:15:20 +0300 Subject: [PATCH 2/4] refactor(sessions): clean up and add configurable collection names - Remove unused _SessionLockKey type alias - Move import copy to module level (stdlib, no reason to lazy-import) - Store self._firestore in __init__ to avoid repeated guarded imports inside create_session / _persist_batch - Add sessions_collection, events_collection, app_state_collection, and user_state_collection constructor params (keyword-only, with defaults) so developers can customise the Firestore collection layout without subclassing --- .../sessions/firestore_session_service.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py index cce6937e..03720145 100644 --- a/src/google/adk_community/sessions/firestore_session_service.py +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -47,6 +47,7 @@ from collections import deque from collections.abc import Awaitable from collections.abc import Callable +import copy from dataclasses import dataclass from dataclasses import field from datetime import datetime @@ -76,8 +77,6 @@ DEFAULT_APP_STATE_COLLECTION = "app_states" DEFAULT_USER_STATE_COLLECTION = "user_states" -_SessionLockKey = tuple[str, str, str] - # Transient Firestore / gRPC failures worth retrying. Matched by class name to # avoid a hard dependency on google.api_core being importable everywhere. _RETRYABLE_ERROR_NAMES = frozenset({ @@ -141,6 +140,10 @@ def __init__( client: Any = None, root_collection: Optional[str] = None, *, + sessions_collection: str = DEFAULT_SESSIONS_COLLECTION, + events_collection: str = DEFAULT_EVENTS_COLLECTION, + app_state_collection: str = DEFAULT_APP_STATE_COLLECTION, + user_state_collection: str = DEFAULT_USER_STATE_COLLECTION, durable_mode: bool = False, buffer_max_events: int = 10, flush_interval_seconds: float = 120.0, @@ -155,6 +158,14 @@ def __init__( client: An optional Firestore ``AsyncClient``. If not provided, a new one is created (requires ``google-cloud-firestore``). root_collection: Root collection name. Defaults to ``'adk-session'``. + sessions_collection: Subcollection name for sessions. Defaults to + ``'sessions'``. + events_collection: Subcollection name for events. Defaults to + ``'events'``. + app_state_collection: Root collection for app-scoped state. Defaults to + ``'app_states'``. + user_state_collection: Root collection for user-scoped state. Defaults + to ``'user_states'``. durable_mode: When True, every event is persisted immediately and no buffering happens. buffer_max_events: Flush a session once this many events are buffered. @@ -172,12 +183,13 @@ def __init__( " Install it with: pip install google-adk-community[firestore]" ) from e + self._firestore = firestore self.client = client if client is not None else firestore.AsyncClient() self.root_collection = root_collection or DEFAULT_ROOT_COLLECTION - self.sessions_collection = DEFAULT_SESSIONS_COLLECTION - self.events_collection = DEFAULT_EVENTS_COLLECTION - self.app_state_collection = DEFAULT_APP_STATE_COLLECTION - self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + self.sessions_collection = sessions_collection + self.events_collection = events_collection + self.app_state_collection = app_state_collection + self.user_state_collection = user_state_collection self._durable_mode = durable_mode self._buffer_max_events = buffer_max_events @@ -224,8 +236,6 @@ def _merge_state( user_state: dict[str, Any], session_state: dict[str, Any], ) -> dict[str, Any]: - import copy - merged = copy.deepcopy(session_state) for key, value in app_state.items(): merged[State.APP_PREFIX + key] = value @@ -258,8 +268,6 @@ async def create_session( session_id: Optional[str] = None, ) -> Session: """Creates a new session (raises AlreadyExistsError on a duplicate id).""" - from google.cloud import firestore - session_id = session_id or str(uuid.uuid4()) deltas = _session_util.extract_state_delta(state or {}) session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) @@ -270,8 +278,8 @@ async def create_session( "appName": app_name, "userId": user_id, "state": deltas["session"], - "createTime": firestore.SERVER_TIMESTAMP, - "updateTime": firestore.SERVER_TIMESTAMP, + "createTime": self._firestore.SERVER_TIMESTAMP, + "updateTime": self._firestore.SERVER_TIMESTAMP, "revision": 1, } @@ -546,8 +554,6 @@ async def _persist_with_retry( async def _persist_batch(self, session: Session, events: list[Event]) -> None: """Persists a batch of events for one session in a single transaction.""" - from google.cloud import firestore - session_ref = self._get_sessions_ref( session.app_name, session.user_id ).document(session.id) @@ -628,7 +634,7 @@ async def _append_txn(transaction: Any) -> int: session_ref, { "state": session_only_state, - "updateTime": firestore.SERVER_TIMESTAMP, + "updateTime": self._firestore.SERVER_TIMESTAMP, "revision": new_revision, }, ) From dff032f8e072cee54f1f814ea8ba51172e1411ea Mon Sep 17 00:00:00 2001 From: enesdemirag Date: Sat, 20 Jun 2026 00:25:41 +0300 Subject: [PATCH 3/4] feat(sessions): add flat_layout option for configurable collection hierarchy Add flat_layout=True constructor parameter so developers can store sessions directly in root_collection/{session_id} instead of the default nested root/{app}/users/{user}/sessions/{session_id} path. Useful when the session id already encodes the user (e.g. {phone}-{date}), matching an existing flat Firestore collection layout. list_sessions adds a userId field filter automatically in flat mode. --- .../sessions/firestore_session_service.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py index 03720145..cc485416 100644 --- a/src/google/adk_community/sessions/firestore_session_service.py +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -144,6 +144,7 @@ def __init__( events_collection: str = DEFAULT_EVENTS_COLLECTION, app_state_collection: str = DEFAULT_APP_STATE_COLLECTION, user_state_collection: str = DEFAULT_USER_STATE_COLLECTION, + flat_layout: bool = False, durable_mode: bool = False, buffer_max_events: int = 10, flush_interval_seconds: float = 120.0, @@ -166,6 +167,10 @@ def __init__( ``'app_states'``. user_state_collection: Root collection for user-scoped state. Defaults to ``'user_states'``. + flat_layout: When True, session documents live directly in + ``root_collection/{session_id}`` (no ``{app}/users/{user}/sessions/`` + nesting). Useful when the session id already encodes the user (e.g. + ``{phone}-{date}``). Defaults to False. durable_mode: When True, every event is persisted immediately and no buffering happens. buffer_max_events: Flush a session once this many events are buffered. @@ -190,6 +195,9 @@ def __init__( self.events_collection = events_collection self.app_state_collection = app_state_collection self.user_state_collection = user_state_collection + # flat_layout=True: sessions/{session_id} (no {app}/users/{user} nesting) + # flat_layout=False (default): {root}/{app}/users/{user}/{sessions}/{session_id} + self._flat_layout = flat_layout self._durable_mode = durable_mode self._buffer_max_events = buffer_max_events @@ -211,6 +219,8 @@ def __init__( # -- Firestore refs / helpers --------------------------------------------- def _get_sessions_ref(self, app_name: str, user_id: str) -> Any: + if self._flat_layout: + return self.client.collection(self.root_collection) return ( self.client.collection(self.root_collection) .document(app_name) @@ -371,7 +381,14 @@ async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None ) -> ListSessionsResponse: """Lists sessions for an app (optionally a single user).""" - if user_id: + if self._flat_layout: + query = self.client.collection(self.root_collection).where( + "appName", "==", app_name + ) + if user_id: + query = query.where("userId", "==", user_id) + docs = await query.get() + elif user_id: docs = await ( self._get_sessions_ref(app_name, user_id) .where("appName", "==", app_name) From c411073ea1e53e663d56eeb20290dbf6eede146b Mon Sep 17 00:00:00 2001 From: enesdemirag Date: Sat, 20 Jun 2026 00:33:34 +0300 Subject: [PATCH 4/4] style: remove __future__ annotations, clean up comments, align with pyink - Drop `from __future__ import annotations`; use X | None syntax directly (requires Python >=3.10, already in project metadata) - Remove Optional import; all annotations now use built-in union syntax - Remove vague section-header comments - Simplify is_retryable_error return (single return name in _RETRYABLE_ERROR_NAMES) - Update firestore extra version range to match ADK's own constraint (>=2.11,<3) --- pyproject.toml | 2 +- .../sessions/firestore_session_service.py | 170 ++++++++---------- 2 files changed, 75 insertions(+), 97 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f7ddd53..0ddac04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] firestore = [ - "google-cloud-firestore>=2.11.0, <3.0.0", # For BufferedFirestoreSessionService + "google-cloud-firestore>=2.11,<3", ] s3 = [ "aioboto3>=13.0.0", # For S3ArtifactService diff --git a/src/google/adk_community/sessions/firestore_session_service.py b/src/google/adk_community/sessions/firestore_session_service.py index cc485416..ec15d96b 100644 --- a/src/google/adk_community/sessions/firestore_session_service.py +++ b/src/google/adk_community/sessions/firestore_session_service.py @@ -18,31 +18,28 @@ ``google.adk.integrations.firestore.FirestoreSessionService`` (same collection hierarchy, app/user/session state scoping, optimistic concurrency via a ``revision`` field, and idempotent event documents keyed by ``event.id``) but -**owns** the Firestore I/O so it can persist a whole batch of buffered events in -a **single transaction**. +**owns** the Firestore I/O so it can persist a whole batch of buffered events +in a **single transaction**. -Collection hierarchy:: +Collection hierarchy (matches the ADK builtin):: adk-session/{app}/users/{user}/sessions/{session}/events/{event} app_states/{app} user_states/{app}/users/{user} Events accumulate in a per-session in-memory buffer and flush when the buffer -reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (the -background task started by :meth:`start`), when ``flush_session`` / ``flush_all`` -/ ``flush`` is called, or when :meth:`stop` runs. Set ``durable_mode=True`` to -persist every event immediately (no buffering). - -Batching does not change the event-document count, but it collapses the repeated -session-doc + state-doc updates and per-event transactions from N to 1 (fewer -round-trips and less optimistic-lock contention). On an abrupt process death -before a flush, up to ``flush_interval_seconds`` of events (or -``buffer_max_events - 1`` per session) may be lost; ``stop()`` flushes on -graceful shutdown but cannot protect against crashes. +reaches ``buffer_max_events``, when ``flush_interval_seconds`` elapses (via +the background task started by :meth:`start`), when ``flush_session`` / +``flush_all`` / ``flush`` is called, or when :meth:`stop` runs. Set +``durable_mode=True`` to persist every event immediately (no buffering). + +Batching collapses the repeated session-doc + state-doc updates and per-event +transactions from N to 1 (fewer round-trips, less optimistic-lock contention). +On an abrupt process death before a flush, up to ``flush_interval_seconds`` of +events (or ``buffer_max_events - 1`` per session) may be lost; :meth:`stop` +flushes on graceful shutdown. """ -from __future__ import annotations - import asyncio from collections import deque from collections.abc import Awaitable @@ -56,7 +53,6 @@ import random import time from typing import Any -from typing import Optional import uuid from google.adk.errors.already_exists_error import AlreadyExistsError @@ -77,8 +73,6 @@ DEFAULT_APP_STATE_COLLECTION = "app_states" DEFAULT_USER_STATE_COLLECTION = "user_states" -# Transient Firestore / gRPC failures worth retrying. Matched by class name to -# avoid a hard dependency on google.api_core being importable everywhere. _RETRYABLE_ERROR_NAMES = frozenset({ "DeadlineExceeded", "ServiceUnavailable", @@ -117,9 +111,7 @@ def is_retryable_error(exc: BaseException) -> bool: name = type(exc).__name__ if name in _NON_RETRYABLE_ERROR_NAMES: return False - if name in _RETRYABLE_ERROR_NAMES: - return True - return False + return name in _RETRYABLE_ERROR_NAMES @dataclass @@ -138,7 +130,7 @@ class BufferedFirestoreSessionService(BaseSessionService): # type: ignore[misc] def __init__( self, client: Any = None, - root_collection: Optional[str] = None, + root_collection: str | None = None, *, sessions_collection: str = DEFAULT_SESSIONS_COLLECTION, events_collection: str = DEFAULT_EVENTS_COLLECTION, @@ -156,26 +148,25 @@ def __init__( """Initializes the buffered Firestore session service. Args: - client: An optional Firestore ``AsyncClient``. If not provided, a new one - is created (requires ``google-cloud-firestore``). + client: An optional Firestore ``AsyncClient``. If not provided, a new + one is created (requires ``google-cloud-firestore``). root_collection: Root collection name. Defaults to ``'adk-session'``. - sessions_collection: Subcollection name for sessions. Defaults to + sessions_collection: Sessions subcollection name. Defaults to ``'sessions'``. - events_collection: Subcollection name for events. Defaults to - ``'events'``. - app_state_collection: Root collection for app-scoped state. Defaults to + events_collection: Events subcollection name. Defaults to ``'events'``. + app_state_collection: Collection for app-scoped state. Defaults to ``'app_states'``. - user_state_collection: Root collection for user-scoped state. Defaults - to ``'user_states'``. - flat_layout: When True, session documents live directly in - ``root_collection/{session_id}`` (no ``{app}/users/{user}/sessions/`` - nesting). Useful when the session id already encodes the user (e.g. - ``{phone}-{date}``). Defaults to False. - durable_mode: When True, every event is persisted immediately and no - buffering happens. - buffer_max_events: Flush a session once this many events are buffered. + user_state_collection: Collection for user-scoped state. Defaults to + ``'user_states'``. + flat_layout: When ``True``, session documents are stored directly at + ``root_collection/{session_id}`` instead of the default nested ADK + path. Useful when the session id already encodes the user (e.g. + ``{phone}-{date}``) or to match an existing flat collection. + durable_mode: When ``True``, every event is persisted immediately (no + buffering). Equivalent to the builtin service behaviour. + buffer_max_events: Flush when this many events are buffered per session. flush_interval_seconds: Background flush cadence (see :meth:`start`). - max_retry_attempts: Max attempts when a flush hits a retryable error. + max_retry_attempts: Max attempts on a retryable Firestore error. retry_base_delay_seconds: Base delay for exponential backoff with jitter. clock: Monotonic clock, injectable for tests. sleeper: Async sleep function, injectable for tests. @@ -195,10 +186,7 @@ def __init__( self.events_collection = events_collection self.app_state_collection = app_state_collection self.user_state_collection = user_state_collection - # flat_layout=True: sessions/{session_id} (no {app}/users/{user} nesting) - # flat_layout=False (default): {root}/{app}/users/{user}/{sessions}/{session_id} self._flat_layout = flat_layout - self._durable_mode = durable_mode self._buffer_max_events = buffer_max_events self._flush_interval_seconds = flush_interval_seconds @@ -206,18 +194,14 @@ def __init__( self._retry_base_delay_seconds = retry_base_delay_seconds self._clock = clock self._sleeper = sleeper - # Injectable so tests can drive a fake client without the real transactional - # retry wrapper. self._transactional = firestore.async_transactional self._buffers: dict[str, _SessionBuffer] = {} self._session_refs: dict[str, Session] = {} self._buffers_guard = asyncio.Lock() - self._task: Optional[asyncio.Task[None]] = None + self._task: asyncio.Task[None] | None = None self._check_interval = max(1.0, min(flush_interval_seconds, 5.0)) - # -- Firestore refs / helpers --------------------------------------------- - def _get_sessions_ref(self, app_name: str, user_id: str) -> Any: if self._flat_layout: return self.client.collection(self.root_collection) @@ -266,16 +250,14 @@ def _coerce_timestamp(value: Any) -> float: except (ValueError, TypeError): return 0.0 - # -- CRUD ------------------------------------------------------------------ - @override async def create_session( self, *, app_name: str, user_id: str, - state: Optional[dict[str, Any]] = None, - session_id: Optional[str] = None, + state: dict[str, Any] | None = None, + session_id: str | None = None, ) -> Session: """Creates a new session (raises AlreadyExistsError on a duplicate id).""" session_id = session_id or str(uuid.uuid4()) @@ -334,8 +316,8 @@ async def get_session( app_name: str, user_id: str, session_id: str, - config: Optional[GetSessionConfig] = None, - ) -> Optional[Session]: + config: GetSessionConfig | None = None, + ) -> Session | None: """Gets a session, merging persisted and not-yet-flushed buffered events.""" session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) doc = await session_ref.get() @@ -378,7 +360,7 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str | None = None ) -> ListSessionsResponse: """Lists sessions for an app (optionally a single user).""" if self._flat_layout: @@ -468,8 +450,6 @@ async def get_user_state( """Returns the raw (un-prefixed) user-scoped state for an app/user.""" return dict(await self._read_state(self._user_state_ref(app_name, user_id))) - # -- buffered append ------------------------------------------------------- - @override async def append_event(self, session: Session, event: Event) -> Event: """Appends an event in memory and buffers (or immediately persists) it.""" @@ -500,13 +480,39 @@ async def flush_all(self) -> None: for session_id in list(self._buffers.keys()): try: await self._flush(session_id, explicit=False) - except Exception: # noqa: BLE001 - never abort shutdown; already logged + except Exception: # noqa: BLE001 logger.exception("flush_all_session_failed session_id=%s", session_id) async def flush(self) -> None: """ADK lifecycle hook (Runner.close()): flushes all buffered sessions.""" await self.flush_all() + async def start(self) -> None: + """Starts the background periodic-flush task (idempotent).""" + if self._task is not None and not self._task.done(): + return + self._task = asyncio.create_task(self._periodic_flush_loop()) + + async def stop(self) -> None: + """Stops the background task and performs a final flush (idempotent).""" + task = self._task + self._task = None + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + await self.flush_all() + + async def close(self) -> None: + """Closes the underlying Firestore AsyncClient.""" + closer = getattr(self.client, "close", None) + if closer is not None: + result = closer() + if asyncio.iscoroutine(result): + await result + async def _flush(self, session_id: str, *, explicit: bool) -> None: buffer = self._buffers.get(session_id) if buffer is None: @@ -514,7 +520,7 @@ async def _flush(self, session_id: str, *, explicit: bool) -> None: async with buffer.lock: if buffer.flush_in_progress: - return # only one flush per session at a time + return if not buffer.pending_events: buffer.last_flush_monotonic = self._clock() return @@ -524,7 +530,7 @@ async def _flush(self, session_id: str, *, explicit: bool) -> None: buffer.last_flush_monotonic = self._clock() session = self._session_refs.get(session_id) - if session is None: # pragma: no cover - defensive + if session is None: # pragma: no cover async with buffer.lock: buffer.pending_events.extendleft(reversed(batch)) buffer.flush_in_progress = False @@ -532,7 +538,7 @@ async def _flush(self, session_id: str, *, explicit: bool) -> None: try: await self._persist_with_retry(session, batch, session_id) - except Exception as exc: # noqa: BLE001 - reclassified; never silently dropped + except Exception as exc: # noqa: BLE001 async with buffer.lock: buffer.pending_events.extendleft(reversed(batch)) buffer.flush_in_progress = False @@ -554,7 +560,7 @@ async def _persist_with_retry( try: await self._persist_batch(session, batch) return - except Exception as exc: # noqa: BLE001 - retryable vs permanent + except Exception as exc: # noqa: BLE001 if not is_retryable_error(exc) or attempt >= self._max_retry_attempts: logger.error( "session_flush_failed session_id=%s events=%s attempt=%s" @@ -625,12 +631,11 @@ async def _append_txn(transaction: Any) -> int: event_ref = session_ref.collection(self.events_collection).document( event.id ) + # Use event's own timestamp so intra-batch order survives a shared commit time. transaction.set( event_ref, { "event_data": event.model_dump(exclude_none=True, mode="json"), - # The event's own timestamp (not SERVER_TIMESTAMP) so order is - # preserved within a batch that shares a commit time. "timestamp": datetime.fromtimestamp( event.timestamp, tz=timezone.utc ), @@ -664,34 +669,6 @@ async def _append_txn(transaction: Any) -> int: if events: session.last_update_time = events[-1].timestamp - # -- periodic flushing ----------------------------------------------------- - - async def start(self) -> None: - """Starts the background periodic-flush task (idempotent).""" - if self._task is not None and not self._task.done(): - return - self._task = asyncio.create_task(self._periodic_flush_loop()) - - async def stop(self) -> None: - """Stops the background task and performs a final flush (idempotent).""" - task = self._task - self._task = None - if task is not None: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - await self.flush_all() - - async def close(self) -> None: - """Closes the underlying Firestore AsyncClient.""" - closer = getattr(self.client, "close", None) - if closer is not None: - result = closer() - if asyncio.iscoroutine(result): - await result - async def _periodic_flush_loop(self) -> None: try: while True: @@ -704,8 +681,11 @@ async def _flush_due(self) -> list[asyncio.Task[None]]: now = self._clock() tasks: list[asyncio.Task[None]] = [] for session_id, buffer in list(self._buffers.items()): - due = (now - buffer.last_flush_monotonic) >= self._flush_interval_seconds - if buffer.pending_events and due: + if ( + buffer.pending_events + and (now - buffer.last_flush_monotonic) + >= self._flush_interval_seconds + ): tasks.append( asyncio.create_task(self._safe_background_flush(session_id)) ) @@ -714,11 +694,9 @@ async def _flush_due(self) -> list[asyncio.Task[None]]: async def _safe_background_flush(self, session_id: str) -> None: try: await self._flush(session_id, explicit=False) - except Exception: # noqa: BLE001 - background task must not raise unhandled + except Exception: # noqa: BLE001 logger.exception("background_flush_failed session_id=%s", session_id) - # -- internal helpers ------------------------------------------------------ - async def _get_or_create_buffer(self, session: Session) -> _SessionBuffer: async with self._buffers_guard: buffer = self._buffers.get(session.id)