From 9db6db55904b8008b63d560c378e376c2bda4d9c Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:08:01 +0000 Subject: [PATCH 1/6] refactor to simplify the firestore session service implementation --- .../firestore/firestore_session_service.py | 186 +++++++++--------- .../test_firestore_session_service.py | 10 +- 2 files changed, 104 insertions(+), 92 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 83b97c33c2..098eafed95 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -15,13 +15,11 @@ from __future__ import annotations import asyncio -from contextlib import asynccontextmanager from datetime import datetime from datetime import timezone import logging import os from typing import Any -from typing import AsyncIterator from typing import cast from typing import Optional from typing import TYPE_CHECKING @@ -50,6 +48,40 @@ DEFAULT_USER_STATE_COLLECTION = "user_states" +class _SessionLockContext: + """Async context manager for serializing event appends for the same session.""" + + def __init__( + self, service: FirestoreSessionService, lock_key: _SessionLockKey + ): + self.service = service + self.lock_key = lock_key + self.lock: Optional[asyncio.Lock] = None + + async def __aenter__(self) -> None: + async with self.service._session_locks_guard: + lock = self.service._session_locks.get(self.lock_key, asyncio.Lock()) + self.service._session_locks[self.lock_key] = lock + self.service._session_lock_ref_count[self.lock_key] = ( + self.service._session_lock_ref_count.get(self.lock_key, 0) + 1 + ) + self.lock = lock + await self.lock.acquire() + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self.lock: + self.lock.release() + async with self.service._session_locks_guard: + remaining = ( + self.service._session_lock_ref_count.get(self.lock_key, 0) - 1 + ) + if remaining <= 0 and not self.lock.locked(): + self.service._session_lock_ref_count.pop(self.lock_key, None) + self.service._session_locks.pop(self.lock_key, None) + else: + self.service._session_lock_ref_count[self.lock_key] = remaining + + class FirestoreSessionService(BaseSessionService): # type: ignore[misc] """Session service that uses Google Cloud Firestore as the backend. @@ -110,30 +142,11 @@ def __init__( self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION - @asynccontextmanager - async def _with_session_lock( + def _with_session_lock( self, *, app_name: str, user_id: str, session_id: str - ) -> AsyncIterator[None]: + ) -> _SessionLockContext: """Serializes event appends for the same session within this process.""" - lock_key = (app_name, user_id, session_id) - async with self._session_locks_guard: - lock = self._session_locks.get(lock_key, asyncio.Lock()) - self._session_locks[lock_key] = lock - self._session_lock_ref_count[lock_key] = ( - self._session_lock_ref_count.get(lock_key, 0) + 1 - ) - - try: - async with lock: - yield - finally: - async with self._session_locks_guard: - remaining = self._session_lock_ref_count.get(lock_key, 0) - 1 - if remaining <= 0 and not lock.locked(): - self._session_lock_ref_count.pop(lock_key, None) - self._session_locks.pop(lock_key, None) - else: - self._session_lock_ref_count[lock_key] = remaining + return _SessionLockContext(self, (app_name, user_id, session_id)) @staticmethod def _merge_state( @@ -210,7 +223,9 @@ async def create_session( } @firestore.async_transactional # type: ignore[untyped-decorator] - async def _create_txn(transaction: firestore.AsyncTransaction) -> None: + async def _create_txn( + transaction: firestore.AsyncTransaction, + ) -> tuple[dict[str, Any], dict[str, Any]]: # 1. Reads snap = await session_ref.get(transaction=transaction) if snap.exists: @@ -218,45 +233,30 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None: raise AlreadyExistsError(f"Session {session_id} already exists.") - app_snap = ( - await app_ref.get(transaction=transaction) - if app_state_delta - else None + app_snap = await app_ref.get(transaction=transaction) + user_snap = await user_ref.get(transaction=transaction) + + current_app: dict[str, Any] = ( + (app_snap.to_dict() or {}) if app_snap.exists else {} ) - user_snap = ( - await user_ref.get(transaction=transaction) - if user_state_delta - else None + current_user: dict[str, Any] = ( + (user_snap.to_dict() or {}) if user_snap.exists else {} ) # 2. Writes if app_state_delta: - current_app = ( - app_snap.to_dict() if (app_snap and app_snap.exists) else {} - ) current_app.update(app_state_delta) transaction.set(app_ref, current_app, merge=True) if user_state_delta: - current_user = ( - user_snap.to_dict() if (user_snap and user_snap.exists) else {} - ) current_user.update(user_state_delta) transaction.set(user_ref, current_user, merge=True) transaction.set(session_ref, session_data) + return current_app, current_user transaction_obj = self.client.transaction() - await _create_txn(transaction_obj) - - storage_app_doc = await app_ref.get() - storage_app_state = ( - storage_app_doc.to_dict() if storage_app_doc.exists else {} - ) - storage_user_doc = await user_ref.get() - storage_user_state = ( - storage_user_doc.to_dict() if storage_user_doc.exists else {} - ) + storage_app_state, storage_user_state = await _create_txn(transaction_obj) merged_state = self._merge_state( storage_app_state, storage_user_state, session_state @@ -294,7 +294,7 @@ async def get_session( if not data: return None - # Fetch events + # Fetch events and shared state concurrently events_ref = session_ref.collection(self.events_collection) query = events_ref.order_by("timestamp") @@ -305,18 +305,6 @@ async def get_session( if config.num_recent_events: query = query.limit_to_last(config.num_recent_events) - events_docs = await query.get() - events = [] - for event_doc in events_docs: - event_data = event_doc.to_dict() - if event_data and "event_data" in event_data: - ed = event_data["event_data"] - events.append(Event.model_validate(ed)) - - # Let's continue getting session. - session_state = data.get("state", {}) - - # Fetch shared state app_ref = self.client.collection(self.app_state_collection).document( app_name ) @@ -326,9 +314,22 @@ async def get_session( .collection("users") .document(user_id) ) - app_doc = await app_ref.get() + + events_docs, app_doc, user_doc = await asyncio.gather( + query.get(), + app_ref.get(), + user_ref.get(), + ) + + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + ed = event_data["event_data"] + events.append(Event.model_validate(ed)) + + session_state = data.get("state", {}) app_state = app_doc.to_dict() if app_doc.exists else {} - user_doc = await user_ref.get() user_state = user_doc.to_dict() if user_doc.exists else {} merged_state = self._merge_state(app_state, user_state, session_state) @@ -374,6 +375,8 @@ async def list_sessions( ) docs = await query.get() + sessions_data = [d for doc in docs if (d := doc.to_dict())] + # Fetch shared state once app_ref = self.client.collection(self.app_state_collection).document( app_name @@ -393,34 +396,37 @@ async def list_sessions( if user_doc.exists: user_states_map[user_id] = user_doc.to_dict() else: - users_ref = ( - self.client.collection(self.user_state_collection) - .document(app_name) - .collection("users") - ) - users_docs = await users_ref.get() - for u_doc in users_docs: - user_states_map[u_doc.id] = u_doc.to_dict() + unique_user_ids = {s["userId"] for s in sessions_data if "userId" in s} + if unique_user_ids: + users_coll = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + user_docs = await asyncio.gather( + *[users_coll.document(uid).get() for uid in unique_user_ids] + ) + for u_doc in user_docs: + if u_doc.exists: + user_states_map[u_doc.id] = u_doc.to_dict() sessions = [] - for doc in docs: - data = doc.to_dict() - if data: - u_id = data["userId"] - s_state = data.get("state", {}) - u_state = user_states_map.get(u_id, {}) - merged = self._merge_state(app_state, u_state, s_state) - - sessions.append( - Session( - id=data["id"], - app_name=data["appName"], - user_id=data["userId"], - state=merged, - events=[], - last_update_time=0.0, - ) - ) + for data in sessions_data: + u_id = data["userId"] + s_state = data.get("state", {}) + u_state = user_states_map.get(u_id, {}) + merged = self._merge_state(app_state, u_state, s_state) + + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=merged, + events=[], + last_update_time=0.0, + ) + ) return ListSessionsResponse(sessions=sessions) diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 1445bfe0ef..a7c057baac 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -494,12 +494,15 @@ def collection_side_effect(name): user_doc = mock.MagicMock() user_doc.id = "user1" + user_doc.exists = True user_doc.to_dict.return_value = {"user_key": "user_val"} user_app_doc = mock.MagicMock() user_state_coll.document.return_value = user_app_doc users_coll = mock.MagicMock() user_app_doc.collection.return_value = users_coll - users_coll.get = mock.AsyncMock(return_value=[user_doc]) + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) response = await service.list_sessions(app_name=app_name) @@ -553,12 +556,15 @@ def collection_side_effect(name): user_doc = mock.MagicMock() user_doc.id = "user1" + user_doc.exists = True user_doc.to_dict.return_value = {"user_key": "user_val"} user_app_doc = mock.MagicMock() user_state_coll.document.return_value = user_app_doc users_coll = mock.MagicMock() user_app_doc.collection.return_value = users_coll - users_coll.get = mock.AsyncMock(return_value=[user_doc]) + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) response = await service.list_sessions(app_name=app_name) From 89ef90eb35d4f0854af7e49b5beaf0078a7e9b5d Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:08:22 +0000 Subject: [PATCH 2/6] fix side-effects inside transaction callbacks In append_event --- .../adk/integrations/firestore/firestore_session_service.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 098eafed95..81206a6d96 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -548,9 +548,6 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: current_user.update(user_updates) transaction.set(user_ref, current_user, merge=True) - for k, v in session_updates.items(): - session.state[k] = v - new_revision = current_revision + 1 session_only_state = { k: v @@ -558,6 +555,8 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: if not k.startswith(State.APP_PREFIX) and not k.startswith(State.USER_PREFIX) } + for k, v in session_updates.items(): + session_only_state[k] = v transaction.update( session_ref, { From ef16890e482c44570623829bc8a6eae626a487d6 Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:26:19 +0000 Subject: [PATCH 3/6] resort and update imports positions --- .../firestore/firestore_session_service.py | 41 ++++++------------- .../test_firestore_session_service.py | 36 +++------------- 2 files changed, 18 insertions(+), 59 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 81206a6d96..6f627aa6c0 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import copy from datetime import datetime from datetime import timezone import logging @@ -22,16 +23,10 @@ from typing import Any from typing import cast from typing import Optional -from typing import TYPE_CHECKING - -_SessionLockKey = tuple[str, str, str] - -if TYPE_CHECKING: - from google.cloud import firestore - -from pydantic import BaseModel +from ...errors.already_exists_error import AlreadyExistsError from ...events.event import Event +from ...platform import uuid as platform_uuid from ...sessions import _session_util from ...sessions.base_session_service import BaseSessionService from ...sessions.base_session_service import GetSessionConfig @@ -39,6 +34,16 @@ from ...sessions.session import Session from ...sessions.state import State +try: + from google.cloud import firestore +except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + +_SessionLockKey = tuple[str, str, str] + logger = logging.getLogger("google_adk." + __name__) DEFAULT_ROOT_COLLECTION = "adk-session" @@ -118,14 +123,6 @@ def __init__( root_collection: The root collection name. Defaults to 'adk-session' or the value of ADK_FIRESTORE_ROOT_COLLECTION env var. """ - try: - from google.cloud import firestore - except ImportError as e: - raise ImportError( - "FirestoreSessionService requires google-cloud-firestore. " - "Install it with: pip install google-cloud-firestore" - ) from e - self.client = client or firestore.AsyncClient() self.root_collection = ( root_collection @@ -155,8 +152,6 @@ def _merge_state( session_state: dict[str, Any], ) -> dict[str, Any]: """Merge app, user, and session states into a single state dictionary.""" - import copy - merged_state = copy.deepcopy(session_state) for key, value in app_state.items(): merged_state[State.APP_PREFIX + key] = value @@ -184,11 +179,7 @@ async def create_session( session_id: Optional[str] = None, ) -> Session: """Creates a new session in Firestore.""" - from google.cloud import firestore - if not session_id: - from ...platform import uuid as platform_uuid - session_id = platform_uuid.new_uuid() initial_state = state or {} @@ -229,8 +220,6 @@ async def _create_txn( # 1. Reads snap = await session_ref.get(transaction=transaction) if snap.exists: - from ...errors.already_exists_error import AlreadyExistsError - raise AlreadyExistsError(f"Session {session_id} already exists.") app_snap = await app_ref.get(transaction=transaction) @@ -434,8 +423,6 @@ async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: """Deletes a session and its events from Firestore.""" - from google.cloud import firestore - session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) @firestore.async_transactional # type: ignore[untyped-decorator] @@ -470,8 +457,6 @@ async def _mark_deleting_txn( async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" - from google.cloud import firestore - if event.partial: return event diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index a7c057baac..154d2be67c 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -16,8 +16,13 @@ from unittest import mock +from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event +from google.adk.events.event import EventActions from google.adk.integrations.firestore.firestore_session_service import FirestoreSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.session import Session +from google.cloud import firestore import pytest @@ -68,21 +73,6 @@ def mock_firestore_client(): return client -def test_init_missing_dependency(): - import builtins - - original_import = builtins.__import__ - - def mock_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "google.cloud" and "firestore" in fromlist: - raise ImportError("Mocked import error") - return original_import(name, globals, locals, fromlist, level) - - with mock.patch("builtins.__import__", side_effect=mock_import): - with pytest.raises(ImportError, match="requires google-cloud-firestore"): - FirestoreSessionService() - - @pytest.mark.asyncio async def test_create_session(mock_firestore_client): @@ -108,8 +98,6 @@ async def test_create_session(mock_firestore_client): sessions_ref = user_ref.collection.return_value session_doc_ref = sessions_ref.document.return_value - from google.cloud import firestore - transaction = mock_firestore_client.transaction.return_value transaction.set.assert_called_once() args, kwargs = transaction.set.call_args @@ -247,7 +235,6 @@ async def test_append_event(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") @@ -267,8 +254,6 @@ async def test_append_event(mock_firestore_client): with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): await service.append_event(session, event) - from google.cloud import firestore - transaction = mock_firestore_client.transaction.return_value transaction.set.assert_called() # Invoked for events appends transaction.update.assert_called_once() # Invoked for session revisions @@ -283,7 +268,6 @@ async def test_append_event_with_state_delta(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) @@ -320,8 +304,6 @@ async def test_append_event_with_state_delta(mock_firestore_client): assert session.state["session_key"] == "session_val" - from google.cloud import firestore - transaction.update.assert_called_once() args, kwargs = transaction.update.call_args # In modular Firestore configurations alignments, updating variables mock assertions core setups @@ -334,9 +316,6 @@ async def test_append_event_with_temp_state(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.events.event import Event - from google.adk.events.event import EventActions - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) @@ -589,8 +568,6 @@ async def test_create_session_already_exists(mock_firestore_client): ) doc_snapshot.exists = True - from google.adk.errors.already_exists_error import AlreadyExistsError - with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): with pytest.raises(AlreadyExistsError): await service.create_session( @@ -619,8 +596,6 @@ async def test_get_session_with_config(mock_firestore_client): mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) - from google.adk.sessions.base_session_service import GetSessionConfig - config = GetSessionConfig(after_timestamp=1234567890.0, num_recent_events=5) await service.get_session( @@ -663,7 +638,6 @@ async def to_async_iter(iterable): @pytest.mark.asyncio async def test_append_event_partial(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) - from google.adk.sessions.session import Session session = Session(id="test_session", app_name="test_app", user_id="test_user") From 85236f5017248959dbe4660842126759ab33ee14 Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:45:26 +0000 Subject: [PATCH 4/6] refactor: update state handling logic to match DB session service --- .../firestore/firestore_session_service.py | 86 ++++++++----------- .../test_firestore_session_service.py | 1 + 2 files changed, 37 insertions(+), 50 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 6f627aa6c0..43a6514e39 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -15,12 +15,14 @@ from __future__ import annotations import asyncio +from contextlib import asynccontextmanager import copy from datetime import datetime from datetime import timezone import logging import os from typing import Any +from typing import AsyncGenerator from typing import cast from typing import Optional @@ -46,6 +48,11 @@ logger = logging.getLogger("google_adk." + __name__) +_STALE_SESSION_ERROR_MESSAGE = ( + "The session has been modified in storage since it was loaded. " + "Please reload the session before appending more events." +) + DEFAULT_ROOT_COLLECTION = "adk-session" DEFAULT_SESSIONS_COLLECTION = "sessions" DEFAULT_EVENTS_COLLECTION = "events" @@ -53,40 +60,6 @@ DEFAULT_USER_STATE_COLLECTION = "user_states" -class _SessionLockContext: - """Async context manager for serializing event appends for the same session.""" - - def __init__( - self, service: FirestoreSessionService, lock_key: _SessionLockKey - ): - self.service = service - self.lock_key = lock_key - self.lock: Optional[asyncio.Lock] = None - - async def __aenter__(self) -> None: - async with self.service._session_locks_guard: - lock = self.service._session_locks.get(self.lock_key, asyncio.Lock()) - self.service._session_locks[self.lock_key] = lock - self.service._session_lock_ref_count[self.lock_key] = ( - self.service._session_lock_ref_count.get(self.lock_key, 0) + 1 - ) - self.lock = lock - await self.lock.acquire() - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self.lock: - self.lock.release() - async with self.service._session_locks_guard: - remaining = ( - self.service._session_lock_ref_count.get(self.lock_key, 0) - 1 - ) - if remaining <= 0 and not self.lock.locked(): - self.service._session_lock_ref_count.pop(self.lock_key, None) - self.service._session_locks.pop(self.lock_key, None) - else: - self.service._session_lock_ref_count[self.lock_key] = remaining - - class FirestoreSessionService(BaseSessionService): # type: ignore[misc] """Session service that uses Google Cloud Firestore as the backend. @@ -139,11 +112,32 @@ def __init__( self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION - def _with_session_lock( + @asynccontextmanager + async def _with_session_lock( self, *, app_name: str, user_id: str, session_id: str - ) -> _SessionLockContext: + ) -> AsyncGenerator[None]: """Serializes event appends for the same session within this process.""" - return _SessionLockContext(self, (app_name, user_id, session_id)) + lock_key = (app_name, user_id, session_id) + async with self._session_locks_guard: + lock = self._session_locks.get(lock_key) + if lock is None: + lock = asyncio.Lock() + self._session_locks[lock_key] = lock + self._session_lock_ref_count[lock_key] = ( + self._session_lock_ref_count.get(lock_key, 0) + 1 + ) + + try: + async with lock: + yield + finally: + async with self._session_locks_guard: + remaining = self._session_lock_ref_count.get(lock_key, 0) - 1 + if remaining <= 0 and not lock.locked(): + self._session_lock_ref_count.pop(lock_key, None) + self._session_locks.pop(lock_key, None) + else: + self._session_lock_ref_count[lock_key] = remaining @staticmethod def _merge_state( @@ -508,10 +502,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: if session._storage_update_marker is not None: if session._storage_update_marker != str(current_revision): - raise ValueError( - "The session has been modified in storage since it was loaded. " - "Please reload the session before appending more events." - ) + raise ValueError(_STALE_SESSION_ERROR_MESSAGE) app_snap = ( await app_ref.get(transaction=transaction) if app_updates else None @@ -534,18 +525,12 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: transaction.set(user_ref, current_user, merge=True) new_revision = current_revision + 1 - session_only_state = { - k: v - for k, v in session.state.items() - if not k.startswith(State.APP_PREFIX) - and not k.startswith(State.USER_PREFIX) - } - for k, v in session_updates.items(): - session_only_state[k] = v + current_session_state = session_doc.get("state", {}) + current_session_state.update(session_updates) transaction.update( session_ref, { - "state": session_only_state, + "state": current_session_state, "updateTime": firestore.SERVER_TIMESTAMP, "revision": new_revision, }, @@ -571,6 +556,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: transaction_obj = self.client.transaction() new_revision_count = await _append_txn(transaction_obj) session._storage_update_marker = str(new_revision_count) + session.last_update_time = event.timestamp await super().append_event(session, event) return event diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 154d2be67c..112abcd95a 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -261,6 +261,7 @@ async def test_append_event(mock_firestore_client): args, kwargs = transaction.update.call_args assert args[1]["revision"] == 1 assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + assert session.last_update_time == event.timestamp @pytest.mark.asyncio From 0ba4d211751150d23de1ad5b1e0e01d992d034da Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:46:01 +0000 Subject: [PATCH 5/6] fix: initialize session revision to 0 --- .../adk/integrations/firestore/firestore_session_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 43a6514e39..1b198add5d 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -204,7 +204,7 @@ async def create_session( "state": session_state, "createTime": now, "updateTime": now, - "revision": 1, + "revision": 0, } @firestore.async_transactional # type: ignore[untyped-decorator] From 485c3bd5d176f588d1adfdd4b100b616e81a4cb1 Mon Sep 17 00:00:00 2001 From: Alpha Du Date: Mon, 11 May 2026 09:56:26 +0000 Subject: [PATCH 6/6] refactor: update list_sessions to use firestore get_all method --- .../firestore/firestore_session_service.py | 6 ++---- .../firestore/test_firestore_session_service.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 1b198add5d..df4449af15 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -386,10 +386,8 @@ async def list_sessions( .document(app_name) .collection("users") ) - user_docs = await asyncio.gather( - *[users_coll.document(uid).get() for uid in unique_user_ids] - ) - for u_doc in user_docs: + refs = [users_coll.document(uid) for uid in sorted(unique_user_ids)] + async for u_doc in self.client.get_all(refs): if u_doc.exists: user_states_map[u_doc.id] = u_doc.to_dict() diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 112abcd95a..8043b94975 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -482,7 +482,11 @@ def collection_side_effect(name): user_app_doc.collection.return_value = users_coll user_doc_ref = mock.MagicMock() users_coll.document.return_value = user_doc_ref - user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + async def mock_get_all(refs): + yield user_doc + + mock_firestore_client.get_all = mock_get_all response = await service.list_sessions(app_name=app_name) @@ -544,7 +548,11 @@ def collection_side_effect(name): user_app_doc.collection.return_value = users_coll user_doc_ref = mock.MagicMock() users_coll.document.return_value = user_doc_ref - user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + async def mock_get_all(refs): + yield user_doc + + mock_firestore_client.get_all = mock_get_all response = await service.list_sessions(app_name=app_name)