Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ Typical hook timing:
- `on_llm_start` / `on_llm_end`: immediately around each model call.
- `on_tool_start` / `on_tool_end`: around each local tool invocation.
For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`.
- `on_tool_progress`: when a tool emits a mid-execution progress update via `await ctx.send_progress(data)`.
- `on_handoff`: when control moves from one agent to another.

Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects.
Expand Down
86 changes: 86 additions & 0 deletions examples/basic/stream_tool_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Example: tool progress via on_tool_progress hooks.

Demonstrates how tools can emit intermediate progress updates using
await ctx.send_progress(data), consumed via RunHooks.on_tool_progress.
"""

import asyncio

from agents import Agent, RunHooks, Runner, function_tool
from agents.tool import Tool
from agents.tool_context import ToolContext


@function_tool
async def analyze_data(ctx: ToolContext, query: str) -> str:
"""Simulate a long-running data analysis task with progress updates."""
await ctx.send_progress({"status": "starting", "query": query})
await asyncio.sleep(1)

await ctx.send_progress({"status": "fetching_data", "progress": 0.25})
await asyncio.sleep(1)

await ctx.send_progress({"status": "processing", "progress": 0.5})
await asyncio.sleep(1)

await ctx.send_progress({"status": "finalizing", "progress": 1.0})
await asyncio.sleep(0.5)

return f"Analysis complete for '{query}': found 42 results with 95% confidence."


@function_tool
async def quick_lookup(ctx: ToolContext, term: str) -> str:
"""A faster tool that also emits progress."""
await ctx.send_progress({"status": "searching", "term": term})
await asyncio.sleep(0.5)
return f"Found definition for '{term}': a common search term."


class ProgressHooks(RunHooks):
async def on_tool_progress(self, ctx, agent, tool: Tool, data):
print(f" [progress] {tool.name}: {data}")


async def main():
agent = Agent(
name="Analyst",
instructions=(
"You are a data analyst. Use the analyze_data tool for complex queries "
"and quick_lookup for simple lookups. Always use the tools when asked."
),
tools=[analyze_data, quick_lookup],
)

hooks = ProgressHooks()

print("Interactive tool progress example (hooks-based).")
print("Type a message to chat, or 'quit' to exit.\n")

while True:
user_input = input("You: ").strip()
if not user_input or user_input.lower() == "quit":
print("Goodbye!")
break

result = Runner.run_streamed(agent, input=user_input, hooks=hooks)
async for event in result.stream_events():
if event.type == "raw_response_event":
data = event.data
if getattr(data, "type", None) == "response.output_text.delta":
print(getattr(data, "delta", ""), end="", flush=True)
elif event.type == "agent_updated_stream_event":
print(f"Agent: {event.new_agent.name}")
elif event.type == "run_item_stream_event":
if event.item.type == "tool_call_item":
print(f"\n-- Tool called: {getattr(event.item.raw_item, 'name', '?')}")
elif event.item.type == "tool_call_output_item":
print(f"\n-- Tool output: {event.item.output}")
elif event.item.type == "message_output_item":
print() # newline after streamed tokens

print()


if __name__ == "__main__":
asyncio.run(main())
30 changes: 30 additions & 0 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ async def on_tool_end(
"""
pass

async def on_tool_progress(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
data: Any,
) -> None:
"""Called when a tool emits a progress update via ``send_progress()``.

Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries,
this fires from inside the tool body at arbitrary points. For function-tool
invocations, ``context`` is typically a ``ToolContext``.
"""
pass


class AgentHooksBase(Generic[TContext, TAgent]):
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
Expand Down Expand Up @@ -172,6 +187,21 @@ async def on_tool_end(
"""
pass

async def on_tool_progress(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
data: Any,
) -> None:
"""Called when a tool emits a progress update via ``send_progress()``.

Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries,
this fires from inside the tool body at arbitrary points. For function-tool
invocations, ``context`` is typically a ``ToolContext``.
"""
pass

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
Expand Down
15 changes: 15 additions & 0 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,21 @@ async def _run_single_tool(
run_config=self.config,
)
agent_hooks = self.public_agent.hooks

async def _send_progress(data: Any) -> None:
await asyncio.gather(
self.hooks.on_tool_progress(tool_context, self.public_agent, func_tool, data),
(
agent_hooks.on_tool_progress(
tool_context, self.public_agent, func_tool, data
)
if agent_hooks
else _coro.noop_coroutine()
),
)

tool_context.set_progress_fn(_send_progress)

if self.config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments

Expand Down
15 changes: 15 additions & 0 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -104,11 +105,25 @@ def __init__(
self.agent = agent
self.run_config = run_config

_progress_fn: Callable[[Any], Awaitable[None]] | None = None

@property
def qualified_tool_name(self) -> str:
"""Return the tool name qualified by namespace when available."""
return tool_trace_name(self.tool_name, self.tool_namespace) or self.tool_name

async def send_progress(self, data: Any) -> None:
"""Emit a progress update, firing ``on_tool_progress`` hooks.

No-op if no progress handler has been set by the framework.
"""
if self._progress_fn is not None:
await self._progress_fn(data)

def set_progress_fn(self, fn: Callable[[Any], Awaitable[None]]) -> None:
"""Set the progress handler. Called by the framework during tool invocation."""
self._progress_fn = fn

@classmethod
def from_agent_context(
cls,
Expand Down
Loading
Loading