|
33 | 33 | _custom_event_value, |
34 | 34 | _details_code, |
35 | 35 | _details_message, |
| 36 | + _extract_responses_from_messages, |
36 | 37 | _interrupt_entry_for_request_event, |
37 | 38 | _latest_assistant_contents, |
38 | 39 | _latest_user_text, |
@@ -1172,9 +1173,253 @@ def test_details_without_error_type(self): |
1172 | 1173 | assert _details_code(details) is None |
1173 | 1174 |
|
1174 | 1175 |
|
| 1176 | +class TestExtractResponsesFromMessages: |
| 1177 | + """Tests for _extract_responses_from_messages helper.""" |
| 1178 | + |
| 1179 | + def test_function_result_extracted(self): |
| 1180 | + """function_result content is extracted keyed by call_id.""" |
| 1181 | + result = Content.from_function_result(call_id="call-1", result="ok") |
| 1182 | + messages = [Message(role="tool", contents=[result])] |
| 1183 | + responses = _extract_responses_from_messages(messages) |
| 1184 | + assert responses == {"call-1": "ok"} |
| 1185 | + |
| 1186 | + def test_function_result_without_call_id_skipped(self): |
| 1187 | + """function_result with no call_id is ignored.""" |
| 1188 | + result = Content.from_function_result(call_id="", result="ok") |
| 1189 | + messages = [Message(role="tool", contents=[result])] |
| 1190 | + responses = _extract_responses_from_messages(messages) |
| 1191 | + assert responses == {} |
| 1192 | + |
| 1193 | + def test_function_approval_response_extracted(self): |
| 1194 | + """function_approval_response content is extracted keyed by id.""" |
| 1195 | + func_call = Content.from_function_call( |
| 1196 | + call_id="call-1", |
| 1197 | + name="do_action", |
| 1198 | + arguments={"x": 1}, |
| 1199 | + ) |
| 1200 | + approval = Content.from_function_approval_response( |
| 1201 | + approved=True, |
| 1202 | + id="approval-1", |
| 1203 | + function_call=func_call, |
| 1204 | + ) |
| 1205 | + messages = [Message(role="user", contents=[approval])] |
| 1206 | + responses = _extract_responses_from_messages(messages) |
| 1207 | + assert "approval-1" in responses |
| 1208 | + assert responses["approval-1"]["approved"] is True |
| 1209 | + assert responses["approval-1"]["id"] == "approval-1" |
| 1210 | + assert "function_call" in responses["approval-1"] |
| 1211 | + |
| 1212 | + def test_denied_approval_response_extracted(self): |
| 1213 | + """Denied function_approval_response is extracted with approved=False.""" |
| 1214 | + func_call = Content.from_function_call( |
| 1215 | + call_id="call-2", |
| 1216 | + name="delete_item", |
| 1217 | + arguments={}, |
| 1218 | + ) |
| 1219 | + approval = Content.from_function_approval_response( |
| 1220 | + approved=False, |
| 1221 | + id="approval-2", |
| 1222 | + function_call=func_call, |
| 1223 | + ) |
| 1224 | + messages = [Message(role="user", contents=[approval])] |
| 1225 | + responses = _extract_responses_from_messages(messages) |
| 1226 | + assert "approval-2" in responses |
| 1227 | + assert responses["approval-2"]["approved"] is False |
| 1228 | + |
| 1229 | + def test_mixed_result_and_approval(self): |
| 1230 | + """Both function_result and function_approval_response are extracted.""" |
| 1231 | + result = Content.from_function_result(call_id="call-1", result="done") |
| 1232 | + func_call = Content.from_function_call( |
| 1233 | + call_id="call-2", |
| 1234 | + name="submit", |
| 1235 | + arguments={}, |
| 1236 | + ) |
| 1237 | + approval = Content.from_function_approval_response( |
| 1238 | + approved=True, |
| 1239 | + id="approval-1", |
| 1240 | + function_call=func_call, |
| 1241 | + ) |
| 1242 | + messages = [ |
| 1243 | + Message(role="tool", contents=[result]), |
| 1244 | + Message(role="user", contents=[approval]), |
| 1245 | + ] |
| 1246 | + responses = _extract_responses_from_messages(messages) |
| 1247 | + assert "call-1" in responses |
| 1248 | + assert responses["call-1"] == "done" |
| 1249 | + assert "approval-1" in responses |
| 1250 | + assert responses["approval-1"]["approved"] is True |
| 1251 | + |
| 1252 | + def test_mixed_result_and_approval_same_message(self): |
| 1253 | + """Both function_result and function_approval_response in the same message are extracted.""" |
| 1254 | + result = Content.from_function_result(call_id="call-1", result="done") |
| 1255 | + func_call = Content.from_function_call( |
| 1256 | + call_id="call-2", |
| 1257 | + name="submit", |
| 1258 | + arguments={}, |
| 1259 | + ) |
| 1260 | + approval = Content.from_function_approval_response( |
| 1261 | + approved=True, |
| 1262 | + id="approval-1", |
| 1263 | + function_call=func_call, |
| 1264 | + ) |
| 1265 | + messages = [Message(role="tool", contents=[result, approval])] |
| 1266 | + responses = _extract_responses_from_messages(messages) |
| 1267 | + assert "call-1" in responses |
| 1268 | + assert responses["call-1"] == "done" |
| 1269 | + assert "approval-1" in responses |
| 1270 | + assert responses["approval-1"]["approved"] is True |
| 1271 | + |
| 1272 | + def test_text_content_skipped(self): |
| 1273 | + """Non-result, non-approval content is ignored.""" |
| 1274 | + text = Content.from_text(text="hello") |
| 1275 | + messages = [Message(role="user", contents=[text])] |
| 1276 | + responses = _extract_responses_from_messages(messages) |
| 1277 | + assert responses == {} |
| 1278 | + |
| 1279 | + def test_empty_messages(self): |
| 1280 | + """Empty message list returns empty responses.""" |
| 1281 | + assert _extract_responses_from_messages([]) == {} |
| 1282 | + |
| 1283 | + |
1175 | 1284 | # ── Stream integration tests ── |
1176 | 1285 |
|
1177 | 1286 |
|
| 1287 | +async def test_workflow_run_approval_via_messages_approved() -> None: |
| 1288 | + """Approval response sent via messages (function_approvals) should satisfy the pending request.""" |
| 1289 | + |
| 1290 | + class ApprovalExecutor(Executor): |
| 1291 | + def __init__(self) -> None: |
| 1292 | + super().__init__(id="approval_executor") |
| 1293 | + |
| 1294 | + @handler |
| 1295 | + async def start(self, message: Any, ctx: WorkflowContext) -> None: |
| 1296 | + del message |
| 1297 | + function_call = Content.from_function_call( |
| 1298 | + call_id="refund-call", |
| 1299 | + name="submit_refund", |
| 1300 | + arguments={"order_id": "12345", "amount": "$89.99"}, |
| 1301 | + ) |
| 1302 | + approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call) |
| 1303 | + await ctx.request_info(approval_request, Content, request_id="approval-1") |
| 1304 | + |
| 1305 | + @response_handler |
| 1306 | + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: |
| 1307 | + del original_request |
| 1308 | + status = "approved" if bool(response.approved) else "rejected" |
| 1309 | + await ctx.yield_output(f"Refund {status}.") |
| 1310 | + |
| 1311 | + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() |
| 1312 | + first_events = [ |
| 1313 | + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) |
| 1314 | + ] |
| 1315 | + first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump() |
| 1316 | + interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) |
| 1317 | + assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 |
| 1318 | + |
| 1319 | + # Second turn: send approval via function_approvals on a message (not resume.interrupts) |
| 1320 | + resumed_events = [ |
| 1321 | + event |
| 1322 | + async for event in run_workflow_stream( |
| 1323 | + { |
| 1324 | + "messages": [ |
| 1325 | + { |
| 1326 | + "role": "user", |
| 1327 | + "content": "", |
| 1328 | + "function_approvals": [ |
| 1329 | + { |
| 1330 | + "approved": True, |
| 1331 | + "id": "approval-1", |
| 1332 | + "call_id": "refund-call", |
| 1333 | + "name": "submit_refund", |
| 1334 | + "arguments": {"order_id": "12345", "amount": "$89.99"}, |
| 1335 | + } |
| 1336 | + ], |
| 1337 | + } |
| 1338 | + ], |
| 1339 | + }, |
| 1340 | + workflow, |
| 1341 | + ) |
| 1342 | + ] |
| 1343 | + |
| 1344 | + resumed_types = [event.type for event in resumed_events] |
| 1345 | + assert "RUN_STARTED" in resumed_types |
| 1346 | + assert "RUN_FINISHED" in resumed_types |
| 1347 | + assert "RUN_ERROR" not in resumed_types |
| 1348 | + assert "TEXT_MESSAGE_CONTENT" in resumed_types |
| 1349 | + text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"] |
| 1350 | + assert any("approved" in delta for delta in text_deltas) |
| 1351 | + resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump() |
| 1352 | + assert not resumed_finished.get("interrupt") |
| 1353 | + |
| 1354 | + |
| 1355 | +async def test_workflow_run_approval_via_messages_denied() -> None: |
| 1356 | + """Denied approval response sent via messages (function_approvals) should satisfy the pending request.""" |
| 1357 | + |
| 1358 | + class ApprovalExecutor(Executor): |
| 1359 | + def __init__(self) -> None: |
| 1360 | + super().__init__(id="approval_executor") |
| 1361 | + |
| 1362 | + @handler |
| 1363 | + async def start(self, message: Any, ctx: WorkflowContext) -> None: |
| 1364 | + del message |
| 1365 | + function_call = Content.from_function_call( |
| 1366 | + call_id="delete-call", |
| 1367 | + name="delete_record", |
| 1368 | + arguments={"record_id": "abc"}, |
| 1369 | + ) |
| 1370 | + approval_request = Content.from_function_approval_request(id="deny-1", function_call=function_call) |
| 1371 | + await ctx.request_info(approval_request, Content, request_id="deny-1") |
| 1372 | + |
| 1373 | + @response_handler |
| 1374 | + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: |
| 1375 | + del original_request |
| 1376 | + status = "approved" if bool(response.approved) else "rejected" |
| 1377 | + await ctx.yield_output(f"Delete {status}.") |
| 1378 | + |
| 1379 | + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() |
| 1380 | + first_events = [ |
| 1381 | + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) |
| 1382 | + ] |
| 1383 | + first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump() |
| 1384 | + interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) |
| 1385 | + assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 |
| 1386 | + |
| 1387 | + # Second turn: send denial via function_approvals on a message (not resume.interrupts) |
| 1388 | + resumed_events = [ |
| 1389 | + event |
| 1390 | + async for event in run_workflow_stream( |
| 1391 | + { |
| 1392 | + "messages": [ |
| 1393 | + { |
| 1394 | + "role": "user", |
| 1395 | + "content": "", |
| 1396 | + "function_approvals": [ |
| 1397 | + { |
| 1398 | + "approved": False, |
| 1399 | + "id": "deny-1", |
| 1400 | + "call_id": "delete-call", |
| 1401 | + "name": "delete_record", |
| 1402 | + "arguments": {"record_id": "abc"}, |
| 1403 | + } |
| 1404 | + ], |
| 1405 | + } |
| 1406 | + ], |
| 1407 | + }, |
| 1408 | + workflow, |
| 1409 | + ) |
| 1410 | + ] |
| 1411 | + |
| 1412 | + resumed_types = [event.type for event in resumed_events] |
| 1413 | + assert "RUN_STARTED" in resumed_types |
| 1414 | + assert "RUN_FINISHED" in resumed_types |
| 1415 | + assert "RUN_ERROR" not in resumed_types |
| 1416 | + assert "TEXT_MESSAGE_CONTENT" in resumed_types |
| 1417 | + text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"] |
| 1418 | + assert any("rejected" in delta for delta in text_deltas) |
| 1419 | + resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump() |
| 1420 | + assert not resumed_finished.get("interrupt") |
| 1421 | + |
| 1422 | + |
1178 | 1423 | async def test_workflow_run_available_interrupts_logged(): |
1179 | 1424 | """available_interrupts in input data should be logged without errors.""" |
1180 | 1425 |
|
|
0 commit comments