Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 87 additions & 113 deletions src/google/adk/integrations/firestore/firestore_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,43 @@

import asyncio
from contextlib import asynccontextmanager
import copy
from datetime import datetime
from datetime import timezone
import logging
import os
from typing import Any
from typing import AsyncIterator
from typing import AsyncGenerator
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING

_SessionLockKey = tuple[str, str, str]

if TYPE_CHECKING:
from google.cloud import firestore

from pydantic import BaseModel

from ...errors.already_exists_error import AlreadyExistsError
from ...events.event import Event
from ...platform import uuid as platform_uuid
from ...sessions import _session_util
from ...sessions.base_session_service import BaseSessionService
from ...sessions.base_session_service import GetSessionConfig
from ...sessions.base_session_service import ListSessionsResponse
from ...sessions.session import Session
from ...sessions.state import State

try:
from google.cloud import firestore
except ImportError as e:
raise ImportError(
"FirestoreSessionService requires google-cloud-firestore. "
"Install it with: pip install google-cloud-firestore"
) from e

_SessionLockKey = tuple[str, str, str]

logger = logging.getLogger("google_adk." + __name__)

_STALE_SESSION_ERROR_MESSAGE = (
"The session has been modified in storage since it was loaded. "
"Please reload the session before appending more events."
)

DEFAULT_ROOT_COLLECTION = "adk-session"
DEFAULT_SESSIONS_COLLECTION = "sessions"
DEFAULT_EVENTS_COLLECTION = "events"
Expand Down Expand Up @@ -86,14 +96,6 @@ def __init__(
root_collection: The root collection name. Defaults to 'adk-session' or
the value of ADK_FIRESTORE_ROOT_COLLECTION env var.
"""
try:
from google.cloud import firestore
except ImportError as e:
raise ImportError(
"FirestoreSessionService requires google-cloud-firestore. "
"Install it with: pip install google-cloud-firestore"
) from e

self.client = client or firestore.AsyncClient()
self.root_collection = (
root_collection
Expand All @@ -113,12 +115,14 @@ def __init__(
@asynccontextmanager
async def _with_session_lock(
self, *, app_name: str, user_id: str, session_id: str
) -> AsyncIterator[None]:
) -> AsyncGenerator[None]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using AsyncGenerator is more appropriate here than AsyncIterator since it's an async context manager that yields. Good catch.

"""Serializes event appends for the same session within this process."""
lock_key = (app_name, user_id, session_id)
async with self._session_locks_guard:
lock = self._session_locks.get(lock_key, asyncio.Lock())
self._session_locks[lock_key] = lock
lock = self._session_locks.get(lock_key)
if lock is None:
lock = asyncio.Lock()
self._session_locks[lock_key] = lock
self._session_lock_ref_count[lock_key] = (
self._session_lock_ref_count.get(lock_key, 0) + 1
)
Expand All @@ -142,8 +146,6 @@ def _merge_state(
session_state: dict[str, Any],
) -> dict[str, Any]:
"""Merge app, user, and session states into a single state dictionary."""
import copy

merged_state = copy.deepcopy(session_state)
for key, value in app_state.items():
merged_state[State.APP_PREFIX + key] = value
Expand Down Expand Up @@ -171,11 +173,7 @@ async def create_session(
session_id: Optional[str] = None,
) -> Session:
"""Creates a new session in Firestore."""
from google.cloud import firestore

if not session_id:
from ...platform import uuid as platform_uuid

session_id = platform_uuid.new_uuid()

initial_state = state or {}
Expand Down Expand Up @@ -206,57 +204,42 @@ async def create_session(
"state": session_state,
"createTime": now,
"updateTime": now,
"revision": 1,
"revision": 0,
}

@firestore.async_transactional # type: ignore[untyped-decorator]
async def _create_txn(transaction: firestore.AsyncTransaction) -> None:
async def _create_txn(
transaction: firestore.AsyncTransaction,
) -> tuple[dict[str, Any], dict[str, Any]]:
# 1. Reads
snap = await session_ref.get(transaction=transaction)
if snap.exists:
from ...errors.already_exists_error import AlreadyExistsError

raise AlreadyExistsError(f"Session {session_id} already exists.")

app_snap = (
await app_ref.get(transaction=transaction)
if app_state_delta
else None
app_snap = await app_ref.get(transaction=transaction)
user_snap = await user_ref.get(transaction=transaction)

current_app: dict[str, Any] = (
(app_snap.to_dict() or {}) if app_snap.exists else {}
)
user_snap = (
await user_ref.get(transaction=transaction)
if user_state_delta
else None
current_user: dict[str, Any] = (
(user_snap.to_dict() or {}) if user_snap.exists else {}
)

# 2. Writes
if app_state_delta:
current_app = (
app_snap.to_dict() if (app_snap and app_snap.exists) else {}
)
current_app.update(app_state_delta)
transaction.set(app_ref, current_app, merge=True)

if user_state_delta:
current_user = (
user_snap.to_dict() if (user_snap and user_snap.exists) else {}
)
current_user.update(user_state_delta)
transaction.set(user_ref, current_user, merge=True)

transaction.set(session_ref, session_data)
return current_app, current_user

transaction_obj = self.client.transaction()
await _create_txn(transaction_obj)

storage_app_doc = await app_ref.get()
storage_app_state = (
storage_app_doc.to_dict() if storage_app_doc.exists else {}
)
storage_user_doc = await user_ref.get()
storage_user_state = (
storage_user_doc.to_dict() if storage_user_doc.exists else {}
)
storage_app_state, storage_user_state = await _create_txn(transaction_obj)

merged_state = self._merge_state(
storage_app_state, storage_user_state, session_state
Expand Down Expand Up @@ -294,7 +277,7 @@ async def get_session(
if not data:
return None

# Fetch events
# Fetch events and shared state concurrently
events_ref = session_ref.collection(self.events_collection)
query = events_ref.order_by("timestamp")

Expand All @@ -305,18 +288,6 @@ async def get_session(
if config.num_recent_events:
query = query.limit_to_last(config.num_recent_events)

events_docs = await query.get()
events = []
for event_doc in events_docs:
event_data = event_doc.to_dict()
if event_data and "event_data" in event_data:
ed = event_data["event_data"]
events.append(Event.model_validate(ed))

# Let's continue getting session.
session_state = data.get("state", {})

# Fetch shared state
app_ref = self.client.collection(self.app_state_collection).document(
app_name
)
Expand All @@ -326,9 +297,22 @@ async def get_session(
.collection("users")
.document(user_id)
)
app_doc = await app_ref.get()

events_docs, app_doc, user_doc = await asyncio.gather(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using asyncio.gather to fetch events, app state, and user state concurrently is a great performance improvement.

query.get(),
app_ref.get(),
user_ref.get(),
)

events = []
for event_doc in events_docs:
event_data = event_doc.to_dict()
if event_data and "event_data" in event_data:
ed = event_data["event_data"]
events.append(Event.model_validate(ed))

session_state = data.get("state", {})
app_state = app_doc.to_dict() if app_doc.exists else {}
user_doc = await user_ref.get()
user_state = user_doc.to_dict() if user_doc.exists else {}

merged_state = self._merge_state(app_state, user_state, session_state)
Expand Down Expand Up @@ -374,6 +358,8 @@ async def list_sessions(
)
docs = await query.get()

sessions_data = [d for doc in docs if (d := doc.to_dict())]

# Fetch shared state once
app_ref = self.client.collection(self.app_state_collection).document(
app_name
Expand All @@ -393,43 +379,42 @@ async def list_sessions(
if user_doc.exists:
user_states_map[user_id] = user_doc.to_dict()
else:
users_ref = (
self.client.collection(self.user_state_collection)
.document(app_name)
.collection("users")
)
users_docs = await users_ref.get()
for u_doc in users_docs:
user_states_map[u_doc.id] = u_doc.to_dict()
unique_user_ids = {s["userId"] for s in sessions_data if "userId" in s}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch fetch using get_all for unique user IDs is much more efficient than fetching all users or fetching them one by one.

if unique_user_ids:
users_coll = (
self.client.collection(self.user_state_collection)
.document(app_name)
.collection("users")
)
refs = [users_coll.document(uid) for uid in sorted(unique_user_ids)]
async for u_doc in self.client.get_all(refs):
if u_doc.exists:
user_states_map[u_doc.id] = u_doc.to_dict()

sessions = []
for doc in docs:
data = doc.to_dict()
if data:
u_id = data["userId"]
s_state = data.get("state", {})
u_state = user_states_map.get(u_id, {})
merged = self._merge_state(app_state, u_state, s_state)

sessions.append(
Session(
id=data["id"],
app_name=data["appName"],
user_id=data["userId"],
state=merged,
events=[],
last_update_time=0.0,
)
)
for data in sessions_data:
u_id = data["userId"]
s_state = data.get("state", {})
u_state = user_states_map.get(u_id, {})
merged = self._merge_state(app_state, u_state, s_state)

sessions.append(
Session(
id=data["id"],
app_name=data["appName"],
user_id=data["userId"],
state=merged,
events=[],
last_update_time=0.0,
)
)

return ListSessionsResponse(sessions=sessions)

async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Deletes a session and its events from Firestore."""
from google.cloud import firestore

session_ref = self._get_sessions_ref(app_name, user_id).document(session_id)

@firestore.async_transactional # type: ignore[untyped-decorator]
Expand Down Expand Up @@ -464,8 +449,6 @@ async def _mark_deleting_txn(

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session in Firestore."""
from google.cloud import firestore

if event.partial:
return event

Expand Down Expand Up @@ -517,10 +500,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int:

if session._storage_update_marker is not None:
if session._storage_update_marker != str(current_revision):
raise ValueError(
"The session has been modified in storage since it was loaded. "
"Please reload the session before appending more events."
)
raise ValueError(_STALE_SESSION_ERROR_MESSAGE)

app_snap = (
await app_ref.get(transaction=transaction) if app_updates else None
Expand All @@ -542,20 +522,13 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int:
current_user.update(user_updates)
transaction.set(user_ref, current_user, merge=True)

for k, v in session_updates.items():
session.state[k] = v

new_revision = current_revision + 1
session_only_state = {
k: v
for k, v in session.state.items()
if not k.startswith(State.APP_PREFIX)
and not k.startswith(State.USER_PREFIX)
}
current_session_state = session_doc.get("state", {})
current_session_state.update(session_updates)
transaction.update(
session_ref,
{
"state": session_only_state,
"state": current_session_state,
"updateTime": firestore.SERVER_TIMESTAMP,
"revision": new_revision,
},
Expand All @@ -581,6 +554,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int:
transaction_obj = self.client.transaction()
new_revision_count = await _append_txn(transaction_obj)
session._storage_update_marker = str(new_revision_count)
session.last_update_time = event.timestamp
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating session.last_update_time here ensures the local session object is in sync with the storage update.


await super().append_event(session, event)
return event
Loading
Loading