diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 83b97c33c2..df4449af15 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -16,24 +16,19 @@ import asyncio from contextlib import asynccontextmanager +import copy from datetime import datetime from datetime import timezone import logging import os from typing import Any -from typing import AsyncIterator +from typing import AsyncGenerator from typing import cast from typing import Optional -from typing import TYPE_CHECKING - -_SessionLockKey = tuple[str, str, str] - -if TYPE_CHECKING: - from google.cloud import firestore - -from pydantic import BaseModel +from ...errors.already_exists_error import AlreadyExistsError from ...events.event import Event +from ...platform import uuid as platform_uuid from ...sessions import _session_util from ...sessions.base_session_service import BaseSessionService from ...sessions.base_session_service import GetSessionConfig @@ -41,8 +36,23 @@ from ...sessions.session import Session from ...sessions.state import State +try: + from google.cloud import firestore +except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + +_SessionLockKey = tuple[str, str, str] + logger = logging.getLogger("google_adk." + __name__) +_STALE_SESSION_ERROR_MESSAGE = ( + "The session has been modified in storage since it was loaded. " + "Please reload the session before appending more events." +) + DEFAULT_ROOT_COLLECTION = "adk-session" DEFAULT_SESSIONS_COLLECTION = "sessions" DEFAULT_EVENTS_COLLECTION = "events" @@ -86,14 +96,6 @@ def __init__( root_collection: The root collection name. Defaults to 'adk-session' or the value of ADK_FIRESTORE_ROOT_COLLECTION env var. """ - try: - from google.cloud import firestore - except ImportError as e: - raise ImportError( - "FirestoreSessionService requires google-cloud-firestore. " - "Install it with: pip install google-cloud-firestore" - ) from e - self.client = client or firestore.AsyncClient() self.root_collection = ( root_collection @@ -113,12 +115,14 @@ def __init__( @asynccontextmanager async def _with_session_lock( self, *, app_name: str, user_id: str, session_id: str - ) -> AsyncIterator[None]: + ) -> AsyncGenerator[None]: """Serializes event appends for the same session within this process.""" lock_key = (app_name, user_id, session_id) async with self._session_locks_guard: - lock = self._session_locks.get(lock_key, asyncio.Lock()) - self._session_locks[lock_key] = lock + lock = self._session_locks.get(lock_key) + if lock is None: + lock = asyncio.Lock() + self._session_locks[lock_key] = lock self._session_lock_ref_count[lock_key] = ( self._session_lock_ref_count.get(lock_key, 0) + 1 ) @@ -142,8 +146,6 @@ def _merge_state( session_state: dict[str, Any], ) -> dict[str, Any]: """Merge app, user, and session states into a single state dictionary.""" - import copy - merged_state = copy.deepcopy(session_state) for key, value in app_state.items(): merged_state[State.APP_PREFIX + key] = value @@ -171,11 +173,7 @@ async def create_session( session_id: Optional[str] = None, ) -> Session: """Creates a new session in Firestore.""" - from google.cloud import firestore - if not session_id: - from ...platform import uuid as platform_uuid - session_id = platform_uuid.new_uuid() initial_state = state or {} @@ -206,57 +204,42 @@ async def create_session( "state": session_state, "createTime": now, "updateTime": now, - "revision": 1, + "revision": 0, } @firestore.async_transactional # type: ignore[untyped-decorator] - async def _create_txn(transaction: firestore.AsyncTransaction) -> None: + async def _create_txn( + transaction: firestore.AsyncTransaction, + ) -> tuple[dict[str, Any], dict[str, Any]]: # 1. Reads snap = await session_ref.get(transaction=transaction) if snap.exists: - from ...errors.already_exists_error import AlreadyExistsError - raise AlreadyExistsError(f"Session {session_id} already exists.") - app_snap = ( - await app_ref.get(transaction=transaction) - if app_state_delta - else None + app_snap = await app_ref.get(transaction=transaction) + user_snap = await user_ref.get(transaction=transaction) + + current_app: dict[str, Any] = ( + (app_snap.to_dict() or {}) if app_snap.exists else {} ) - user_snap = ( - await user_ref.get(transaction=transaction) - if user_state_delta - else None + current_user: dict[str, Any] = ( + (user_snap.to_dict() or {}) if user_snap.exists else {} ) # 2. Writes if app_state_delta: - current_app = ( - app_snap.to_dict() if (app_snap and app_snap.exists) else {} - ) current_app.update(app_state_delta) transaction.set(app_ref, current_app, merge=True) if user_state_delta: - current_user = ( - user_snap.to_dict() if (user_snap and user_snap.exists) else {} - ) current_user.update(user_state_delta) transaction.set(user_ref, current_user, merge=True) transaction.set(session_ref, session_data) + return current_app, current_user transaction_obj = self.client.transaction() - await _create_txn(transaction_obj) - - storage_app_doc = await app_ref.get() - storage_app_state = ( - storage_app_doc.to_dict() if storage_app_doc.exists else {} - ) - storage_user_doc = await user_ref.get() - storage_user_state = ( - storage_user_doc.to_dict() if storage_user_doc.exists else {} - ) + storage_app_state, storage_user_state = await _create_txn(transaction_obj) merged_state = self._merge_state( storage_app_state, storage_user_state, session_state @@ -294,7 +277,7 @@ async def get_session( if not data: return None - # Fetch events + # Fetch events and shared state concurrently events_ref = session_ref.collection(self.events_collection) query = events_ref.order_by("timestamp") @@ -305,18 +288,6 @@ async def get_session( if config.num_recent_events: query = query.limit_to_last(config.num_recent_events) - events_docs = await query.get() - events = [] - for event_doc in events_docs: - event_data = event_doc.to_dict() - if event_data and "event_data" in event_data: - ed = event_data["event_data"] - events.append(Event.model_validate(ed)) - - # Let's continue getting session. - session_state = data.get("state", {}) - - # Fetch shared state app_ref = self.client.collection(self.app_state_collection).document( app_name ) @@ -326,9 +297,22 @@ async def get_session( .collection("users") .document(user_id) ) - app_doc = await app_ref.get() + + events_docs, app_doc, user_doc = await asyncio.gather( + query.get(), + app_ref.get(), + user_ref.get(), + ) + + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + ed = event_data["event_data"] + events.append(Event.model_validate(ed)) + + session_state = data.get("state", {}) app_state = app_doc.to_dict() if app_doc.exists else {} - user_doc = await user_ref.get() user_state = user_doc.to_dict() if user_doc.exists else {} merged_state = self._merge_state(app_state, user_state, session_state) @@ -374,6 +358,8 @@ async def list_sessions( ) docs = await query.get() + sessions_data = [d for doc in docs if (d := doc.to_dict())] + # Fetch shared state once app_ref = self.client.collection(self.app_state_collection).document( app_name @@ -393,34 +379,35 @@ async def list_sessions( if user_doc.exists: user_states_map[user_id] = user_doc.to_dict() else: - users_ref = ( - self.client.collection(self.user_state_collection) - .document(app_name) - .collection("users") - ) - users_docs = await users_ref.get() - for u_doc in users_docs: - user_states_map[u_doc.id] = u_doc.to_dict() + unique_user_ids = {s["userId"] for s in sessions_data if "userId" in s} + if unique_user_ids: + users_coll = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + refs = [users_coll.document(uid) for uid in sorted(unique_user_ids)] + async for u_doc in self.client.get_all(refs): + if u_doc.exists: + user_states_map[u_doc.id] = u_doc.to_dict() sessions = [] - for doc in docs: - data = doc.to_dict() - if data: - u_id = data["userId"] - s_state = data.get("state", {}) - u_state = user_states_map.get(u_id, {}) - merged = self._merge_state(app_state, u_state, s_state) - - sessions.append( - Session( - id=data["id"], - app_name=data["appName"], - user_id=data["userId"], - state=merged, - events=[], - last_update_time=0.0, - ) - ) + for data in sessions_data: + u_id = data["userId"] + s_state = data.get("state", {}) + u_state = user_states_map.get(u_id, {}) + merged = self._merge_state(app_state, u_state, s_state) + + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=merged, + events=[], + last_update_time=0.0, + ) + ) return ListSessionsResponse(sessions=sessions) @@ -428,8 +415,6 @@ async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: """Deletes a session and its events from Firestore.""" - from google.cloud import firestore - session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) @firestore.async_transactional # type: ignore[untyped-decorator] @@ -464,8 +449,6 @@ async def _mark_deleting_txn( async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" - from google.cloud import firestore - if event.partial: return event @@ -517,10 +500,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: if session._storage_update_marker is not None: if session._storage_update_marker != str(current_revision): - raise ValueError( - "The session has been modified in storage since it was loaded. " - "Please reload the session before appending more events." - ) + raise ValueError(_STALE_SESSION_ERROR_MESSAGE) app_snap = ( await app_ref.get(transaction=transaction) if app_updates else None @@ -542,20 +522,13 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: current_user.update(user_updates) transaction.set(user_ref, current_user, merge=True) - for k, v in session_updates.items(): - session.state[k] = v - new_revision = current_revision + 1 - session_only_state = { - k: v - for k, v in session.state.items() - if not k.startswith(State.APP_PREFIX) - and not k.startswith(State.USER_PREFIX) - } + current_session_state = session_doc.get("state", {}) + current_session_state.update(session_updates) transaction.update( session_ref, { - "state": session_only_state, + "state": current_session_state, "updateTime": firestore.SERVER_TIMESTAMP, "revision": new_revision, }, @@ -581,6 +554,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: transaction_obj = self.client.transaction() new_revision_count = await _append_txn(transaction_obj) session._storage_update_marker = str(new_revision_count) + session.last_update_time = event.timestamp await super().append_event(session, event) return event diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 1445bfe0ef..8043b94975 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -16,8 +16,13 @@ from unittest import mock +from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event +from google.adk.events.event import EventActions from google.adk.integrations.firestore.firestore_session_service import FirestoreSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.session import Session +from google.cloud import firestore import pytest @@ -68,21 +73,6 @@ def mock_firestore_client(): return client -def test_init_missing_dependency(): - import builtins - - original_import = builtins.__import__ - - def mock_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "google.cloud" and "firestore" in fromlist: - raise ImportError("Mocked import error") - return original_import(name, globals, locals, fromlist, level) - - with mock.patch("builtins.__import__", side_effect=mock_import): - with pytest.raises(ImportError, match="requires google-cloud-firestore"): - FirestoreSessionService() - - @pytest.mark.asyncio async def test_create_session(mock_firestore_client): @@ -108,8 +98,6 @@ async def test_create_session(mock_firestore_client): sessions_ref = user_ref.collection.return_value session_doc_ref = sessions_ref.document.return_value - from google.cloud import firestore - transaction = mock_firestore_client.transaction.return_value transaction.set.assert_called_once() args, kwargs = transaction.set.call_args @@ -247,7 +235,6 @@ async def test_append_event(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") @@ -267,8 +254,6 @@ async def test_append_event(mock_firestore_client): with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): await service.append_event(session, event) - from google.cloud import firestore - transaction = mock_firestore_client.transaction.return_value transaction.set.assert_called() # Invoked for events appends transaction.update.assert_called_once() # Invoked for session revisions @@ -276,6 +261,7 @@ async def test_append_event(mock_firestore_client): args, kwargs = transaction.update.call_args assert args[1]["revision"] == 1 assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + assert session.last_update_time == event.timestamp @pytest.mark.asyncio @@ -283,7 +269,6 @@ async def test_append_event_with_state_delta(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) @@ -320,8 +305,6 @@ async def test_append_event_with_state_delta(mock_firestore_client): assert session.state["session_key"] == "session_val" - from google.cloud import firestore - transaction.update.assert_called_once() args, kwargs = transaction.update.call_args # In modular Firestore configurations alignments, updating variables mock assertions core setups @@ -334,9 +317,6 @@ async def test_append_event_with_temp_state(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" - from google.adk.events.event import Event - from google.adk.events.event import EventActions - from google.adk.sessions.session import Session session = Session(id="test_session", app_name=app_name, user_id=user_id) @@ -494,12 +474,19 @@ def collection_side_effect(name): user_doc = mock.MagicMock() user_doc.id = "user1" + user_doc.exists = True user_doc.to_dict.return_value = {"user_key": "user_val"} user_app_doc = mock.MagicMock() user_state_coll.document.return_value = user_app_doc users_coll = mock.MagicMock() user_app_doc.collection.return_value = users_coll - users_coll.get = mock.AsyncMock(return_value=[user_doc]) + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + + async def mock_get_all(refs): + yield user_doc + + mock_firestore_client.get_all = mock_get_all response = await service.list_sessions(app_name=app_name) @@ -553,12 +540,19 @@ def collection_side_effect(name): user_doc = mock.MagicMock() user_doc.id = "user1" + user_doc.exists = True user_doc.to_dict.return_value = {"user_key": "user_val"} user_app_doc = mock.MagicMock() user_state_coll.document.return_value = user_app_doc users_coll = mock.MagicMock() user_app_doc.collection.return_value = users_coll - users_coll.get = mock.AsyncMock(return_value=[user_doc]) + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + + async def mock_get_all(refs): + yield user_doc + + mock_firestore_client.get_all = mock_get_all response = await service.list_sessions(app_name=app_name) @@ -583,8 +577,6 @@ async def test_create_session_already_exists(mock_firestore_client): ) doc_snapshot.exists = True - from google.adk.errors.already_exists_error import AlreadyExistsError - with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): with pytest.raises(AlreadyExistsError): await service.create_session( @@ -613,8 +605,6 @@ async def test_get_session_with_config(mock_firestore_client): mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) - from google.adk.sessions.base_session_service import GetSessionConfig - config = GetSessionConfig(after_timestamp=1234567890.0, num_recent_events=5) await service.get_session( @@ -657,7 +647,6 @@ async def to_async_iter(iterable): @pytest.mark.asyncio async def test_append_event_partial(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) - from google.adk.sessions.session import Session session = Session(id="test_session", app_name="test_app", user_id="test_user")