Skip to content

Commit 2f2495e

Browse files
moonbox3Copilot
andauthored
Python: Fix function_approval_response extraction in AG-UI workflow path (#4550)
* Extract function_approval_response from workflow messages (#4546) _extract_responses_from_messages now handles function_approval_response content in addition to function_result content. Previously, approval responses sent via the messages field were silently dropped because the function only checked for content.type == "function_result". The approval response is keyed by content.id and includes the approved status, id, and serialized function_call — consistent with how _coerce_content identifies approval response payloads. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Fix #4546: Update docstring and add integration tests for message-based approvals - Update _extract_responses_from_messages docstring to reflect that it now handles function_approval_response content in addition to function_result content. - Add integration tests for run_workflow_stream across two turns with approval responses provided via messages (function_approvals) rather than resume.interrupts, covering both approved and denied scenarios. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review feedback for #4546 - Use safer 'not .get("interrupt")' assertion instead of 'not in' to handle Pydantic v2 model_dump() including keys with None values - Add unit test for mixed function_result and function_approval_response in the same message to TestExtractResponsesFromMessages Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent e5d6e8c commit 2f2495e

2 files changed

Lines changed: 264 additions & 5 deletions

File tree

python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,28 @@ def _request_payload_from_request_event(request_event: Any) -> dict[str, Any] |
124124

125125

126126
def _extract_responses_from_messages(messages: list[Message]) -> dict[str, Any]:
127-
"""Extract request-info responses from incoming tool/function-result messages."""
127+
"""Extract request-info responses from incoming messages.
128+
129+
Handles both ``function_result`` content (keyed by ``call_id``) and
130+
``function_approval_response`` content (keyed by ``id``), so that
131+
approval decisions sent via messages are forwarded into the workflow
132+
responses map.
133+
"""
128134
responses: dict[str, Any] = {}
129135
for message in messages:
130136
for content in message.contents:
131-
if content.type != "function_result" or not content.call_id:
132-
continue
133-
value = _coerce_json_value(content.result)
134-
responses[str(content.call_id)] = value
137+
if content.type == "function_result" and content.call_id:
138+
value = _coerce_json_value(content.result)
139+
responses[str(content.call_id)] = value
140+
elif content.type == "function_approval_response" and getattr(content, "id", None):
141+
approval_value: dict[str, Any] = {
142+
"approved": getattr(content, "approved", False),
143+
"id": str(content.id), # type: ignore[union-attr]
144+
}
145+
func_call = getattr(content, "function_call", None)
146+
if func_call is not None:
147+
approval_value["function_call"] = make_json_safe(func_call.to_dict())
148+
responses[str(content.id)] = approval_value # type: ignore[union-attr]
135149
return responses
136150

137151

python/packages/ag-ui/tests/ag_ui/test_workflow_run.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_custom_event_value,
3434
_details_code,
3535
_details_message,
36+
_extract_responses_from_messages,
3637
_interrupt_entry_for_request_event,
3738
_latest_assistant_contents,
3839
_latest_user_text,
@@ -1172,9 +1173,253 @@ def test_details_without_error_type(self):
11721173
assert _details_code(details) is None
11731174

11741175

1176+
class TestExtractResponsesFromMessages:
1177+
"""Tests for _extract_responses_from_messages helper."""
1178+
1179+
def test_function_result_extracted(self):
1180+
"""function_result content is extracted keyed by call_id."""
1181+
result = Content.from_function_result(call_id="call-1", result="ok")
1182+
messages = [Message(role="tool", contents=[result])]
1183+
responses = _extract_responses_from_messages(messages)
1184+
assert responses == {"call-1": "ok"}
1185+
1186+
def test_function_result_without_call_id_skipped(self):
1187+
"""function_result with no call_id is ignored."""
1188+
result = Content.from_function_result(call_id="", result="ok")
1189+
messages = [Message(role="tool", contents=[result])]
1190+
responses = _extract_responses_from_messages(messages)
1191+
assert responses == {}
1192+
1193+
def test_function_approval_response_extracted(self):
1194+
"""function_approval_response content is extracted keyed by id."""
1195+
func_call = Content.from_function_call(
1196+
call_id="call-1",
1197+
name="do_action",
1198+
arguments={"x": 1},
1199+
)
1200+
approval = Content.from_function_approval_response(
1201+
approved=True,
1202+
id="approval-1",
1203+
function_call=func_call,
1204+
)
1205+
messages = [Message(role="user", contents=[approval])]
1206+
responses = _extract_responses_from_messages(messages)
1207+
assert "approval-1" in responses
1208+
assert responses["approval-1"]["approved"] is True
1209+
assert responses["approval-1"]["id"] == "approval-1"
1210+
assert "function_call" in responses["approval-1"]
1211+
1212+
def test_denied_approval_response_extracted(self):
1213+
"""Denied function_approval_response is extracted with approved=False."""
1214+
func_call = Content.from_function_call(
1215+
call_id="call-2",
1216+
name="delete_item",
1217+
arguments={},
1218+
)
1219+
approval = Content.from_function_approval_response(
1220+
approved=False,
1221+
id="approval-2",
1222+
function_call=func_call,
1223+
)
1224+
messages = [Message(role="user", contents=[approval])]
1225+
responses = _extract_responses_from_messages(messages)
1226+
assert "approval-2" in responses
1227+
assert responses["approval-2"]["approved"] is False
1228+
1229+
def test_mixed_result_and_approval(self):
1230+
"""Both function_result and function_approval_response are extracted."""
1231+
result = Content.from_function_result(call_id="call-1", result="done")
1232+
func_call = Content.from_function_call(
1233+
call_id="call-2",
1234+
name="submit",
1235+
arguments={},
1236+
)
1237+
approval = Content.from_function_approval_response(
1238+
approved=True,
1239+
id="approval-1",
1240+
function_call=func_call,
1241+
)
1242+
messages = [
1243+
Message(role="tool", contents=[result]),
1244+
Message(role="user", contents=[approval]),
1245+
]
1246+
responses = _extract_responses_from_messages(messages)
1247+
assert "call-1" in responses
1248+
assert responses["call-1"] == "done"
1249+
assert "approval-1" in responses
1250+
assert responses["approval-1"]["approved"] is True
1251+
1252+
def test_mixed_result_and_approval_same_message(self):
1253+
"""Both function_result and function_approval_response in the same message are extracted."""
1254+
result = Content.from_function_result(call_id="call-1", result="done")
1255+
func_call = Content.from_function_call(
1256+
call_id="call-2",
1257+
name="submit",
1258+
arguments={},
1259+
)
1260+
approval = Content.from_function_approval_response(
1261+
approved=True,
1262+
id="approval-1",
1263+
function_call=func_call,
1264+
)
1265+
messages = [Message(role="tool", contents=[result, approval])]
1266+
responses = _extract_responses_from_messages(messages)
1267+
assert "call-1" in responses
1268+
assert responses["call-1"] == "done"
1269+
assert "approval-1" in responses
1270+
assert responses["approval-1"]["approved"] is True
1271+
1272+
def test_text_content_skipped(self):
1273+
"""Non-result, non-approval content is ignored."""
1274+
text = Content.from_text(text="hello")
1275+
messages = [Message(role="user", contents=[text])]
1276+
responses = _extract_responses_from_messages(messages)
1277+
assert responses == {}
1278+
1279+
def test_empty_messages(self):
1280+
"""Empty message list returns empty responses."""
1281+
assert _extract_responses_from_messages([]) == {}
1282+
1283+
11751284
# ── Stream integration tests ──
11761285

11771286

1287+
async def test_workflow_run_approval_via_messages_approved() -> None:
1288+
"""Approval response sent via messages (function_approvals) should satisfy the pending request."""
1289+
1290+
class ApprovalExecutor(Executor):
1291+
def __init__(self) -> None:
1292+
super().__init__(id="approval_executor")
1293+
1294+
@handler
1295+
async def start(self, message: Any, ctx: WorkflowContext) -> None:
1296+
del message
1297+
function_call = Content.from_function_call(
1298+
call_id="refund-call",
1299+
name="submit_refund",
1300+
arguments={"order_id": "12345", "amount": "$89.99"},
1301+
)
1302+
approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call)
1303+
await ctx.request_info(approval_request, Content, request_id="approval-1")
1304+
1305+
@response_handler
1306+
async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None:
1307+
del original_request
1308+
status = "approved" if bool(response.approved) else "rejected"
1309+
await ctx.yield_output(f"Refund {status}.")
1310+
1311+
workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build()
1312+
first_events = [
1313+
event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)
1314+
]
1315+
first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump()
1316+
interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt"))
1317+
assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1
1318+
1319+
# Second turn: send approval via function_approvals on a message (not resume.interrupts)
1320+
resumed_events = [
1321+
event
1322+
async for event in run_workflow_stream(
1323+
{
1324+
"messages": [
1325+
{
1326+
"role": "user",
1327+
"content": "",
1328+
"function_approvals": [
1329+
{
1330+
"approved": True,
1331+
"id": "approval-1",
1332+
"call_id": "refund-call",
1333+
"name": "submit_refund",
1334+
"arguments": {"order_id": "12345", "amount": "$89.99"},
1335+
}
1336+
],
1337+
}
1338+
],
1339+
},
1340+
workflow,
1341+
)
1342+
]
1343+
1344+
resumed_types = [event.type for event in resumed_events]
1345+
assert "RUN_STARTED" in resumed_types
1346+
assert "RUN_FINISHED" in resumed_types
1347+
assert "RUN_ERROR" not in resumed_types
1348+
assert "TEXT_MESSAGE_CONTENT" in resumed_types
1349+
text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"]
1350+
assert any("approved" in delta for delta in text_deltas)
1351+
resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump()
1352+
assert not resumed_finished.get("interrupt")
1353+
1354+
1355+
async def test_workflow_run_approval_via_messages_denied() -> None:
1356+
"""Denied approval response sent via messages (function_approvals) should satisfy the pending request."""
1357+
1358+
class ApprovalExecutor(Executor):
1359+
def __init__(self) -> None:
1360+
super().__init__(id="approval_executor")
1361+
1362+
@handler
1363+
async def start(self, message: Any, ctx: WorkflowContext) -> None:
1364+
del message
1365+
function_call = Content.from_function_call(
1366+
call_id="delete-call",
1367+
name="delete_record",
1368+
arguments={"record_id": "abc"},
1369+
)
1370+
approval_request = Content.from_function_approval_request(id="deny-1", function_call=function_call)
1371+
await ctx.request_info(approval_request, Content, request_id="deny-1")
1372+
1373+
@response_handler
1374+
async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None:
1375+
del original_request
1376+
status = "approved" if bool(response.approved) else "rejected"
1377+
await ctx.yield_output(f"Delete {status}.")
1378+
1379+
workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build()
1380+
first_events = [
1381+
event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)
1382+
]
1383+
first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump()
1384+
interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt"))
1385+
assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1
1386+
1387+
# Second turn: send denial via function_approvals on a message (not resume.interrupts)
1388+
resumed_events = [
1389+
event
1390+
async for event in run_workflow_stream(
1391+
{
1392+
"messages": [
1393+
{
1394+
"role": "user",
1395+
"content": "",
1396+
"function_approvals": [
1397+
{
1398+
"approved": False,
1399+
"id": "deny-1",
1400+
"call_id": "delete-call",
1401+
"name": "delete_record",
1402+
"arguments": {"record_id": "abc"},
1403+
}
1404+
],
1405+
}
1406+
],
1407+
},
1408+
workflow,
1409+
)
1410+
]
1411+
1412+
resumed_types = [event.type for event in resumed_events]
1413+
assert "RUN_STARTED" in resumed_types
1414+
assert "RUN_FINISHED" in resumed_types
1415+
assert "RUN_ERROR" not in resumed_types
1416+
assert "TEXT_MESSAGE_CONTENT" in resumed_types
1417+
text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"]
1418+
assert any("rejected" in delta for delta in text_deltas)
1419+
resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump()
1420+
assert not resumed_finished.get("interrupt")
1421+
1422+
11781423
async def test_workflow_run_available_interrupts_logged():
11791424
"""available_interrupts in input data should be logged without errors."""
11801425

0 commit comments

Comments
 (0)