From 561b9685e77d56fc30381cf111ff1d5bf7662e0f Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 24 Jun 2026 15:18:04 -0700 Subject: [PATCH] fix: exit connection cleanly on expected GoAway signal in bidi streaming Receive the GoAway signal from the Gemini Live API, set a flag on the InvocationContext indicating reconnection is requested, and exit the receive generator cleanly instead of raising a ConnectionClosed exception. This avoids throwing expected session-recycling exceptions into custom client wrappers, which helps prevent false alarms in custom client log monitors. Co-authored-by: Kathy Wu PiperOrigin-RevId: 937586604 Change-Id: Ic8d85ecadfd50647c9349dcb4c1a4d53518a7621 --- .../adk/flows/llm_flows/base_llm_flow.py | 18 +++++++++++++++--- .../flows/llm_flows/test_base_llm_flow.py | 13 +++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 39a66244ca5..9c0731e6a2b 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -57,6 +57,11 @@ # Prefix used by toolset auth credential IDs TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' + +class _ReconnectSentinel(Event): + """Internal sentinel event to signal a silent reconnection request.""" + + if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent from ...models.base_llm import BaseLlm @@ -577,6 +582,7 @@ async def run_live( self._send_to_model(llm_connection, invocation_context) ) + should_reconnect = False try: async with Aclosing( self._receive_from_model( @@ -587,6 +593,9 @@ async def run_live( ) ) as agen: async for event in agen: + if isinstance(event, _ReconnectSentinel): + should_reconnect = True + break # Empty event means the queue is closed. if not event: break @@ -667,6 +676,9 @@ async def run_live( await send_task except asyncio.CancelledError: pass + if should_reconnect: + continue + break except (ConnectionClosed, ConnectionClosedOK) as e: # If we have a session resumption handle, we attempt to reconnect. # This handle is updated dynamically during the session. @@ -805,9 +817,9 @@ def get_author_for_event(llm_response: LlmResponse) -> str: if llm_response.go_away: logger.info(f'Received go away signal: {llm_response.go_away}') # The server signals that it will close the connection soon. - # We proactively raise ConnectionClosed to trigger the reconnection - # logic in run_live, which will use the latest session handle. - raise ConnectionClosed(None, None) + # We yield a sentinel event to request reconnection internally. + yield _ReconnectSentinel(author='system') + return model_response_event = Event( id=Event.new_id(), diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index b5c3f1a612a..2b8eb92d3a6 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -22,6 +22,7 @@ from google.adk.agents.run_config import RunConfig from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback +from google.adk.flows.llm_flows.base_llm_flow import _ReconnectSentinel from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini from google.adk.models.google_llm import GoogleLLMVariant @@ -728,15 +729,23 @@ async def mock_receive_2(): ) as mock_connect: mock_connect.return_value.__aenter__ = mock_aenter + yielded_events = [] try: - async for _ in flow.run_live(invocation_context): - pass + async for event in flow.run_live(invocation_context): + yielded_events.append(event) except StopError: pass # Verify that we attempted to connect twice (initial + reconnect after go_away). assert mock_connect.call_count == 2 + # Verify that the internal _ReconnectSentinel is not leaked/yielded to the caller. + assert not any(isinstance(e, _ReconnectSentinel) for e in yielded_events) + + # Verify we yielded the expected response after reconnection. + assert len(yielded_events) == 1 + assert yielded_events[0].content.parts[0].text == 'hi' + @pytest.mark.asyncio async def test_run_live_no_reconnect_without_handle():