diff --git a/go/adk/pkg/tools/remote_a2a_tool.go b/go/adk/pkg/tools/remote_a2a_tool.go index 87afe88a8..7fa20c1fc 100644 --- a/go/adk/pkg/tools/remote_a2a_tool.go +++ b/go/adk/pkg/tools/remote_a2a_tool.go @@ -22,6 +22,33 @@ import ( // userIDContextKey is the context key for passing the session user_id to the subagent. type userIDContextKey struct{} +// parentContextIDContextKey is the context key carrying this agent's own +// A2A context_id (== ADK session id) into the outbound interceptor so it can +// be stamped as the parent_context_id header on every outbound A2A call. +type parentContextIDContextKey struct{} + +// Conversation-lineage headers stamped on outbound A2A calls so a remote +// agent can correlate this turn with the originating chat conversation - +// useful when downstream code keys per-conversation state (sessions, sandbox +// pods, cache entries) on a stable identifier across A2A hops. +// +// ParentContextIDHeader is the immediate caller's A2A context_id (the +// session id of the agent that ran this tool). It changes with every hop in +// a chain of A2A calls. +// +// RootContextIDHeader is the top-of-chain context_id - the agent at the +// start of the chain (typically the user-facing chat agent). It stays +// stable across every hop and across every turn of the same conversation, +// so downstream agents can key state that should outlive a single A2A call +// (e.g. claim a per-conversation worker pod that survives between turns). +// +// Mirrors the Python ADK constants in +// python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py. +const ( + ParentContextIDHeader = "x-kagent-parent-context-id" + RootContextIDHeader = "x-kagent-root-context-id" +) + // userIDForwardingInterceptor forwards the session user_id as an x-user-id header. type userIDForwardingInterceptor struct { a2aclient.PassthroughInterceptor @@ -34,6 +61,56 @@ func (u *userIDForwardingInterceptor) Before(ctx context.Context, req *a2aclient return ctx, nil } +// lineageHeadersInterceptor stamps the parent + root context_id headers on +// every outbound A2A call. Parent comes from a context value populated by the +// caller (the tool's own ADK session id). Root is forwarded unchanged from the +// inbound A2A request when present (so the value set by the agent at the start +// of the chain survives every hop), with a legacy fallback to the inbound +// parent header for older callers, and a final fallback to our own session id +// when this agent is the chain root. +// +// Pre-existing headers on req.Meta win (analogous to Python's header_provider +// override), so a caller that sets extraHeaders for either header keeps full +// control. +type lineageHeadersInterceptor struct { + a2aclient.PassthroughInterceptor +} + +func (l *lineageHeadersInterceptor) Before(ctx context.Context, req *a2aclient.Request) (context.Context, error) { + parent, _ := ctx.Value(parentContextIDContextKey{}).(string) + if parent == "" { + return ctx, nil + } + + var inboundRoot, inboundParent string + if callCtx, ok := a2asrv.CallContextFrom(ctx); ok { + if meta := callCtx.RequestMeta(); meta != nil { + if vals, ok := meta.Get(RootContextIDHeader); ok && len(vals) > 0 { + inboundRoot = vals[0] + } + if vals, ok := meta.Get(ParentContextIDHeader); ok && len(vals) > 0 { + inboundParent = vals[0] + } + } + } + + root := inboundRoot + if root == "" { + root = inboundParent + } + if root == "" { + root = parent + } + + if len(req.Meta.Get(ParentContextIDHeader)) == 0 { + req.Meta.Append(ParentContextIDHeader, parent) + } + if len(req.Meta.Get(RootContextIDHeader)) == 0 { + req.Meta.Append(RootContextIDHeader, root) + } + return ctx, nil +} + // authzForwardingInterceptor forwards the Authorization header from the // incoming A2A request context to outbound sub-agent A2A calls. type authzForwardingInterceptor struct { @@ -150,6 +227,7 @@ func (s *remoteA2AState) ensureClient(ctx context.Context) (*a2aclient.Client, e interceptors := []a2aclient.CallInterceptor{ a2aclient.NewStaticCallMetaInjector(meta), &userIDForwardingInterceptor{}, + &lineageHeadersInterceptor{}, } if s.propagateToken { interceptors = append(interceptors, &authzForwardingInterceptor{}) @@ -192,6 +270,7 @@ func (s *remoteA2AState) handleFirstCall(ctx tool.Context, requestText string) ( message.ContextID = s.lastContextID sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID()) + sendCtx = context.WithValue(sendCtx, parentContextIDContextKey{}, ctx.SessionID()) result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message}) if err != nil { slog.Error("Remote agent request failed", "tool", s.name, "error", err) @@ -242,6 +321,7 @@ func (s *remoteA2AState) handleResume(ctx tool.Context) (map[string]any, error) } sendCtx := context.WithValue(ctx, userIDContextKey{}, ctx.UserID()) + sendCtx = context.WithValue(sendCtx, parentContextIDContextKey{}, ctx.SessionID()) result, err := client.SendMessage(sendCtx, &a2atype.MessageSendParams{Message: message}) if err != nil { slog.Error("Remote agent resume failed", "tool", subagentName, "error", err) diff --git a/go/adk/pkg/tools/remote_a2a_tool_test.go b/go/adk/pkg/tools/remote_a2a_tool_test.go new file mode 100644 index 000000000..97580c351 --- /dev/null +++ b/go/adk/pkg/tools/remote_a2a_tool_test.go @@ -0,0 +1,123 @@ +package tools + +import ( + "context" + "testing" + + "github.com/a2aproject/a2a-go/a2aclient" + "github.com/a2aproject/a2a-go/a2asrv" +) + +// newReq returns an empty outbound client Request with an initialized CallMeta. +func newReq() *a2aclient.Request { + return &a2aclient.Request{Meta: a2aclient.CallMeta{}} +} + +// withCallContext returns a context that carries an a2asrv.CallContext whose +// RequestMeta exposes the given inbound headers, so the interceptor's +// CallContextFrom + RequestMeta path can be exercised. +func withCallContext(parent context.Context, inbound map[string][]string) context.Context { + ctx, _ := a2asrv.WithCallContext(parent, a2asrv.NewRequestMeta(inbound)) + return ctx +} + +// TestLineageHeaderPropagation covers the parent + root context_id header +// derivation. Mirrors the Python TestLineageHeaderPropagation cases in +// python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py. +func TestLineageHeaderPropagation(t *testing.T) { + const ownSession = "own-session-123" + const upstreamRoot = "root-from-upstream" + const upstreamParent = "parent-from-upstream" + + t.Run("chain root stamps own id as parent and root", func(t *testing.T) { + ctx := context.WithValue(context.Background(), parentContextIDContextKey{}, ownSession) + req := newReq() + + if _, err := (&lineageHeadersInterceptor{}).Before(ctx, req); err != nil { + t.Fatalf("Before returned error: %v", err) + } + + assertSingleHeader(t, req, ParentContextIDHeader, ownSession) + assertSingleHeader(t, req, RootContextIDHeader, ownSession) + }) + + t.Run("mid-chain forwards root unchanged and overrides parent with own id", func(t *testing.T) { + ctx := context.WithValue(context.Background(), parentContextIDContextKey{}, ownSession) + ctx = withCallContext(ctx, map[string][]string{ + RootContextIDHeader: {upstreamRoot}, + ParentContextIDHeader: {upstreamParent}, + }) + req := newReq() + + if _, err := (&lineageHeadersInterceptor{}).Before(ctx, req); err != nil { + t.Fatalf("Before returned error: %v", err) + } + + assertSingleHeader(t, req, ParentContextIDHeader, ownSession) + assertSingleHeader(t, req, RootContextIDHeader, upstreamRoot) + }) + + t.Run("legacy inbound with only parent header promotes it to root", func(t *testing.T) { + ctx := context.WithValue(context.Background(), parentContextIDContextKey{}, ownSession) + ctx = withCallContext(ctx, map[string][]string{ + ParentContextIDHeader: {upstreamParent}, + }) + req := newReq() + + if _, err := (&lineageHeadersInterceptor{}).Before(ctx, req); err != nil { + t.Fatalf("Before returned error: %v", err) + } + + assertSingleHeader(t, req, ParentContextIDHeader, ownSession) + assertSingleHeader(t, req, RootContextIDHeader, upstreamParent) + }) + + t.Run("no session id is a no-op", func(t *testing.T) { + // No parentContextIDContextKey on ctx - matches the stub tool_context + // case in Python (empty dict, no headers stamped). + ctx := context.Background() + req := newReq() + + if _, err := (&lineageHeadersInterceptor{}).Before(ctx, req); err != nil { + t.Fatalf("Before returned error: %v", err) + } + + if got := req.Meta.Get(ParentContextIDHeader); len(got) != 0 { + t.Errorf("expected no parent header, got %v", got) + } + if got := req.Meta.Get(RootContextIDHeader); len(got) != 0 { + t.Errorf("expected no root header, got %v", got) + } + }) + + t.Run("pre-existing header on req.Meta wins over lineage", func(t *testing.T) { + // Analogous to Python's header_provider override: a caller-supplied + // header that is already present on the outbound request must not be + // overwritten by the lineage interceptor. + ctx := context.WithValue(context.Background(), parentContextIDContextKey{}, ownSession) + ctx = withCallContext(ctx, map[string][]string{ + RootContextIDHeader: {upstreamRoot}, + }) + req := newReq() + req.Meta.Append(ParentContextIDHeader, "caller-override-parent") + req.Meta.Append(RootContextIDHeader, "caller-override-root") + + if _, err := (&lineageHeadersInterceptor{}).Before(ctx, req); err != nil { + t.Fatalf("Before returned error: %v", err) + } + + assertSingleHeader(t, req, ParentContextIDHeader, "caller-override-parent") + assertSingleHeader(t, req, RootContextIDHeader, "caller-override-root") + }) +} + +func assertSingleHeader(t *testing.T, req *a2aclient.Request, key, want string) { + t.Helper() + got := req.Meta.Get(key) + if len(got) != 1 { + t.Fatalf("%s: expected exactly 1 value, got %v", key, got) + } + if got[0] != want { + t.Errorf("%s: got %q, want %q", key, got[0], want) + } +} diff --git a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py index 3cbaabcb2..5f9afe109 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py +++ b/python/packages/kagent-adk/src/kagent/adk/_remote_a2a_tool.py @@ -61,6 +61,24 @@ _HEADERS_STATE_KEY = "headers" _EXTRA_HEADERS_CONTEXT_KEY = "_a2a_extra_headers" +# Conversation-lineage headers propagated on outbound A2A calls so a remote +# agent can correlate this turn with the originating chat conversation — +# useful when downstream code keys per-conversation state (sessions, sandbox +# pods, cache entries) on a stable identifier across A2A hops. +# +# `x-kagent-parent-context-id` is the immediate caller's A2A context_id +# (the session id of the agent that ran this tool). It changes with every +# hop in a chain of A2A calls. +# +# `x-kagent-root-context-id` is the top-of-chain context_id — the agent at +# the start of the chain (typically the user-facing chat agent). It stays +# stable across every hop and across every turn of the same conversation, +# so downstream agents can use it to key state that should outlive a single +# A2A call (e.g. claim a per-conversation worker pod that survives between +# turns). +PARENT_CONTEXT_ID_HEADER = "x-kagent-parent-context-id" +ROOT_CONTEXT_ID_HEADER = "x-kagent-root-context-id" + class _SubagentInterceptor(ClientCallInterceptor): """ @@ -217,12 +235,70 @@ def _get_declaration(self) -> genai_types.FunctionDeclaration: def _build_call_context(self, tool_context: ToolContext) -> ClientCallContext: state: dict[str, Any] = {_USER_ID_CONTEXT_KEY: tool_context.session.user_id} + + # Derive conversation lineage so the remote agent can correlate this + # turn with the originating chat conversation. See the header constant + # docstrings at the top of this module for the parent/root semantics. + lineage_headers = self._build_lineage_headers(tool_context) + if self._header_provider: extra_headers = self._header_provider(tool_context) if extra_headers: - state[_EXTRA_HEADERS_CONTEXT_KEY] = extra_headers + # Merge caller-supplied headers on top of lineage so a custom + # provider can override the defaults if it really wants to, + # but the typical case (no custom provider) just gets lineage. + lineage_headers.update(extra_headers) + + if lineage_headers: + state[_EXTRA_HEADERS_CONTEXT_KEY] = lineage_headers return ClientCallContext(state=state) + def _build_lineage_headers(self, tool_context: ToolContext) -> dict[str, str]: + """Compute the parent/root context_id headers for an outbound A2A call. + + - ``x-kagent-parent-context-id`` is set to this agent's own current + session id (the immediate caller of the remote agent). + - ``x-kagent-root-context-id`` is forwarded unchanged when we were + ourselves called by an upstream agent that set it. If no upstream + root header is present we are the top of the chain, so our own + session id becomes the root. As a transitional fallback we also + honor a parent header sent by older callers that did not yet set + root. + + Returns an empty dict when no session id can be resolved (e.g. tests + that pass a stub tool_context); the outbound request then matches + existing behavior. + """ + parent_context_id: Optional[str] = None + inbound_headers: dict[str, Any] = {} + + # ToolContext exposes `.session` (the ADK session for the current + # invocation). The session id IS the A2A context_id for kagent agents, + # per the request_converter that maps `request.context_id` to + # `session_id` in kagent.adk.converters.request_converter. + session = getattr(tool_context, "session", None) + if session is not None: + parent_context_id = getattr(session, "id", None) + state = getattr(session, "state", None) + if isinstance(state, dict): + hdrs = state.get(_HEADERS_STATE_KEY) + if isinstance(hdrs, dict): + inbound_headers = hdrs + + if not parent_context_id: + return {} + + root_context_id = ( + inbound_headers.get(ROOT_CONTEXT_ID_HEADER) + or inbound_headers.get(PARENT_CONTEXT_ID_HEADER) + or parent_context_id + ) + + return { + PARENT_CONTEXT_ID_HEADER: str(parent_context_id), + ROOT_CONTEXT_ID_HEADER: str(root_context_id), + } + async def run_async(self, *, args: dict[str, Any], tool_context: ToolContext) -> Any: """Execute the remote agent tool. diff --git a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py index 8682741bb..e4facf14f 100644 --- a/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py +++ b/python/packages/kagent-adk/tests/unittests/test_remote_a2a_tool.py @@ -1,6 +1,6 @@ """Tests for KAgentRemoteA2ATool.""" -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, Callable from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -37,10 +37,17 @@ class _MockSession: - """Minimal session mock providing user_id.""" + """Minimal session mock providing user_id, id, and state.""" - def __init__(self, user_id: str = _DEFAULT_USER_ID): + def __init__( + self, + user_id: str = _DEFAULT_USER_ID, + session_id: str | None = None, + state: dict[str, Any] | None = None, + ): self.user_id = user_id + self.id = session_id + self.state = state if state is not None else {} class MockToolContext: @@ -50,11 +57,13 @@ def __init__( self, tool_confirmation: ToolConfirmation | None = None, user_id: str = _DEFAULT_USER_ID, + session_id: str | None = None, + session_state: dict[str, Any] | None = None, ): self.state: dict[str, Any] = {} self.function_call_id = "outer_fc_1" self.tool_confirmation = tool_confirmation - self.session = _MockSession(user_id) + self.session = _MockSession(user_id, session_id=session_id, state=session_state) self._confirmations: dict[str, ToolConfirmation] = {} def request_confirmation(self, *, hint: str = "", payload: dict | None = None) -> None: @@ -109,12 +118,17 @@ async def _async_yield(*items) -> AsyncIterator: yield item -def _make_tool(*, httpx_client: httpx.AsyncClient | None = None) -> KAgentRemoteA2ATool: +def _make_tool( + *, + httpx_client: httpx.AsyncClient | None = None, + header_provider: Callable[[Any], dict[str, str]] | None = None, +) -> KAgentRemoteA2ATool: return KAgentRemoteA2ATool( name="k8s_agent", description="K8s subagent", agent_card_url="http://k8s-agent/.well-known/agent.json", httpx_client=httpx_client, + header_provider=header_provider, ) @@ -487,3 +501,129 @@ async def test_get_tools_returns_the_tool(self): assert isinstance(tools[0], KAgentRemoteA2ATool) assert tools[0].name == "my_agent" await mock_client.aclose() + + +# --------------------------------------------------------------------------- +# Conversation lineage header tests +# --------------------------------------------------------------------------- + + +class TestLineageHeaderPropagation: + """Tests for the parent/root context_id headers built by + ``KAgentRemoteA2ATool._build_call_context``. + + The lineage headers let a remote A2A peer correlate this turn with the + originating chat conversation. ``x-kagent-parent-context-id`` is always + the immediate caller's session id; ``x-kagent-root-context-id`` is + forwarded unchanged from the upstream caller when present, or stamped + with the immediate caller's own id when this agent is the root of the + chain. + """ + + def _build_state(self, tool: KAgentRemoteA2ATool, ctx: MockToolContext) -> dict[str, Any]: + return tool._build_call_context(ctx).state + + def test_root_agent_stamps_own_id_as_root_and_parent(self): + """An agent at the top of the chain (no inbound lineage headers) sets + both parent and root to its own session id.""" + tool = _make_tool() + ctx = MockToolContext(session_id="chat-1", session_state={"headers": {}}) + + state = self._build_state(tool, ctx) + extras = state.get("_a2a_extra_headers", {}) + + assert extras.get("x-kagent-parent-context-id") == "chat-1" + assert extras.get("x-kagent-root-context-id") == "chat-1" + + def test_mid_chain_forwards_root_and_overrides_parent(self): + """An agent in the middle of an A2A chain forwards the root header + unchanged from the inbound request and replaces parent with its own + session id.""" + tool = _make_tool() + ctx = MockToolContext( + session_id="router-2", + session_state={ + "headers": { + "x-kagent-parent-context-id": "chat-1", + "x-kagent-root-context-id": "chat-1", + } + }, + ) + + state = self._build_state(tool, ctx) + extras = state.get("_a2a_extra_headers", {}) + + assert extras.get("x-kagent-parent-context-id") == "router-2" + assert extras.get("x-kagent-root-context-id") == "chat-1" + + def test_legacy_inbound_with_only_parent_promotes_to_root(self): + """If the upstream caller predates the root header, promote its + parent header to root so downstream peers still see a stable + chain-root identifier.""" + tool = _make_tool() + ctx = MockToolContext( + session_id="router-2", + session_state={"headers": {"x-kagent-parent-context-id": "legacy-1"}}, + ) + + state = self._build_state(tool, ctx) + extras = state.get("_a2a_extra_headers", {}) + + assert extras.get("x-kagent-parent-context-id") == "router-2" + assert extras.get("x-kagent-root-context-id") == "legacy-1" + + def test_no_session_id_emits_no_lineage_headers(self): + """When the caller cannot resolve a session id (e.g. a stub + ToolContext), the outbound request gets no lineage headers — matches + pre-feature behavior so this change is non-breaking for callers that + don't yet plumb session ids.""" + tool = _make_tool() + ctx = MockToolContext(session_id=None, session_state={"headers": {}}) + + state = self._build_state(tool, ctx) + extras = state.get("_a2a_extra_headers") + + assert extras is None or ( + "x-kagent-parent-context-id" not in extras and "x-kagent-root-context-id" not in extras + ) + + def test_header_provider_overrides_lineage(self): + """A constructor-supplied header_provider can override lineage + headers — escape hatch for custom propagation logic.""" + tool = _make_tool(header_provider=lambda _ctx: {"x-kagent-root-context-id": "forced"}) + ctx = MockToolContext(session_id="router-2", session_state={"headers": {}}) + + state = self._build_state(tool, ctx) + extras = state.get("_a2a_extra_headers", {}) + + assert extras.get("x-kagent-parent-context-id") == "router-2" + assert extras.get("x-kagent-root-context-id") == "forced" + + async def test_lineage_headers_reach_outbound_http(self): + """End-to-end: lineage headers built by _build_call_context flow + through _SubagentInterceptor onto the outbound HTTP request.""" + from a2a.client.middleware import ClientCallContext + + tool = _make_tool() + ctx = MockToolContext( + session_id="router-2", + session_state={ + "headers": { + "x-kagent-root-context-id": "chat-1", + } + }, + ) + call_ctx = tool._build_call_context(ctx) + + interceptor = _SubagentInterceptor() + _, http_kwargs = await interceptor.intercept( + method_name="message/send", + request_payload={}, + http_kwargs={}, + agent_card=None, + context=ClientCallContext(state=call_ctx.state), + ) + headers = http_kwargs.get("headers", {}) + + assert headers.get("x-kagent-parent-context-id") == "router-2" + assert headers.get("x-kagent-root-context-id") == "chat-1"