diff --git a/docs/agents.md b/docs/agents.md index f1878559b2..1db33ec63a 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -263,6 +263,7 @@ Typical hook timing: - `on_llm_start` / `on_llm_end`: immediately around each model call. - `on_tool_start` / `on_tool_end`: around each local tool invocation. For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`. +- `on_tool_progress`: when a tool emits a mid-execution progress update via `await ctx.send_progress(data)`. - `on_handoff`: when control moves from one agent to another. Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects. diff --git a/examples/basic/stream_tool_progress.py b/examples/basic/stream_tool_progress.py new file mode 100644 index 0000000000..39776f5497 --- /dev/null +++ b/examples/basic/stream_tool_progress.py @@ -0,0 +1,86 @@ +"""Example: tool progress via on_tool_progress hooks. + +Demonstrates how tools can emit intermediate progress updates using +await ctx.send_progress(data), consumed via RunHooks.on_tool_progress. +""" + +import asyncio + +from agents import Agent, RunHooks, Runner, function_tool +from agents.tool import Tool +from agents.tool_context import ToolContext + + +@function_tool +async def analyze_data(ctx: ToolContext, query: str) -> str: + """Simulate a long-running data analysis task with progress updates.""" + await ctx.send_progress({"status": "starting", "query": query}) + await asyncio.sleep(1) + + await ctx.send_progress({"status": "fetching_data", "progress": 0.25}) + await asyncio.sleep(1) + + await ctx.send_progress({"status": "processing", "progress": 0.5}) + await asyncio.sleep(1) + + await ctx.send_progress({"status": "finalizing", "progress": 1.0}) + await asyncio.sleep(0.5) + + return f"Analysis complete for '{query}': found 42 results with 95% confidence." + + +@function_tool +async def quick_lookup(ctx: ToolContext, term: str) -> str: + """A faster tool that also emits progress.""" + await ctx.send_progress({"status": "searching", "term": term}) + await asyncio.sleep(0.5) + return f"Found definition for '{term}': a common search term." + + +class ProgressHooks(RunHooks): + async def on_tool_progress(self, ctx, agent, tool: Tool, data): + print(f" [progress] {tool.name}: {data}") + + +async def main(): + agent = Agent( + name="Analyst", + instructions=( + "You are a data analyst. Use the analyze_data tool for complex queries " + "and quick_lookup for simple lookups. Always use the tools when asked." + ), + tools=[analyze_data, quick_lookup], + ) + + hooks = ProgressHooks() + + print("Interactive tool progress example (hooks-based).") + print("Type a message to chat, or 'quit' to exit.\n") + + while True: + user_input = input("You: ").strip() + if not user_input or user_input.lower() == "quit": + print("Goodbye!") + break + + result = Runner.run_streamed(agent, input=user_input, hooks=hooks) + async for event in result.stream_events(): + if event.type == "raw_response_event": + data = event.data + if getattr(data, "type", None) == "response.output_text.delta": + print(getattr(data, "delta", ""), end="", flush=True) + elif event.type == "agent_updated_stream_event": + print(f"Agent: {event.new_agent.name}") + elif event.type == "run_item_stream_event": + if event.item.type == "tool_call_item": + print(f"\n-- Tool called: {getattr(event.item.raw_item, 'name', '?')}") + elif event.item.type == "tool_call_output_item": + print(f"\n-- Tool output: {event.item.output}") + elif event.item.type == "message_output_item": + print() # newline after streamed tokens + + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 2ca7484739..9c4fce6f0f 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -98,6 +98,21 @@ async def on_tool_end( """ pass + async def on_tool_progress( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + tool: Tool, + data: Any, + ) -> None: + """Called when a tool emits a progress update via ``send_progress()``. + + Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries, + this fires from inside the tool body at arbitrary points. For function-tool + invocations, ``context`` is typically a ``ToolContext``. + """ + pass + class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can @@ -172,6 +187,21 @@ async def on_tool_end( """ pass + async def on_tool_progress( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + tool: Tool, + data: Any, + ) -> None: + """Called when a tool emits a progress update via ``send_progress()``. + + Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries, + this fires from inside the tool body at arbitrary points. For function-tool + invocations, ``context`` is typically a ``ToolContext``. + """ + pass + async def on_llm_start( self, context: RunContextWrapper[TContext], diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index 8f30e4a01f..4106ac9ca0 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -1582,6 +1582,21 @@ async def _run_single_tool( run_config=self.config, ) agent_hooks = self.public_agent.hooks + + async def _send_progress(data: Any) -> None: + await asyncio.gather( + self.hooks.on_tool_progress(tool_context, self.public_agent, func_tool, data), + ( + agent_hooks.on_tool_progress( + tool_context, self.public_agent, func_tool, data + ) + if agent_hooks + else _coro.noop_coroutine() + ), + ) + + tool_context.set_progress_fn(_send_progress) + if self.config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index eaad0cc167..5224e7ad38 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any, cast @@ -104,11 +105,25 @@ def __init__( self.agent = agent self.run_config = run_config + _progress_fn: Callable[[Any], Awaitable[None]] | None = None + @property def qualified_tool_name(self) -> str: """Return the tool name qualified by namespace when available.""" return tool_trace_name(self.tool_name, self.tool_namespace) or self.tool_name + async def send_progress(self, data: Any) -> None: + """Emit a progress update, firing ``on_tool_progress`` hooks. + + No-op if no progress handler has been set by the framework. + """ + if self._progress_fn is not None: + await self._progress_fn(data) + + def set_progress_fn(self, fn: Callable[[Any], Awaitable[None]]) -> None: + """Set the progress handler. Called by the framework during tool invocation.""" + self._progress_fn = fn + @classmethod def from_agent_context( cls, diff --git a/tests/test_tool_progress.py b/tests/test_tool_progress.py new file mode 100644 index 0000000000..b2b031e749 --- /dev/null +++ b/tests/test_tool_progress.py @@ -0,0 +1,264 @@ +"""Tests for ToolContext.send_progress and on_tool_progress hooks.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from agents import Agent, AgentHooks, RunHooks, Runner, function_tool +from agents.tool import Tool +from agents.tool_context import ToolContext + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + + +class TestSendProgress: + @pytest.mark.asyncio + async def test_send_progress_fires_callback(self) -> None: + """send_progress calls the _progress_fn when set.""" + received: list[Any] = [] + + async def _on_progress(data: Any) -> None: + received.append(data) + + ctx: ToolContext[None] = ToolContext( + context=None, + tool_name="test_tool", + tool_call_id="call-1", + tool_arguments="{}", + ) + ctx.set_progress_fn(_on_progress) + await ctx.send_progress({"status": "working"}) + assert received == [{"status": "working"}] + + @pytest.mark.asyncio + async def test_send_progress_noop_without_fn(self) -> None: + """send_progress is a no-op when _progress_fn is None.""" + ctx: ToolContext[None] = ToolContext( + context=None, + tool_name="test_tool", + tool_call_id="call-1", + tool_arguments="{}", + ) + await ctx.send_progress({"status": "working"}) + + @pytest.mark.asyncio + async def test_send_progress_multiple_events(self) -> None: + """Multiple send_progress calls all arrive in order.""" + received: list[Any] = [] + + async def _on_progress(data: Any) -> None: + received.append(data) + + ctx: ToolContext[None] = ToolContext( + context=None, + tool_name="test_tool", + tool_call_id="call-1", + tool_arguments="{}", + ) + ctx.set_progress_fn(_on_progress) + await ctx.send_progress({"step": 1}) + await ctx.send_progress({"step": 2}) + await ctx.send_progress({"step": 3}) + assert len(received) == 3 + assert received[0] == {"step": 1} + assert received[1] == {"step": 2} + assert received[2] == {"step": 3} + + +class _CollectingRunHooks(RunHooks[None]): + def __init__(self) -> None: + self.progress_events: list[dict[str, Any]] = [] + + async def on_tool_progress( + self, + context: Any, + agent: Any, + tool: Tool, + data: Any, + ) -> None: + self.progress_events.append({"tool_name": tool.name, "data": data}) + + +class _CollectingAgentHooks(AgentHooks[None]): + def __init__(self) -> None: + self.progress_events: list[dict[str, Any]] = [] + + async def on_tool_progress( + self, + context: Any, + agent: Any, + tool: Tool, + data: Any, + ) -> None: + self.progress_events.append({"tool_name": tool.name, "data": data}) + + +class TestHooksIntegration: + @pytest.mark.asyncio + async def test_run_hooks_fire_on_progress(self) -> None: + """on_tool_progress fires on RunHooks during Runner.run().""" + + async def _progress_fn(ctx: ToolContext) -> str: + await ctx.send_progress({"status": "working"}) + return "done" + + tool = function_tool(_progress_fn, name_override="progress_tool") + hooks = _CollectingRunHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("progress_tool", "{}")], + [get_text_message("final")], + ] + ) + + await Runner.run(agent, input="test", hooks=hooks) + assert len(hooks.progress_events) == 1 + assert hooks.progress_events[0]["tool_name"] == "progress_tool" + assert hooks.progress_events[0]["data"] == {"status": "working"} + + @pytest.mark.asyncio + async def test_agent_hooks_fire_on_progress(self) -> None: + """on_tool_progress fires on AgentHooks during Runner.run().""" + + async def _progress_fn(ctx: ToolContext) -> str: + await ctx.send_progress({"status": "working"}) + return "done" + + tool = function_tool(_progress_fn, name_override="progress_tool") + agent_hooks = _CollectingAgentHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool], hooks=agent_hooks) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("progress_tool", "{}")], + [get_text_message("final")], + ] + ) + + await Runner.run(agent, input="test") + assert len(agent_hooks.progress_events) == 1 + assert agent_hooks.progress_events[0]["data"] == {"status": "working"} + + @pytest.mark.asyncio + async def test_both_hooks_fire(self) -> None: + """Both RunHooks and AgentHooks on_tool_progress fire.""" + + async def _progress_fn(ctx: ToolContext) -> str: + await ctx.send_progress({"status": "both"}) + return "done" + + tool = function_tool(_progress_fn, name_override="progress_tool") + run_hooks = _CollectingRunHooks() + agent_hooks = _CollectingAgentHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool], hooks=agent_hooks) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("progress_tool", "{}")], + [get_text_message("final")], + ] + ) + + await Runner.run(agent, input="test", hooks=run_hooks) + assert len(run_hooks.progress_events) == 1 + assert len(agent_hooks.progress_events) == 1 + + @pytest.mark.asyncio + async def test_progress_in_streamed_run(self) -> None: + """on_tool_progress hooks fire during Runner.run_streamed().""" + + async def _progress_fn(ctx: ToolContext) -> str: + await ctx.send_progress({"status": "starting"}) + await ctx.send_progress({"status": "done"}) + return "result" + + tool = function_tool(_progress_fn, name_override="progress_tool") + hooks = _CollectingRunHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("progress_tool", "{}")], + [get_text_message("final")], + ] + ) + + result = Runner.run_streamed(agent, input="test", hooks=hooks) + async for _ in result.stream_events(): + pass + + assert len(hooks.progress_events) == 2 + assert hooks.progress_events[0]["data"] == {"status": "starting"} + assert hooks.progress_events[1]["data"] == {"status": "done"} + + @pytest.mark.asyncio + async def test_parallel_tools_with_progress(self) -> None: + """Parallel tools report progress with correct tool identity.""" + + async def _tool_a(ctx: ToolContext) -> str: + await ctx.send_progress({"tool": "a"}) + return "a_done" + + async def _tool_b(ctx: ToolContext) -> str: + await ctx.send_progress({"tool": "b"}) + return "b_done" + + tool_a = function_tool(_tool_a, name_override="tool_a") + tool_b = function_tool(_tool_b, name_override="tool_b") + hooks = _CollectingRunHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool_a, tool_b]) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("tool_a", "{}", call_id="call_a"), + get_function_tool_call("tool_b", "{}", call_id="call_b"), + ], + [get_text_message("final")], + ] + ) + + await Runner.run(agent, input="test", hooks=hooks) + a_events = [e for e in hooks.progress_events if e["tool_name"] == "tool_a"] + b_events = [e for e in hooks.progress_events if e["tool_name"] == "tool_b"] + assert len(a_events) == 1 + assert len(b_events) == 1 + + @pytest.mark.asyncio + async def test_multiple_progress_events_in_order(self) -> None: + """Multiple progress events arrive in emission order.""" + + async def _progress_fn(ctx: ToolContext) -> str: + await ctx.send_progress({"step": 1}) + await ctx.send_progress({"step": 2}) + await ctx.send_progress({"step": 3}) + return "done" + + tool = function_tool(_progress_fn, name_override="progress_tool") + hooks = _CollectingRunHooks() + model = FakeModel() + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("progress_tool", "{}")], + [get_text_message("final")], + ] + ) + + await Runner.run(agent, input="test", hooks=hooks) + assert [e["data"] for e in hooks.progress_events] == [ + {"step": 1}, + {"step": 2}, + {"step": 3}, + ]