diff --git a/src/google/adk/workflow/utils/_rehydration_utils.py b/src/google/adk/workflow/utils/_rehydration_utils.py index 3cfacf3e51..95e303076d 100644 --- a/src/google/adk/workflow/utils/_rehydration_utils.py +++ b/src/google/adk/workflow/utils/_rehydration_utils.py @@ -25,9 +25,9 @@ from pydantic import TypeAdapter from pydantic import ValidationError +from ...events._node_path_builder import _NodePathBuilder from ...events.event import Event from ...events.request_input import RequestInput -from ...events._node_path_builder import _NodePathBuilder from ._workflow_hitl_utils import REQUEST_INPUT_FUNCTION_CALL_NAME _RESULT_KEY = 'result' @@ -168,6 +168,25 @@ def _validate_resume_response(response_data: Any, schema: Any) -> Any: raise ValueError(f'Validation failed against schema: {e}') from e +def _process_content_object(event: Event) -> Any: + """Extracts output from event.content.""" + if not event.content or not event.content.parts: + return None + + text = ''.join( + p.text for p in event.content.parts if p.text and not p.thought + ) + text = text.strip() + + if not text: + return None + + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + return text + + def _reconstruct_node_states( events: list[Event], base_path: str, @@ -189,8 +208,9 @@ def get_owner_key(event_path_builder: _NodePathBuilder) -> str | None: segment: str = child_path._segments[-1] return segment else: - if event_path_builder == base_path_builder or event_path_builder.is_descendant_of( - base_path_builder + if ( + event_path_builder == base_path_builder + or event_path_builder.is_descendant_of(base_path_builder) ): return base_path return None @@ -266,7 +286,7 @@ def get_owner_key(event_path_builder: _NodePathBuilder) -> str | None: child.output = event.output child.branch = event.branch elif use_message_as_output: - child.output = event.content + child.output = _process_content_object(event) if event.actions and event.actions.route is not None: child.route = event.actions.route diff --git a/tests/integration/utils/test_runner.py b/tests/integration/utils/test_runner.py index 5322816137..37d2a978bc 100644 --- a/tests/integration/utils/test_runner.py +++ b/tests/integration/utils/test_runner.py @@ -29,6 +29,7 @@ class TestRunner: """Agents runner for testing.""" + __test__ = False app_name = "test_app" user_id = "test_user" diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py index 054e06d542..f61f7a1a37 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py @@ -29,6 +29,7 @@ class TestBaseLlmFlow(BaseLlmFlow): """Test implementation of BaseLlmFlow for testing purposes.""" + __test__ = False pass diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index c220e0228d..9ee05db5d7 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -189,6 +189,8 @@ class TestInMemoryRunner(AfInMemoryRunner): app_name is hardcoded as InMemoryRunner in the parent class. """ + __test__ = False + async def run_async_with_new_session( self, new_message: types.ContentUnion ) -> list[Event]: diff --git a/tests/unittests/workflow/testing_utils.py b/tests/unittests/workflow/testing_utils.py index f6de68d895..9d1fcb80c7 100644 --- a/tests/unittests/workflow/testing_utils.py +++ b/tests/unittests/workflow/testing_utils.py @@ -216,6 +216,8 @@ class TestInMemoryRunner(AfInMemoryRunner): app_name is hardcoded as InMemoryRunner in the parent class. """ + __test__ = False + async def run_async_with_new_session( self, new_message: types.ContentUnion ) -> list[Event]: diff --git a/tests/unittests/workflow/utils/test_rehydration_utils.py b/tests/unittests/workflow/utils/test_rehydration_utils.py index afaadc9a2b..ccf05c98cd 100644 --- a/tests/unittests/workflow/utils/test_rehydration_utils.py +++ b/tests/unittests/workflow/utils/test_rehydration_utils.py @@ -21,6 +21,7 @@ from google.adk.events.event import NodeInfo from google.adk.events.request_input import RequestInput from google.adk.workflow.utils._rehydration_utils import _ChildScanState +from google.adk.workflow.utils._rehydration_utils import _process_content_object from google.adk.workflow.utils._rehydration_utils import _reconstruct_node_states from google.adk.workflow.utils._rehydration_utils import _unwrap_response from google.adk.workflow.utils._rehydration_utils import _validate_resume_response @@ -157,6 +158,48 @@ class User(BaseModel): ) +# --- _process_content_object --- + + +class TestProcessContentObject: + + def test_extracts_plain_text(self): + content = types.Content(parts=[types.Part(text="hello world")]) + event = Event(content=content, invocation_id="id") + assert _process_content_object(event) == "hello world" + + def test_parses_json_text(self): + content = types.Content(parts=[types.Part(text='{"foo": "bar"}')]) + event = Event(content=content, invocation_id="id") + assert _process_content_object(event) == {"foo": "bar"} + + def test_joins_multiple_parts(self): + content = types.Content( + parts=[types.Part(text="hello "), types.Part(text="world")] + ) + event = Event(content=content, invocation_id="id") + assert _process_content_object(event) == "hello world" + + def test_filters_thought_parts(self): + content = types.Content( + parts=[ + types.Part(text="thinking...", thought=True), + types.Part(text='{"answer": 42}'), + ] + ) + event = Event(content=content, invocation_id="id") + assert _process_content_object(event) == {"answer": 42} + + def test_returns_none_for_no_content(self): + event = Event(invocation_id="id") + assert _process_content_object(event) is None + + def test_returns_none_for_empty_text(self): + content = types.Content(parts=[types.Part(text=" ")]) + event = Event(content=content, invocation_id="id") + assert _process_content_object(event) is None + + # --- _reconstruct_node_states --- @@ -188,7 +231,7 @@ def test_scan_message_as_output(self): results = _reconstruct_node_states([event], "/wf@1", invocation_id="test_id", group_by_direct_child=True) assert "node_a@1" in results - assert results["node_a@1"].output == content + assert results["node_a@1"].output == "hello" def test_scan_descendant_interrupts(self): event = Event( diff --git a/tests/unittests/workflow/workflow_testing_utils.py b/tests/unittests/workflow/workflow_testing_utils.py index 4272faf6fe..9b7ea14071 100644 --- a/tests/unittests/workflow/workflow_testing_utils.py +++ b/tests/unittests/workflow/workflow_testing_utils.py @@ -64,6 +64,7 @@ async def run_workflow(wf, message='start'): # The route can be set without the output. This means didn't produce any output # but wants to signal a route to take. class TestingNode(BaseNode): + __test__ = False model_config = ConfigDict(arbitrary_types_allowed=True) output: Optional[Any] = None @@ -97,6 +98,7 @@ async def _run_impl( class TestingNodeWithIntermediateContent(BaseNode): + __test__ = False model_config = ConfigDict(arbitrary_types_allowed=True) intermediate_content: list[types.Content] = Field(default_factory=list)