Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
# Prefix used by toolset auth credential IDs
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'


class _ReconnectSentinel(Event):
"""Internal sentinel event to signal a silent reconnection request."""


if TYPE_CHECKING:
from ...agents.llm_agent import LlmAgent
from ...models.base_llm import BaseLlm
Expand Down Expand Up @@ -577,6 +582,7 @@ async def run_live(
self._send_to_model(llm_connection, invocation_context)
)

should_reconnect = False
try:
async with Aclosing(
self._receive_from_model(
Expand All @@ -587,6 +593,9 @@ async def run_live(
)
) as agen:
async for event in agen:
if isinstance(event, _ReconnectSentinel):
should_reconnect = True
break
# Empty event means the queue is closed.
if not event:
break
Expand Down Expand Up @@ -667,6 +676,9 @@ async def run_live(
await send_task
except asyncio.CancelledError:
pass
if should_reconnect:
continue
break
except (ConnectionClosed, ConnectionClosedOK) as e:
# If we have a session resumption handle, we attempt to reconnect.
# This handle is updated dynamically during the session.
Expand Down Expand Up @@ -805,9 +817,9 @@ def get_author_for_event(llm_response: LlmResponse) -> str:
if llm_response.go_away:
logger.info(f'Received go away signal: {llm_response.go_away}')
# The server signals that it will close the connection soon.
# We proactively raise ConnectionClosed to trigger the reconnection
# logic in run_live, which will use the latest session handle.
raise ConnectionClosed(None, None)
# We yield a sentinel event to request reconnection internally.
yield _ReconnectSentinel(author='system')
return

model_response_event = Event(
id=Event.new_id(),
Expand Down
13 changes: 11 additions & 2 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.adk.agents.run_config import RunConfig
from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
from google.adk.flows.llm_flows.base_llm_flow import _ReconnectSentinel
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.models.google_llm import Gemini
from google.adk.models.google_llm import GoogleLLMVariant
Expand Down Expand Up @@ -728,15 +729,23 @@ async def mock_receive_2():
) as mock_connect:
mock_connect.return_value.__aenter__ = mock_aenter

yielded_events = []
try:
async for _ in flow.run_live(invocation_context):
pass
async for event in flow.run_live(invocation_context):
yielded_events.append(event)
except StopError:
pass

# Verify that we attempted to connect twice (initial + reconnect after go_away).
assert mock_connect.call_count == 2

# Verify that the internal _ReconnectSentinel is not leaked/yielded to the caller.
assert not any(isinstance(e, _ReconnectSentinel) for e in yielded_events)

# Verify we yielded the expected response after reconnection.
assert len(yielded_events) == 1
assert yielded_events[0].content.parts[0].text == 'hi'


@pytest.mark.asyncio
async def test_run_live_no_reconnect_without_handle():
Expand Down
Loading