Skip to content

Commit 2db20f6

Browse files
committed
fix: preserve nontransparent live resumption
1 parent 76b9f0b commit 2db20f6

2 files changed

Lines changed: 81 additions & 4 deletions

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,19 @@ async def run_live(
507507
attempt += 1
508508
if not llm_request.live_connect_config:
509509
llm_request.live_connect_config = types.LiveConnectConfig()
510-
if not llm_request.live_connect_config.session_resumption:
510+
session_resumption = (
511+
llm_request.live_connect_config.session_resumption
512+
)
513+
if not session_resumption:
514+
session_resumption = types.SessionResumptionConfig()
511515
llm_request.live_connect_config.session_resumption = (
512-
types.SessionResumptionConfig()
516+
session_resumption
513517
)
514-
llm_request.live_connect_config.session_resumption.handle = (
518+
session_resumption.handle = (
515519
invocation_context.live_session_resumption_handle
516520
)
517-
llm_request.live_connect_config.session_resumption.transparent = True
521+
if session_resumption.transparent is None:
522+
session_resumption.transparent = True
518523

519524
logger.info(
520525
'Establishing live connection for agent: %s',

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,78 @@ async def mock_receive_2():
623623
assert invocation_context.live_session_resumption_handle == 'test_handle'
624624

625625

626+
@pytest.mark.asyncio
627+
async def test_run_live_reconnect_preserves_nontransparent_resumption():
628+
"""Test that reconnect does not force transparent resumption."""
629+
from google.adk.agents.live_request_queue import LiveRequestQueue
630+
from websockets.exceptions import ConnectionClosed
631+
632+
real_model = Gemini()
633+
mock_connection = mock.AsyncMock()
634+
635+
async def mock_receive():
636+
yield LlmResponse(
637+
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
638+
new_handle='test_handle'
639+
)
640+
)
641+
raise ConnectionClosed(None, None)
642+
643+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
644+
645+
agent = Agent(name='test_agent', model=real_model)
646+
invocation_context = await testing_utils.create_invocation_context(
647+
agent=agent
648+
)
649+
invocation_context.live_request_queue = LiveRequestQueue()
650+
651+
flow = BaseLlmFlowForTesting()
652+
653+
async def mock_preprocess(ctx, req):
654+
req.live_connect_config.session_resumption = types.SessionResumptionConfig(
655+
transparent=False
656+
)
657+
if False:
658+
yield
659+
660+
with mock.patch.object(
661+
flow, '_preprocess_async', side_effect=mock_preprocess
662+
):
663+
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
664+
mock_connection_2 = mock.AsyncMock()
665+
666+
class StopError(Exception):
667+
pass
668+
669+
async def mock_receive_2():
670+
yield LlmResponse(
671+
content=types.Content(parts=[types.Part.from_text(text='hi')])
672+
)
673+
raise StopError('stop')
674+
675+
mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)
676+
677+
mock_aenter = mock.AsyncMock()
678+
mock_aenter.side_effect = [mock_connection, mock_connection_2]
679+
680+
with mock.patch(
681+
'google.adk.models.google_llm.Gemini.connect'
682+
) as mock_connect:
683+
mock_connect.return_value.__aenter__ = mock_aenter
684+
685+
try:
686+
async for _ in flow.run_live(invocation_context):
687+
pass
688+
except StopError:
689+
pass
690+
691+
reconnect_request = mock_connect.call_args_list[1].args[0]
692+
assert (
693+
reconnect_request.live_connect_config.session_resumption.transparent
694+
is False
695+
)
696+
697+
626698
@pytest.mark.asyncio
627699
async def test_run_live_skips_send_history_on_resumption():
628700
"""Test that run_live skips send_history when resuming a session."""

0 commit comments

Comments
 (0)