Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/google/adk/workflow/utils/_rehydration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from pydantic import TypeAdapter
from pydantic import ValidationError

from ...events._node_path_builder import _NodePathBuilder
from ...events.event import Event
from ...events.request_input import RequestInput
from ...events._node_path_builder import _NodePathBuilder
from ._workflow_hitl_utils import REQUEST_INPUT_FUNCTION_CALL_NAME

_RESULT_KEY = 'result'
Expand Down Expand Up @@ -168,6 +168,25 @@ def _validate_resume_response(response_data: Any, schema: Any) -> Any:
raise ValueError(f'Validation failed against schema: {e}') from e


def _process_content_object(event: Event) -> Any:
"""Extracts output from event.content."""
if not event.content or not event.content.parts:
return None

text = ''.join(
p.text for p in event.content.parts if p.text and not p.thought
)
text = text.strip()

if not text:
return None

try:
return json.loads(text)
except (json.JSONDecodeError, ValueError):
return text


def _reconstruct_node_states(
events: list[Event],
base_path: str,
Expand All @@ -189,8 +208,9 @@ def get_owner_key(event_path_builder: _NodePathBuilder) -> str | None:
segment: str = child_path._segments[-1]
return segment
else:
if event_path_builder == base_path_builder or event_path_builder.is_descendant_of(
base_path_builder
if (
event_path_builder == base_path_builder
or event_path_builder.is_descendant_of(base_path_builder)
):
return base_path
return None
Expand Down Expand Up @@ -266,7 +286,7 @@ def get_owner_key(event_path_builder: _NodePathBuilder) -> str | None:
child.output = event.output
child.branch = event.branch
elif use_message_as_output:
child.output = event.content
child.output = _process_content_object(event)
if event.actions and event.actions.route is not None:
child.route = event.actions.route

Expand Down
1 change: 1 addition & 0 deletions tests/integration/utils/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class TestRunner:
"""Agents runner for testing."""

__test__ = False
app_name = "test_app"
user_id = "test_user"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class TestBaseLlmFlow(BaseLlmFlow):
"""Test implementation of BaseLlmFlow for testing purposes."""

__test__ = False
pass


Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class TestInMemoryRunner(AfInMemoryRunner):
app_name is hardcoded as InMemoryRunner in the parent class.
"""

__test__ = False

async def run_async_with_new_session(
self, new_message: types.ContentUnion
) -> list[Event]:
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/workflow/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ class TestInMemoryRunner(AfInMemoryRunner):
app_name is hardcoded as InMemoryRunner in the parent class.
"""

__test__ = False

async def run_async_with_new_session(
self, new_message: types.ContentUnion
) -> list[Event]:
Expand Down
45 changes: 44 additions & 1 deletion tests/unittests/workflow/utils/test_rehydration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.adk.events.event import NodeInfo
from google.adk.events.request_input import RequestInput
from google.adk.workflow.utils._rehydration_utils import _ChildScanState
from google.adk.workflow.utils._rehydration_utils import _process_content_object
from google.adk.workflow.utils._rehydration_utils import _reconstruct_node_states
from google.adk.workflow.utils._rehydration_utils import _unwrap_response
from google.adk.workflow.utils._rehydration_utils import _validate_resume_response
Expand Down Expand Up @@ -157,6 +158,48 @@ class User(BaseModel):
)


# --- _process_content_object ---


class TestProcessContentObject:

def test_extracts_plain_text(self):
content = types.Content(parts=[types.Part(text="hello world")])
event = Event(content=content, invocation_id="id")
assert _process_content_object(event) == "hello world"

def test_parses_json_text(self):
content = types.Content(parts=[types.Part(text='{"foo": "bar"}')])
event = Event(content=content, invocation_id="id")
assert _process_content_object(event) == {"foo": "bar"}

def test_joins_multiple_parts(self):
content = types.Content(
parts=[types.Part(text="hello "), types.Part(text="world")]
)
event = Event(content=content, invocation_id="id")
assert _process_content_object(event) == "hello world"

def test_filters_thought_parts(self):
content = types.Content(
parts=[
types.Part(text="thinking...", thought=True),
types.Part(text='{"answer": 42}'),
]
)
event = Event(content=content, invocation_id="id")
assert _process_content_object(event) == {"answer": 42}

def test_returns_none_for_no_content(self):
event = Event(invocation_id="id")
assert _process_content_object(event) is None

def test_returns_none_for_empty_text(self):
content = types.Content(parts=[types.Part(text=" ")])
event = Event(content=content, invocation_id="id")
assert _process_content_object(event) is None


# --- _reconstruct_node_states ---


Expand Down Expand Up @@ -188,7 +231,7 @@ def test_scan_message_as_output(self):
results = _reconstruct_node_states([event], "/wf@1", invocation_id="test_id", group_by_direct_child=True)

assert "node_a@1" in results
assert results["node_a@1"].output == content
assert results["node_a@1"].output == "hello"

def test_scan_descendant_interrupts(self):
event = Event(
Expand Down
2 changes: 2 additions & 0 deletions tests/unittests/workflow/workflow_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def run_workflow(wf, message='start'):
# The route can be set without the output. This means didn't produce any output
# but wants to signal a route to take.
class TestingNode(BaseNode):
__test__ = False
model_config = ConfigDict(arbitrary_types_allowed=True)

output: Optional[Any] = None
Expand Down Expand Up @@ -97,6 +98,7 @@ async def _run_impl(


class TestingNodeWithIntermediateContent(BaseNode):
__test__ = False
model_config = ConfigDict(arbitrary_types_allowed=True)

intermediate_content: list[types.Content] = Field(default_factory=list)
Expand Down