Skip to content

Commit fb92aad

Browse files
google-genai-botcopybara-github
authored andcommitted
fix(simulation): Add error message when LlmBackedUserSimulator returns empty response
PiperOrigin-RevId: 911999966
1 parent 3117e09 commit fb92aad

2 files changed

Lines changed: 123 additions & 21 deletions

File tree

src/google/adk/evaluation/simulation/llm_backed_user_simulator.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import logging
1818
from typing import ClassVar
19-
from typing import Optional
2019

2120
from google.genai import types as genai_types
2221
from pydantic import Field
@@ -72,7 +71,7 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
7271
(Not recommended) If you don't want a limit, you can set the value to -1.""",
7372
)
7473

75-
custom_instructions: Optional[str] = Field(
74+
custom_instructions: str | None = Field(
7675
default=None,
7776
description="""Custom instructions for the LlmBackedUserSimulator. The
7877
instructions must contain the following formatting placeholders following Jinja syntax:
@@ -88,7 +87,7 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
8887

8988
@field_validator("custom_instructions")
9089
@classmethod
91-
def validate_custom_instructions(cls, value: Optional[str]) -> Optional[str]:
90+
def validate_custom_instructions(cls, value: str | None) -> str | None:
9291
if value is None:
9392
return value
9493
if not is_valid_user_simulator_template(
@@ -158,11 +157,11 @@ def _summarize_conversation(
158157
async def _get_llm_response(
159158
self,
160159
rewritten_dialogue: str,
161-
) -> str:
162-
"""Sends a user message generation request to the LLM and returns the full response."""
160+
) -> tuple[str, str | None]:
161+
"""Sends a user message generation request to the LLM and returns the full response and potential error reason."""
163162
if self._invocation_count == 0:
164163
# first invocation - send the static starting prompt
165-
return self._conversation_scenario.starting_prompt
164+
return self._conversation_scenario.starting_prompt, None
166165

167166
user_agent_instructions = get_llm_backed_user_simulator_prompt(
168167
conversation_plan=self._conversation_scenario.conversation_plan,
@@ -187,19 +186,44 @@ async def _get_llm_response(
187186
add_default_retry_options_if_not_present(llm_request)
188187

189188
response = ""
189+
error_reason = None
190+
has_thought_tokens = False
190191
async with Aclosing(self._llm.generate_content_async(llm_request)) as agen:
191192
async for llm_response in agen:
193+
error_code = llm_response.error_code
194+
if error_code:
195+
logger.warning(
196+
"User simulator LLM returned error: code=%s, message=%s",
197+
error_code,
198+
getattr(llm_response, "error_message", ""),
199+
)
200+
error_reason = f"safety filters or other error (code={error_code})"
201+
response = ""
202+
break
203+
192204
generated_content: genai_types.Content = llm_response.content
193205
if (
194206
not generated_content
195207
or not hasattr(generated_content, "parts")
196208
or not generated_content.parts
197209
):
198210
continue
211+
199212
for part in generated_content.parts:
200-
if part.text and not part.thought:
213+
if part.thought:
214+
has_thought_tokens = True
215+
elif part.text:
201216
response += part.text
202-
return response
217+
218+
if not response:
219+
if error_reason:
220+
pass # Keep the error reason from error_code
221+
elif has_thought_tokens:
222+
error_reason = "LLM returned only thinking tokens"
223+
else:
224+
error_reason = "LLM returned empty response"
225+
226+
return response, error_reason
203227

204228
@override
205229
async def get_next_user_message(
@@ -234,11 +258,11 @@ async def get_next_user_message(
234258
rewritten_dialogue = self._summarize_conversation(events)
235259

236260
# query the LLM for the next user message
237-
response = await self._get_llm_response(rewritten_dialogue)
261+
response, error_reason = await self._get_llm_response(rewritten_dialogue)
238262
self._invocation_count += 1
239263

240264
# is the conversation over? (Has the user simulator output the stop signal?)
241-
if _STOP_SIGNAL.lower() in response.lower():
265+
if response and _STOP_SIGNAL.lower() in response.lower():
242266
logger.info(
243267
"Stopping user message generation as the stop signal was detected."
244268
)
@@ -256,11 +280,11 @@ async def get_next_user_message(
256280

257281
# if we are here, the user agent failed to generate a message, which is not
258282
# a valid result for the LLM backed user simulator.
259-
raise RuntimeError("Failed to generate a user message")
283+
raise RuntimeError(f"Failed to generate a user message: {error_reason}")
260284

261285
@override
262286
def get_simulation_evaluator(
263287
self,
264-
) -> Optional[Evaluator]:
288+
) -> Evaluator | None:
265289
"""Returns an Evaluator that evaluates if the simulation was successful or not."""
266290
raise NotImplementedError()

tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ async def to_async_iter(items):
129129
def mock_llm_agent(mocker):
130130
"""Provides a mock LLM agent."""
131131
mock_llm_registry_cls = mocker.patch(
132-
"google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry"
132+
"google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry",
133+
autospec=True,
133134
)
134135
mock_llm_registry = mocker.MagicMock()
135136
mock_llm_registry_cls.return_value = mock_llm_registry
@@ -207,18 +208,25 @@ async def test_get_llm_response_return_value(
207208
self, simulator, mock_llm_agent, mocker
208209
):
209210
"""Tests that _get_llm_response returns the full response correctly."""
210-
mock_llm_response = mocker.MagicMock()
211+
mock_llm_response = mocker.create_autospec(
212+
types.GenerateContentResponse, instance=True
213+
)
214+
mock_llm_response.error_code = None
211215
mock_llm_response.content = types.Content(
212216
parts=[
213217
types.Part(text="some thought", thought=True),
214218
types.Part(text="Hello world!"),
215219
]
216220
)
221+
mock_llm_response.parts = mock_llm_response.content.parts
217222
mock_llm_agent.generate_content_async.return_value = to_async_iter(
218223
[mock_llm_response]
219224
)
220-
response = await simulator._get_llm_response(rewritten_dialogue="")
225+
response, error_reason = await simulator._get_llm_response(
226+
rewritten_dialogue=""
227+
)
221228
assert response == "Hello world!"
229+
assert error_reason is None
222230

223231
@pytest.mark.asyncio
224232
async def test_get_next_user_message_first_invocation(
@@ -257,10 +265,14 @@ async def test_turn_limit_reached(self, conversation_scenario):
257265
@pytest.mark.asyncio
258266
async def test_stop_signal_detected(self, simulator, mock_llm_agent, mocker):
259267
"""Tests get_next_user_message when the stop signal is detected."""
260-
mock_llm_response = mocker.MagicMock()
268+
mock_llm_response = mocker.create_autospec(
269+
types.GenerateContentResponse, instance=True
270+
)
271+
mock_llm_response.error_code = None
261272
mock_llm_response.content = types.Content(
262273
parts=[types.Part(text="Thanks! Bye!</finished>")]
263274
)
275+
mock_llm_response.parts = mock_llm_response.content.parts
264276
mock_llm_agent.generate_content_async.return_value = to_async_iter(
265277
[mock_llm_response]
266278
)
@@ -273,22 +285,84 @@ async def test_stop_signal_detected(self, simulator, mock_llm_agent, mocker):
273285
assert next_user_message.user_message is None
274286

275287
@pytest.mark.asyncio
276-
async def test_no_message_generated(self, simulator, mock_llm_agent):
277-
"""Tests get_next_user_message when no message is generated."""
288+
async def test_no_message_generated_empty_response(
289+
self, simulator, mock_llm_agent
290+
):
291+
"""Tests get_next_user_message when no message is generated (empty stream)."""
278292
mock_llm_agent.generate_content_async.return_value = to_async_iter([])
279293

280-
with pytest.raises(RuntimeError, match="Failed to generate a user message"):
294+
with pytest.raises(
295+
RuntimeError,
296+
match="Failed to generate a user message: LLM returned empty response",
297+
):
298+
await simulator.get_next_user_message(events=_INPUT_EVENTS)
299+
300+
@pytest.mark.asyncio
301+
async def test_get_next_user_message_safety_blocked(
302+
self, simulator, mock_llm_agent, mocker
303+
):
304+
"""Tests get_next_user_message when response is safety blocked."""
305+
mock_llm_response = mocker.create_autospec(
306+
types.GenerateContentResponse, instance=True
307+
)
308+
mock_llm_response.content = None
309+
mock_llm_response.error_code = "SAFETY"
310+
mock_llm_response.error_message = "Blocked by safety"
311+
mock_llm_response.parts = []
312+
mock_llm_agent.generate_content_async.return_value = to_async_iter(
313+
[mock_llm_response]
314+
)
315+
316+
with pytest.raises(
317+
RuntimeError,
318+
match=(
319+
"Failed to generate a user message: safety filters or other error"
320+
" \\(code=SAFETY\\)"
321+
),
322+
):
323+
await simulator.get_next_user_message(events=_INPUT_EVENTS)
324+
325+
@pytest.mark.asyncio
326+
async def test_get_next_user_message_thinking_only(
327+
self, simulator, mock_llm_agent, mocker
328+
):
329+
"""Tests get_next_user_message when response contains only thinking tokens."""
330+
mock_llm_response = mocker.create_autospec(
331+
types.GenerateContentResponse, instance=True
332+
)
333+
mock_llm_response.content = types.Content(
334+
parts=[
335+
types.Part(text="thinking...", thought=True),
336+
]
337+
)
338+
mock_llm_response.error_code = None
339+
mock_llm_response.parts = mock_llm_response.content.parts
340+
mock_llm_agent.generate_content_async.return_value = to_async_iter(
341+
[mock_llm_response]
342+
)
343+
344+
with pytest.raises(
345+
RuntimeError,
346+
match=(
347+
"Failed to generate a user message: LLM returned only thinking"
348+
" tokens"
349+
),
350+
):
281351
await simulator.get_next_user_message(events=_INPUT_EVENTS)
282352

283353
@pytest.mark.asyncio
284354
async def test_get_next_user_message_success(
285355
self, simulator, mock_llm_agent, mocker
286356
):
287357
"""Tests get_next_user_message when the user message is generated successfully."""
288-
mock_llm_response = mocker.MagicMock()
358+
mock_llm_response = mocker.create_autospec(
359+
types.GenerateContentResponse, instance=True
360+
)
361+
mock_llm_response.error_code = None
289362
mock_llm_response.content = types.Content(
290363
parts=[types.Part(text="I need to book a flight.")]
291364
)
365+
mock_llm_response.parts = mock_llm_response.content.parts
292366
mock_llm_agent.generate_content_async.return_value = to_async_iter(
293367
[mock_llm_response]
294368
)
@@ -309,10 +383,14 @@ async def test_get_next_user_message_with_persona_success(
309383
self, simulator_with_persona, mock_llm_agent, mocker
310384
):
311385
"""Tests get_next_user_message when the user message is generated successfully."""
312-
mock_llm_response = mocker.MagicMock()
386+
mock_llm_response = mocker.create_autospec(
387+
types.GenerateContentResponse, instance=True
388+
)
389+
mock_llm_response.error_code = None
313390
mock_llm_response.content = types.Content(
314391
parts=[types.Part(text="I need to book a flight.")]
315392
)
393+
mock_llm_response.parts = mock_llm_response.content.parts
316394
mock_llm_agent.generate_content_async.return_value = to_async_iter(
317395
[mock_llm_response]
318396
)

0 commit comments

Comments
 (0)