Skip to content

Commit a80da51

Browse files
Roland Kakonyirolandkakonyi
authored andcommitted
Fix Vertex session user IDs
1 parent 8f20d56 commit a80da51

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.Comparator;
3636
import java.util.List;
3737
import java.util.Map;
38+
import java.util.Objects;
3839
import java.util.Optional;
3940
import java.util.concurrent.ConcurrentHashMap;
4041
import java.util.concurrent.ConcurrentMap;
@@ -148,7 +149,7 @@ private ListSessionsResponse parseListSessionsResponse(
148149
Session session =
149150
Session.builder(sessionId)
150151
.appName(appName)
151-
.userId(userId)
152+
.userId((String) apiSession.get("userId"))
152153
.state(
153154
apiSession.get("sessionState") == null
154155
? new ConcurrentHashMap<>()
@@ -195,6 +196,16 @@ public Maybe<Session> getSession(
195196
.getSession(reasoningEngineId, sessionId)
196197
.flatMap(
197198
getSessionResponseMap -> {
199+
String responseUserId =
200+
Optional.ofNullable(getSessionResponseMap.get("userId"))
201+
.map(JsonNode::asText)
202+
.orElse(null);
203+
if (!Objects.equals(responseUserId, userId)) {
204+
return Maybe.error(
205+
new IllegalArgumentException(
206+
"Session " + sessionId + " does not belong to user " + userId + "."));
207+
}
208+
198209
String sessId =
199210
Optional.ofNullable(getSessionResponseMap.get("name"))
200211
.map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText())))

core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,33 @@ public void listSessions_success() {
274274
assertThat(sessionsList).hasSize(2);
275275
ImmutableList<String> ids = sessionsList.stream().map(Session::id).collect(toImmutableList());
276276
assertThat(ids).containsExactly("1", "2");
277+
ImmutableList<String> userIds =
278+
sessionsList.stream().map(Session::userId).collect(toImmutableList());
279+
assertThat(userIds).containsExactly("user", "user");
280+
}
281+
282+
@Test
283+
public void listSessions_usesResponseUserId() throws Exception {
284+
when(mockApiClient.request("GET", "reasoningEngines/123/sessions?filter=user_id=user1", ""))
285+
.thenAnswer(
286+
new MockApiAnswer(
287+
"""
288+
{
289+
"sessions": [
290+
{
291+
"name": "projects/test-project/locations/test-location/reasoningEngines/123/sessions/3",
292+
"userId": "user2",
293+
"updateTime": "2024-12-14T12:12:12.123456Z"
294+
}
295+
]
296+
}\
297+
"""));
298+
299+
ListSessionsResponse sessions =
300+
vertexAiSessionService.listSessions("123", "user1").blockingGet();
301+
302+
assertThat(sessions.sessions()).hasSize(1);
303+
assertThat(sessions.sessions().get(0).userId()).isEqualTo("user2");
277304
}
278305

279306
@Test
@@ -346,12 +373,25 @@ public void listEvents_empty() {
346373
public void listEmptySession_success() {
347374
assertThat(
348375
vertexAiSessionService
349-
.getSession("789", "user1", "3", Optional.empty())
376+
.getSession("789", "user2", "3", Optional.empty())
350377
.blockingGet()
351378
.events())
352379
.isEmpty();
353380
}
354381

382+
@Test
383+
public void getSession_whenResponseUserIdDiffers_throws() {
384+
IllegalArgumentException exception =
385+
assertThrows(
386+
IllegalArgumentException.class,
387+
() ->
388+
vertexAiSessionService
389+
.getSession("789", "user1", "3", Optional.empty())
390+
.blockingGet());
391+
392+
assertThat(exception).hasMessageThat().contains("Session 3 does not belong to user user1.");
393+
}
394+
355395
@Test
356396
public void appendEvent_withStateRemoved_updatesSessionState() {
357397
String userId = "userB";

0 commit comments

Comments
 (0)