Skip to content

Commit 73ea2e2

Browse files
committed
fix: preserve live resumption transparency
1 parent 7e38fc8 commit 73ea2e2

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ async def run_live(
517517
llm_request.live_connect_config.session_resumption.handle = (
518518
invocation_context.live_session_resumption_handle
519519
)
520-
llm_request.live_connect_config.session_resumption.transparent = True
521520

522521
logger.info(
523522
'Establishing live connection for agent: %s',

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,50 @@ async def mock_receive():
681681
mock_connection.send_history.assert_not_called()
682682

683683

684+
@pytest.mark.asyncio
685+
async def test_run_live_resumption_preserves_transparent_setting():
686+
"""Test that reconnect does not force transparent resumption."""
687+
from google.adk.agents.live_request_queue import LiveRequestQueue
688+
689+
real_model = Gemini()
690+
mock_connection = mock.AsyncMock()
691+
692+
async def mock_receive():
693+
yield LlmResponse(
694+
content=types.Content(parts=[types.Part.from_text(text='hi')])
695+
)
696+
raise RuntimeError('stop')
697+
698+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
699+
700+
agent = Agent(name='test_agent', model=real_model)
701+
run_config = RunConfig(session_resumption=types.SessionResumptionConfig())
702+
invocation_context = await testing_utils.create_invocation_context(
703+
agent=agent, run_config=run_config
704+
)
705+
invocation_context.live_session_resumption_handle = 'test_handle'
706+
invocation_context.live_request_queue = LiveRequestQueue()
707+
708+
flow = BaseLlmFlowForTesting()
709+
710+
with mock.patch.object(
711+
flow, '_send_to_model', new_callable=AsyncMock
712+
) as mock_send:
713+
with mock.patch(
714+
'google.adk.models.google_llm.Gemini.connect'
715+
) as mock_connect:
716+
mock_connect.return_value.__aenter__.return_value = mock_connection
717+
718+
with pytest.raises(RuntimeError, match='stop'):
719+
async for _ in flow.run_live(invocation_context):
720+
pass
721+
722+
llm_request = mock_connect.call_args.args[0]
723+
session_resumption = llm_request.live_connect_config.session_resumption
724+
assert session_resumption.handle == 'test_handle'
725+
assert session_resumption.transparent is None
726+
727+
684728
@pytest.mark.asyncio
685729
async def test_live_session_resumption_go_away():
686730
"""Test that go_away triggers reconnection."""

0 commit comments

Comments
 (0)