Skip to content

Commit 4099293

Browse files
authored
Merge branch 'main' into fix/inject-session-state-key-error
2 parents 89d9c5f + baf7efb commit 4099293

14 files changed

Lines changed: 738 additions & 939 deletions

File tree

src/google/adk/a2a/converters/event_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,10 @@ def convert_event_to_a2a_events(
570570

571571
# Handle regular message content
572572
message = convert_event_to_a2a_message(
573-
event, invocation_context, part_converter=part_converter
573+
event,
574+
invocation_context,
575+
part_converter=part_converter,
576+
role=Role.user if event.author == "user" else Role.agent,
574577
)
575578
if message:
576579
running_event = _create_status_update_event(

src/google/adk/evaluation/simulation/llm_backed_user_simulator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
8585
""",
8686
)
8787

88+
include_function_calls: bool = Field(
89+
default=False,
90+
description="""Whether to include function calls and responses in the
91+
conversation history prompt provided to the user simulator.""",
92+
)
93+
8894
@field_validator("custom_instructions")
8995
@classmethod
9096
def validate_custom_instructions(cls, value: str | None) -> str | None:
@@ -132,13 +138,15 @@ def __init__(
132138
def _summarize_conversation(
133139
cls,
134140
events: list[Event],
141+
include_function_calls: bool = False,
135142
) -> str:
136143
"""Summarize the conversation to add to the prompt.
137144
138-
Removes tool calls, responses, and thoughts.
145+
Removes responses, thoughts, optionally tool calls and tool responses.
139146
140147
Args:
141148
events: The conversation history to rewrite.
149+
include_function_calls: Whether to include function calls and responses.
142150
143151
Returns:
144152
The summarized conversation history as a string.
@@ -151,6 +159,16 @@ def _summarize_conversation(
151159
for part in e.content.parts:
152160
if part.text and not part.thought:
153161
rewritten_dialogue.append(f"{author}: {part.text}")
162+
elif include_function_calls and part.function_call:
163+
rewritten_dialogue.append(
164+
f"{author} called tool '{part.function_call.name}' with args:"
165+
f" {part.function_call.args}"
166+
)
167+
elif include_function_calls and part.function_response:
168+
rewritten_dialogue.append(
169+
f"Tool '{part.function_response.name}' returned:"
170+
f" {part.function_response.response}"
171+
)
154172

155173
return "\n\n".join(rewritten_dialogue)
156174

@@ -255,7 +273,9 @@ async def get_next_user_message(
255273
return NextUserMessage(status=Status.TURN_LIMIT_REACHED)
256274

257275
# rewrite events for the user simulator
258-
rewritten_dialogue = self._summarize_conversation(events)
276+
rewritten_dialogue = self._summarize_conversation(
277+
events, self._config.include_function_calls
278+
)
259279

260280
# query the LLM for the next user message
261281
response, error_reason = await self._get_llm_response(rewritten_dialogue)

src/google/adk/models/lite_llm.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -807,9 +807,14 @@ async def _content_to_message_param(
807807
if isinstance(response, str)
808808
else _safe_json_serialize(response)
809809
)
810+
# gemma4 requires role='tool_responses' for recognizing function_response parts as responses
811+
# from the tool call, instead of OpenAI-compatible 'tool' role used by other models.
812+
# Earlier Gemma versions before version 4 do not support tool use,
813+
# so this check is intentionally scoped to only look for "gemma4" in the model name.
814+
tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"
810815
tool_messages.append(
811816
ChatCompletionToolMessage(
812-
role="tool",
817+
role=tool_role,
813818
tool_call_id=part.function_response.id,
814819
content=response_content,
815820
)
@@ -824,6 +829,7 @@ async def _content_to_message_param(
824829
follow_up = await _content_to_message_param(
825830
types.Content(role=content.role, parts=non_tool_parts),
826831
provider=provider,
832+
model=model,
827833
)
828834
follow_up_messages = (
829835
follow_up if isinstance(follow_up, list) else [follow_up]
@@ -934,12 +940,16 @@ async def _content_to_message_param(
934940
)
935941

936942

937-
def _ensure_tool_results(messages: List[Message]) -> List[Message]:
943+
def _ensure_tool_results(messages: List[Message], model: str) -> List[Message]:
938944
"""Insert placeholder tool messages for missing tool results.
939945
940946
LiteLLM-backed providers like OpenAI and Anthropic reject histories where an
941947
assistant tool call is not followed by tool responses before the next
942948
non-tool message. This helps recover from interrupted tool execution.
949+
950+
For models that expect a different tool response role (e.g. Gemma4 models,
951+
which require 'tool_responses' instead of 'tool'), the role is adjusted
952+
accordingly.
943953
"""
944954
if not messages:
945955
return messages
@@ -948,17 +958,19 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
948958

949959
healed_messages: List[Message] = []
950960
pending_tool_call_ids: List[str] = []
961+
expected_tool_role = "tool_responses" if "gemma4" in model.lower() else "tool"
951962

952963
for message in messages:
953964
role = message.get("role")
954-
if pending_tool_call_ids and role != "tool":
965+
966+
if pending_tool_call_ids and role != expected_tool_role:
955967
logger.warning(
956968
"Missing tool results for tool_call_id(s): %s",
957969
pending_tool_call_ids,
958970
)
959971
healed_messages.extend(
960972
ChatCompletionToolMessage(
961-
role="tool",
973+
role=expected_tool_role,
962974
tool_call_id=tool_call_id,
963975
content=_MISSING_TOOL_RESULT_MESSAGE,
964976
)
@@ -971,21 +983,22 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]:
971983
pending_tool_call_ids = [
972984
tool_call.get("id") for tool_call in tool_calls if tool_call.get("id")
973985
]
974-
elif role == "tool":
986+
elif role == expected_tool_role:
975987
tool_call_id = message.get("tool_call_id")
976988
if tool_call_id in pending_tool_call_ids:
977989
pending_tool_call_ids.remove(tool_call_id)
978990

979991
healed_messages.append(message)
980992

993+
# Final block also uses expected_tool_role
981994
if pending_tool_call_ids:
982995
logger.warning(
983996
"Missing tool results for tool_call_id(s): %s",
984997
pending_tool_call_ids,
985998
)
986999
healed_messages.extend(
9871000
ChatCompletionToolMessage(
988-
role="tool",
1001+
role=expected_tool_role,
9891002
tool_call_id=tool_call_id,
9901003
content=_MISSING_TOOL_RESULT_MESSAGE,
9911004
)
@@ -1905,7 +1918,7 @@ async def _get_completion_inputs(
19051918
content=llm_request.config.system_instruction,
19061919
),
19071920
)
1908-
messages = _ensure_tool_results(messages)
1921+
messages = _ensure_tool_results(messages, model)
19091922

19101923
# 2. Convert tool declarations
19111924
tools: Optional[List[Dict]] = None

src/google/adk/telemetry/tracing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,12 @@ def trace_call_llm(
358358
except AttributeError:
359359
pass
360360

361-
try:
362-
llm_response_json = llm_response.model_dump_json(exclude_none=True)
363-
except Exception: # pylint: disable=broad-exception-caught
364-
llm_response_json = '<not serializable>'
365-
366361
if _should_add_request_response_to_spans():
362+
try:
363+
llm_response_json = llm_response.model_dump_json(exclude_none=True)
364+
except Exception: # pylint: disable=broad-exception-caught
365+
llm_response_json = '<not serializable>'
366+
367367
span.set_attribute(
368368
'gcp.vertex.agent.llm_response',
369369
llm_response_json,
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import json
17+
from typing import Any
18+
19+
import requests
20+
21+
22+
def get_stream(
23+
url: str,
24+
ca_payload: dict[str, Any],
25+
headers: dict[str, str],
26+
max_query_result_rows: int,
27+
) -> list[dict[str, Any]]:
28+
"""Sends a JSON request to a streaming API and returns a list of messages."""
29+
with requests.Session() as s:
30+
accumulator = ""
31+
messages = []
32+
data_msg_idx = -1
33+
34+
with s.post(url, json=ca_payload, headers=headers, stream=True) as resp:
35+
resp.raise_for_status()
36+
for line in resp.iter_lines():
37+
if not line:
38+
continue
39+
40+
decoded_line = line.decode("utf-8")
41+
42+
if decoded_line == "[{":
43+
accumulator = "{"
44+
elif decoded_line == "}]":
45+
accumulator += "}"
46+
elif decoded_line == ",":
47+
continue
48+
else:
49+
accumulator += decoded_line
50+
51+
try:
52+
data_json = json.loads(accumulator)
53+
except ValueError:
54+
continue
55+
56+
accumulator = ""
57+
58+
if not isinstance(data_json, dict):
59+
messages.append(data_json)
60+
continue
61+
62+
processed_msg = None
63+
data_result = _extract_data_result(data_json)
64+
if data_result is not None:
65+
processed_msg = _format_data_retrieved(
66+
data_result, max_query_result_rows
67+
)
68+
if data_msg_idx >= 0:
69+
messages[data_msg_idx] = {
70+
"Data Retrieved": "Intermediate result omitted"
71+
}
72+
data_msg_idx = len(messages)
73+
elif isinstance(data_json.get("systemMessage"), dict):
74+
processed_msg = data_json["systemMessage"]
75+
else:
76+
processed_msg = data_json
77+
78+
if processed_msg is not None:
79+
messages.append(processed_msg)
80+
81+
return messages
82+
83+
84+
def _extract_data_result(msg: dict[str, Any]) -> dict[str, Any] | None:
85+
"""Attempts to find the result.data deep inside the generic dict."""
86+
sm = msg.get("systemMessage")
87+
if not isinstance(sm, dict):
88+
return None
89+
data = sm.get("data")
90+
if not isinstance(data, dict):
91+
return None
92+
result = data.get("result")
93+
if not isinstance(result, dict):
94+
return None
95+
if "data" in result and isinstance(result["data"], list):
96+
return result
97+
return None
98+
99+
100+
def _format_data_retrieved(
101+
result: dict[str, Any], max_rows: int
102+
) -> dict[str, Any]:
103+
"""Transforms the raw result dict into the simplified Toolbox format."""
104+
raw_data = result.get("data", [])
105+
106+
fields = []
107+
schema = result.get("schema")
108+
if isinstance(schema, dict):
109+
schema_fields = schema.get("fields")
110+
if isinstance(schema_fields, list):
111+
fields = schema_fields
112+
113+
headers = []
114+
for f in fields:
115+
if isinstance(f, dict):
116+
name = f.get("name")
117+
if isinstance(name, str):
118+
headers.append(name)
119+
120+
if not headers and raw_data:
121+
first_row = raw_data[0]
122+
if isinstance(first_row, dict):
123+
headers = list(first_row.keys())
124+
125+
total_rows = len(raw_data)
126+
num_to_display = min(total_rows, max_rows)
127+
128+
rows = []
129+
for r in raw_data[:num_to_display]:
130+
if isinstance(r, dict):
131+
row = [r.get(h) for h in headers]
132+
rows.append(row)
133+
134+
summary = f"Showing all {total_rows} rows."
135+
if total_rows > max_rows:
136+
summary = f"Showing the first {num_to_display} of {total_rows} total rows."
137+
138+
return {
139+
"Data Retrieved": {
140+
"headers": headers,
141+
"rows": rows,
142+
"summary": summary,
143+
}
144+
}

0 commit comments

Comments
 (0)