-
Notifications
You must be signed in to change notification settings - Fork 3.4k
refactor(firestore) : comprehensive refactoring for FirestoreSessionService #5663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9db6db5
89ef90e
ef16890
85236f5
0ba4d21
485c3bd
df12ace
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,33 +16,43 @@ | |
|
|
||
| import asyncio | ||
| from contextlib import asynccontextmanager | ||
| import copy | ||
| from datetime import datetime | ||
| from datetime import timezone | ||
| import logging | ||
| import os | ||
| from typing import Any | ||
| from typing import AsyncIterator | ||
| from typing import AsyncGenerator | ||
| from typing import cast | ||
| from typing import Optional | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| _SessionLockKey = tuple[str, str, str] | ||
|
|
||
| if TYPE_CHECKING: | ||
| from google.cloud import firestore | ||
|
|
||
| from pydantic import BaseModel | ||
|
|
||
| from ...errors.already_exists_error import AlreadyExistsError | ||
| from ...events.event import Event | ||
| from ...platform import uuid as platform_uuid | ||
| from ...sessions import _session_util | ||
| from ...sessions.base_session_service import BaseSessionService | ||
| from ...sessions.base_session_service import GetSessionConfig | ||
| from ...sessions.base_session_service import ListSessionsResponse | ||
| from ...sessions.session import Session | ||
| from ...sessions.state import State | ||
|
|
||
| try: | ||
| from google.cloud import firestore | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "FirestoreSessionService requires google-cloud-firestore. " | ||
| "Install it with: pip install google-cloud-firestore" | ||
| ) from e | ||
|
|
||
| _SessionLockKey = tuple[str, str, str] | ||
|
|
||
| logger = logging.getLogger("google_adk." + __name__) | ||
|
|
||
| _STALE_SESSION_ERROR_MESSAGE = ( | ||
| "The session has been modified in storage since it was loaded. " | ||
| "Please reload the session before appending more events." | ||
| ) | ||
|
|
||
| DEFAULT_ROOT_COLLECTION = "adk-session" | ||
| DEFAULT_SESSIONS_COLLECTION = "sessions" | ||
| DEFAULT_EVENTS_COLLECTION = "events" | ||
|
|
@@ -86,14 +96,6 @@ def __init__( | |
| root_collection: The root collection name. Defaults to 'adk-session' or | ||
| the value of ADK_FIRESTORE_ROOT_COLLECTION env var. | ||
| """ | ||
| try: | ||
| from google.cloud import firestore | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "FirestoreSessionService requires google-cloud-firestore. " | ||
| "Install it with: pip install google-cloud-firestore" | ||
| ) from e | ||
|
|
||
| self.client = client or firestore.AsyncClient() | ||
| self.root_collection = ( | ||
| root_collection | ||
|
|
@@ -113,12 +115,14 @@ def __init__( | |
| @asynccontextmanager | ||
| async def _with_session_lock( | ||
| self, *, app_name: str, user_id: str, session_id: str | ||
| ) -> AsyncIterator[None]: | ||
| ) -> AsyncGenerator[None]: | ||
| """Serializes event appends for the same session within this process.""" | ||
| lock_key = (app_name, user_id, session_id) | ||
| async with self._session_locks_guard: | ||
| lock = self._session_locks.get(lock_key, asyncio.Lock()) | ||
| self._session_locks[lock_key] = lock | ||
| lock = self._session_locks.get(lock_key) | ||
| if lock is None: | ||
| lock = asyncio.Lock() | ||
| self._session_locks[lock_key] = lock | ||
| self._session_lock_ref_count[lock_key] = ( | ||
| self._session_lock_ref_count.get(lock_key, 0) + 1 | ||
| ) | ||
|
|
@@ -142,8 +146,6 @@ def _merge_state( | |
| session_state: dict[str, Any], | ||
| ) -> dict[str, Any]: | ||
| """Merge app, user, and session states into a single state dictionary.""" | ||
| import copy | ||
|
|
||
| merged_state = copy.deepcopy(session_state) | ||
| for key, value in app_state.items(): | ||
| merged_state[State.APP_PREFIX + key] = value | ||
|
|
@@ -171,11 +173,7 @@ async def create_session( | |
| session_id: Optional[str] = None, | ||
| ) -> Session: | ||
| """Creates a new session in Firestore.""" | ||
| from google.cloud import firestore | ||
|
|
||
| if not session_id: | ||
| from ...platform import uuid as platform_uuid | ||
|
|
||
| session_id = platform_uuid.new_uuid() | ||
|
|
||
| initial_state = state or {} | ||
|
|
@@ -206,57 +204,42 @@ async def create_session( | |
| "state": session_state, | ||
| "createTime": now, | ||
| "updateTime": now, | ||
| "revision": 1, | ||
| "revision": 0, | ||
| } | ||
|
|
||
| @firestore.async_transactional # type: ignore[untyped-decorator] | ||
| async def _create_txn(transaction: firestore.AsyncTransaction) -> None: | ||
| async def _create_txn( | ||
| transaction: firestore.AsyncTransaction, | ||
| ) -> tuple[dict[str, Any], dict[str, Any]]: | ||
| # 1. Reads | ||
| snap = await session_ref.get(transaction=transaction) | ||
| if snap.exists: | ||
| from ...errors.already_exists_error import AlreadyExistsError | ||
|
|
||
| raise AlreadyExistsError(f"Session {session_id} already exists.") | ||
|
|
||
| app_snap = ( | ||
| await app_ref.get(transaction=transaction) | ||
| if app_state_delta | ||
| else None | ||
| app_snap = await app_ref.get(transaction=transaction) | ||
| user_snap = await user_ref.get(transaction=transaction) | ||
|
|
||
| current_app: dict[str, Any] = ( | ||
| (app_snap.to_dict() or {}) if app_snap.exists else {} | ||
| ) | ||
| user_snap = ( | ||
| await user_ref.get(transaction=transaction) | ||
| if user_state_delta | ||
| else None | ||
| current_user: dict[str, Any] = ( | ||
| (user_snap.to_dict() or {}) if user_snap.exists else {} | ||
| ) | ||
|
|
||
| # 2. Writes | ||
| if app_state_delta: | ||
| current_app = ( | ||
| app_snap.to_dict() if (app_snap and app_snap.exists) else {} | ||
| ) | ||
| current_app.update(app_state_delta) | ||
| transaction.set(app_ref, current_app, merge=True) | ||
|
|
||
| if user_state_delta: | ||
| current_user = ( | ||
| user_snap.to_dict() if (user_snap and user_snap.exists) else {} | ||
| ) | ||
| current_user.update(user_state_delta) | ||
| transaction.set(user_ref, current_user, merge=True) | ||
|
|
||
| transaction.set(session_ref, session_data) | ||
| return current_app, current_user | ||
|
|
||
| transaction_obj = self.client.transaction() | ||
| await _create_txn(transaction_obj) | ||
|
|
||
| storage_app_doc = await app_ref.get() | ||
| storage_app_state = ( | ||
| storage_app_doc.to_dict() if storage_app_doc.exists else {} | ||
| ) | ||
| storage_user_doc = await user_ref.get() | ||
| storage_user_state = ( | ||
| storage_user_doc.to_dict() if storage_user_doc.exists else {} | ||
| ) | ||
| storage_app_state, storage_user_state = await _create_txn(transaction_obj) | ||
|
|
||
| merged_state = self._merge_state( | ||
| storage_app_state, storage_user_state, session_state | ||
|
|
@@ -294,7 +277,7 @@ async def get_session( | |
| if not data: | ||
| return None | ||
|
|
||
| # Fetch events | ||
| # Fetch events and shared state concurrently | ||
| events_ref = session_ref.collection(self.events_collection) | ||
| query = events_ref.order_by("timestamp") | ||
|
|
||
|
|
@@ -305,18 +288,6 @@ async def get_session( | |
| if config.num_recent_events: | ||
| query = query.limit_to_last(config.num_recent_events) | ||
|
|
||
| events_docs = await query.get() | ||
| events = [] | ||
| for event_doc in events_docs: | ||
| event_data = event_doc.to_dict() | ||
| if event_data and "event_data" in event_data: | ||
| ed = event_data["event_data"] | ||
| events.append(Event.model_validate(ed)) | ||
|
|
||
| # Let's continue getting session. | ||
| session_state = data.get("state", {}) | ||
|
|
||
| # Fetch shared state | ||
| app_ref = self.client.collection(self.app_state_collection).document( | ||
| app_name | ||
| ) | ||
|
|
@@ -326,9 +297,22 @@ async def get_session( | |
| .collection("users") | ||
| .document(user_id) | ||
| ) | ||
| app_doc = await app_ref.get() | ||
|
|
||
| events_docs, app_doc, user_doc = await asyncio.gather( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||
| query.get(), | ||
| app_ref.get(), | ||
| user_ref.get(), | ||
| ) | ||
|
|
||
| events = [] | ||
| for event_doc in events_docs: | ||
| event_data = event_doc.to_dict() | ||
| if event_data and "event_data" in event_data: | ||
| ed = event_data["event_data"] | ||
| events.append(Event.model_validate(ed)) | ||
|
|
||
| session_state = data.get("state", {}) | ||
| app_state = app_doc.to_dict() if app_doc.exists else {} | ||
| user_doc = await user_ref.get() | ||
| user_state = user_doc.to_dict() if user_doc.exists else {} | ||
|
|
||
| merged_state = self._merge_state(app_state, user_state, session_state) | ||
|
|
@@ -374,6 +358,8 @@ async def list_sessions( | |
| ) | ||
| docs = await query.get() | ||
|
|
||
| sessions_data = [d for doc in docs if (d := doc.to_dict())] | ||
|
|
||
| # Fetch shared state once | ||
| app_ref = self.client.collection(self.app_state_collection).document( | ||
| app_name | ||
|
|
@@ -393,43 +379,42 @@ async def list_sessions( | |
| if user_doc.exists: | ||
| user_states_map[user_id] = user_doc.to_dict() | ||
| else: | ||
| users_ref = ( | ||
| self.client.collection(self.user_state_collection) | ||
| .document(app_name) | ||
| .collection("users") | ||
| ) | ||
| users_docs = await users_ref.get() | ||
| for u_doc in users_docs: | ||
| user_states_map[u_doc.id] = u_doc.to_dict() | ||
| unique_user_ids = {s["userId"] for s in sessions_data if "userId" in s} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The batch fetch using |
||
| if unique_user_ids: | ||
| users_coll = ( | ||
| self.client.collection(self.user_state_collection) | ||
| .document(app_name) | ||
| .collection("users") | ||
| ) | ||
| refs = [users_coll.document(uid) for uid in sorted(unique_user_ids)] | ||
| async for u_doc in self.client.get_all(refs): | ||
| if u_doc.exists: | ||
| user_states_map[u_doc.id] = u_doc.to_dict() | ||
|
|
||
| sessions = [] | ||
| for doc in docs: | ||
| data = doc.to_dict() | ||
| if data: | ||
| u_id = data["userId"] | ||
| s_state = data.get("state", {}) | ||
| u_state = user_states_map.get(u_id, {}) | ||
| merged = self._merge_state(app_state, u_state, s_state) | ||
|
|
||
| sessions.append( | ||
| Session( | ||
| id=data["id"], | ||
| app_name=data["appName"], | ||
| user_id=data["userId"], | ||
| state=merged, | ||
| events=[], | ||
| last_update_time=0.0, | ||
| ) | ||
| ) | ||
| for data in sessions_data: | ||
| u_id = data["userId"] | ||
| s_state = data.get("state", {}) | ||
| u_state = user_states_map.get(u_id, {}) | ||
| merged = self._merge_state(app_state, u_state, s_state) | ||
|
|
||
| sessions.append( | ||
| Session( | ||
| id=data["id"], | ||
| app_name=data["appName"], | ||
| user_id=data["userId"], | ||
| state=merged, | ||
| events=[], | ||
| last_update_time=0.0, | ||
| ) | ||
| ) | ||
|
|
||
| return ListSessionsResponse(sessions=sessions) | ||
|
|
||
| async def delete_session( | ||
| self, *, app_name: str, user_id: str, session_id: str | ||
| ) -> None: | ||
| """Deletes a session and its events from Firestore.""" | ||
| from google.cloud import firestore | ||
|
|
||
| session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) | ||
|
|
||
| @firestore.async_transactional # type: ignore[untyped-decorator] | ||
|
|
@@ -464,8 +449,6 @@ async def _mark_deleting_txn( | |
|
|
||
| async def append_event(self, session: Session, event: Event) -> Event: | ||
| """Appends an event to a session in Firestore.""" | ||
| from google.cloud import firestore | ||
|
|
||
| if event.partial: | ||
| return event | ||
|
|
||
|
|
@@ -517,10 +500,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: | |
|
|
||
| if session._storage_update_marker is not None: | ||
| if session._storage_update_marker != str(current_revision): | ||
| raise ValueError( | ||
| "The session has been modified in storage since it was loaded. " | ||
| "Please reload the session before appending more events." | ||
| ) | ||
| raise ValueError(_STALE_SESSION_ERROR_MESSAGE) | ||
|
|
||
| app_snap = ( | ||
| await app_ref.get(transaction=transaction) if app_updates else None | ||
|
|
@@ -542,20 +522,13 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: | |
| current_user.update(user_updates) | ||
| transaction.set(user_ref, current_user, merge=True) | ||
|
|
||
| for k, v in session_updates.items(): | ||
| session.state[k] = v | ||
|
|
||
| new_revision = current_revision + 1 | ||
| session_only_state = { | ||
| k: v | ||
| for k, v in session.state.items() | ||
| if not k.startswith(State.APP_PREFIX) | ||
| and not k.startswith(State.USER_PREFIX) | ||
| } | ||
| current_session_state = session_doc.get("state", {}) | ||
| current_session_state.update(session_updates) | ||
| transaction.update( | ||
| session_ref, | ||
| { | ||
| "state": session_only_state, | ||
| "state": current_session_state, | ||
| "updateTime": firestore.SERVER_TIMESTAMP, | ||
| "revision": new_revision, | ||
| }, | ||
|
|
@@ -581,6 +554,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: | |
| transaction_obj = self.client.transaction() | ||
| new_revision_count = await _append_txn(transaction_obj) | ||
| session._storage_update_marker = str(new_revision_count) | ||
| session.last_update_time = event.timestamp | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating |
||
|
|
||
| await super().append_event(session, event) | ||
| return event | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
AsyncGeneratoris more appropriate here thanAsyncIteratorsince it's an async context manager that yields. Good catch.