Skip to content

Commit 4be0bc6

Browse files
committed
refactor: update list_sessions to use firestore get_all method
1 parent 929903b commit 4be0bc6

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

src/google/adk/integrations/firestore/firestore_session_service.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,8 @@ async def list_sessions(
386386
.document(app_name)
387387
.collection("users")
388388
)
389-
user_docs = await asyncio.gather(
390-
*[users_coll.document(uid).get() for uid in unique_user_ids]
391-
)
392-
for u_doc in user_docs:
389+
refs = [users_coll.document(uid) for uid in sorted(unique_user_ids)]
390+
async for u_doc in self.client.get_all(refs):
393391
if u_doc.exists:
394392
user_states_map[u_doc.id] = u_doc.to_dict()
395393

tests/unittests/integrations/firestore/test_firestore_session_service.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,11 @@ def collection_side_effect(name):
482482
user_app_doc.collection.return_value = users_coll
483483
user_doc_ref = mock.MagicMock()
484484
users_coll.document.return_value = user_doc_ref
485-
user_doc_ref.get = mock.AsyncMock(return_value=user_doc)
485+
486+
async def mock_get_all(refs):
487+
yield user_doc
488+
489+
mock_firestore_client.get_all = mock_get_all
486490

487491
response = await service.list_sessions(app_name=app_name)
488492

@@ -544,7 +548,11 @@ def collection_side_effect(name):
544548
user_app_doc.collection.return_value = users_coll
545549
user_doc_ref = mock.MagicMock()
546550
users_coll.document.return_value = user_doc_ref
547-
user_doc_ref.get = mock.AsyncMock(return_value=user_doc)
551+
552+
async def mock_get_all(refs):
553+
yield user_doc
554+
555+
mock_firestore_client.get_all = mock_get_all
548556

549557
response = await service.list_sessions(app_name=app_name)
550558

0 commit comments

Comments
 (0)