From 19d7450f33f9f9fa443838975a97023180a0abb8 Mon Sep 17 00:00:00 2001 From: Hubert Zub Date: Tue, 5 May 2026 14:35:04 +0200 Subject: [PATCH] feat(appkit): supervisor api adapter --- apps/dev-playground/server/index.ts | 36 +- docs/docs/plugins/agents.md | 75 ++ packages/appkit/src/agents/supervisor-api.ts | 577 +++++++++++++++ .../src/agents/tests/supervisor-api.test.ts | 662 ++++++++++++++++++ packages/appkit/src/beta.ts | 10 + .../appkit/src/connectors/serving/client.ts | 63 +- packages/appkit/src/stream/index.ts | 1 + packages/appkit/src/stream/sse-reader.ts | 114 +++ .../src/stream/tests/sse-reader.test.ts | 182 +++++ 9 files changed, 1704 insertions(+), 16 deletions(-) create mode 100644 packages/appkit/src/agents/supervisor-api.ts create mode 100644 packages/appkit/src/agents/tests/supervisor-api.test.ts create mode 100644 packages/appkit/src/stream/sse-reader.ts create mode 100644 packages/appkit/src/stream/tests/sse-reader.test.ts diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index ecbd18e78..67187dcbe 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -11,7 +11,12 @@ import { serving, WRITE_ACTIONS, } from "@databricks/appkit"; -import { agents, createAgent, tool } from "@databricks/appkit/beta"; +import { + agents, + createAgent, + fromSupervisorApi, + tool, +} from "@databricks/appkit/beta"; import { WorkspaceClient } from "@databricks/sdk-experimental"; import { z } from "zod"; import { lakebaseExamples } from "./lakebase-examples-plugin"; @@ -68,6 +73,33 @@ const helper = createAgent({ }, }); +// Supervisor API demo agent. Tools are configured on the adapter (the SA +// endpoint executes them server-side), not on the createAgent definition. +// Uncomment a `supervisorTools.*` entry (and import 'supervisorTools' from +// '@databricks/appkit/beta') to give the model real powers. +// +// We `await` the factory at module init so a misconfigured workspace +// (missing host, bad credentials) fails fast with a clear error here +// instead of as an unhandled rejection. Top-level await is fine in this +// ESM module. +const supervisor = createAgent({ + instructions: + "You are an assistant powered by the Databricks Supervisor API.", + model: fromSupervisorApi({ + model: "databricks-claude-sonnet-4-5", + tools: [ + // supervisorTools.genieSpace( + // "01ABCDEF12345678", + // "NYC taxi trip records and zones", + // ), + // supervisorTools.ucFunction( + // "main.default.add", + // "Adds two integers and returns the sum.", + // ), + ], + }), +}); + /* * Smart-Dashboard agents. * @@ -385,7 +417,7 @@ createApp({ }), serving(), agents({ - agents: { helper, sql_analyst, dashboard_pilot }, + agents: { helper, sql_analyst, dashboard_pilot, supervisor }, // `query` (markdown dispatcher) + `sql_analyst` + `dashboard_pilot` // wire the /smart-dashboard route. `insights` and `anomaly` are // ephemeral markdown agents auto-fired by the route's AgentSidebar. diff --git a/docs/docs/plugins/agents.md b/docs/docs/plugins/agents.md index 0ba2ab301..c228551e2 100644 --- a/docs/docs/plugins/agents.md +++ b/docs/docs/plugins/agents.md @@ -16,6 +16,8 @@ This page covers the full lifecycle. For the hand-written primitives (`tool()`, The agents plugin drives the LLM over Server-Sent Events. Foundation Model APIs (Claude, Llama, GPT, etc.) and other chat-style endpoints support streaming and work out of the box. Custom model endpoints that return a single JSON response (e.g. typical `sklearn` or MLflow `pyfunc` deployments) do **not** stream — pointing an agent at one will fail with "Response body is null — streaming not supported" on the first turn. If you list a serving endpoint in `apps init`, pick one whose model implements the chat-completions streaming protocol; the agents plugin reads its name from `DATABRICKS_SERVING_ENDPOINT_NAME` whenever an agent doesn't pin `model:` itself. For the non-streaming path against a custom endpoint, use the `serving` plugin's `/invoke` route with `useServingInvoke` instead. + +Or skip serving-endpoint setup entirely with the managed [Supervisor API adapter](#managed-agents-the-supervisor-api-adapter) (beta). ::: ## Install @@ -217,6 +219,79 @@ const result = await runAgent(classifier, { Hosted tools (MCP) are still `agents()`-only since they require the live MCP client. Plugin tool dispatch in standalone mode runs as the service principal (no OBO) and **bypasses the agents-plugin approval gate** — treat standalone runAgent as a trusted-prompt environment (CI, batch eval, internal scripts), not as an exposed user-facing surface. +## Managed agents: the Supervisor API adapter + +`fromSupervisorApi` (beta) is the zero-config way to run an agent: instead of provisioning and pointing at a model-serving endpoint, you run the agentic loop in the Databricks workspace by targeting the AI Gateway Responses API (`/ai-gateway/mlflow/v1/responses`), which runs the LLM — and any hosted tools — as a managed service on Databricks. No `DATABRICKS_SERVING_ENDPOINT_NAME`, no stream-capability check, no JS tool plumbing for the common cases. + +The minimal agent is one extra line versus a markdown agent: + +```ts +import { createApp, createAgent } from "@databricks/appkit"; +import { agents, fromSupervisorApi } from "@databricks/appkit/beta"; + +await createApp({ + plugins: [ + agents({ + agents: { + assistant: createAgent({ + instructions: "You are a helpful assistant.", + model: fromSupervisorApi({ model: "databricks-claude-sonnet-4-5" }), + }), + }, + }), + ], +}); +``` + +`createAgent({ model })` already accepts adapters and adapter promises in addition to the model-name string used in earlier examples, so you can drop the factory result straight in. `fromSupervisorApi` resolves credentials through the SDK chain (`DATABRICKS_HOST`, OAuth, PAT, …); pass `workspaceClient` to reuse an existing client. + +### Hosted tools + +Expose Genie spaces, Unity Catalog functions/connections, Knowledge Assistants, or other AppKit apps to the model by listing them on the adapter — execution stays server-side, you write no tool code: + +```ts +import { fromSupervisorApi, supervisorTools } from "@databricks/appkit/beta"; + +const model = fromSupervisorApi({ + model: "databricks-claude-sonnet-4-5", + tools: [ + supervisorTools.genieSpace( + "01ABCDEF12345678", + "NYC taxi trip records and zones", + ), + supervisorTools.ucFunction( + "main.default.add", + "Adds two integers and returns the sum.", + ), + ], +}); +``` + +`description` is **required and non-empty** — the LLM uses it to route between tools, so two Genie spaces both labelled "Genie space" will be indistinguishable. + +| Factory | Tool kind | Identifier | +|---|---|---| +| `supervisorTools.genieSpace(id, description)` | Genie space | space id | +| `supervisorTools.ucFunction(name, description)` | Unity Catalog function | three-part name | +| `supervisorTools.knowledgeAssistant(id, description)` | Knowledge Assistant | assistant id | +| `supervisorTools.app(name, description)` | Databricks App | app name | +| `supervisorTools.ucConnection(name, description)` | UC connection | connection name | + +### What does *not* apply to Supervisor-API agents + +The managed runtime owns its own tool execution, so the adapter intentionally **ignores the agents-plugin tool index**. For any agent whose `model:` is a Supervisor adapter: + +- Tools wired via markdown `tools:` or the `tools(plugins)` function form are not exposed to the model — declare hosted tools via `fromSupervisorApi({ tools: […] })` instead. +- The **human-in-the-loop approval gate** does not fire (tool calls never enter the Node process; `effect: "destructive"` annotations on plugin tools are irrelevant here). +- `limits.maxToolCalls` is not enforced (the managed runtime accounts for its own calls). +- Per-call **OBO** does not apply to hosted tools; they run with the credentials the managed runtime uses for the target resource. + +Standard-adapter agents and Supervisor-API agents can coexist in the same `agents({ agents: { … } })` map and can be composed as sub-agents (Level 4) — only the agent whose `model:` points at a Supervisor adapter is exempt from the items above. + +:::note Recovery path for non-streaming tool turns +Some hosted tool kinds return their final assistant text without incremental `output_text.delta` events. The adapter has a recovery path that pulls the text out of `response.completed.output[]` so the turn is not silently empty. Set `DEBUG=appkit:agents:supervisor-api` to log the per-turn event-type histogram if you want to verify which path a turn took. +::: + ## Configuration reference ```ts diff --git a/packages/appkit/src/agents/supervisor-api.ts b/packages/appkit/src/agents/supervisor-api.ts new file mode 100644 index 000000000..228eb8be9 --- /dev/null +++ b/packages/appkit/src/agents/supervisor-api.ts @@ -0,0 +1,577 @@ +import type { + AgentAdapter, + AgentEvent, + AgentInput, + AgentRunContext, + Message, + ResponseStreamEvent, +} from "shared"; +import { type ApiClientLike, streamPath } from "../connectors/serving/client"; +import { createLogger } from "../logging/logger"; +import { readSseEvents } from "../stream"; + +const logger = createLogger("agents:supervisor-api"); + +/** + * Transport shim: given a request body, returns the raw SSE byte stream from + * the Supervisor API endpoint. Injected at construction time so callers can + * swap in the workspace SDK (the {@link fromSupervisorApi} factory), a bare + * `fetch` (a reverse proxy / mock), or a test fake. Mirrors `StreamBody` in + * `agents/databricks.ts` so both adapters share one transport surface. + */ +type StreamBody = ( + body: Record, + signal?: AbortSignal, +) => Promise>; + +/** + * Structural shape of a Databricks SDK client used by {@link fromSupervisorApi}. + * Only what we need: `apiClient.request` for streaming and + * `config.ensureResolved` to materialise the host/credentials. + */ +interface WorkspaceClientLike extends ApiClientLike { + config: { ensureResolved(): Promise }; +} + +// --------------------------------------------------------------------------- +// Supervisor API tool surface (wire format) +// --------------------------------------------------------------------------- + +/** + * Tools supported by the Databricks AI Gateway Responses API. The shapes match + * the wire format the endpoint expects, so the adapter passes the array + * straight into the request body. + * + * Prefer the {@link supervisorTools} factories — they fill in the + * SA-validation-bug workaround for `description` (must be non-empty). + */ +export type SupervisorTool = + | { type: "genie_space"; genie_space: { id: string; description: string } } + | { type: "uc_function"; uc_function: { name: string; description: string } } + | { + type: "knowledge_assistant"; + knowledge_assistant: { + knowledge_assistant_id: string; + description: string; + }; + } + | { type: "app"; app: { name: string; description: string } } + | { + type: "uc_connection"; + uc_connection: { name: string; description: string }; + }; + +/** + * Concise factories for declaring Supervisor API tools. + * + * `description` is required: SA's protobuf validation rejects `null`/`""`, + * AND the LLM running on SA reads this string to decide when to route to + * the tool. Two genie spaces both labelled "Genie space" give the model + * nothing to discriminate on, so callers always own the routing hint. + * + * @example + * ```ts + * fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * supervisorTools.ucFunction( + * "main.default.add", + * "Adds two integers and returns the sum.", + * ), + * ], + * }); + * ``` + */ +export const supervisorTools = { + genieSpace: (id: string, description: string): SupervisorTool => ({ + type: "genie_space", + genie_space: { id, description }, + }), + ucFunction: (name: string, description: string): SupervisorTool => ({ + type: "uc_function", + uc_function: { name, description }, + }), + knowledgeAssistant: ( + knowledgeAssistantId: string, + description: string, + ): SupervisorTool => ({ + type: "knowledge_assistant", + knowledge_assistant: { + knowledge_assistant_id: knowledgeAssistantId, + description, + }, + }), + app: (name: string, description: string): SupervisorTool => ({ + type: "app", + app: { name, description }, + }), + ucConnection: (name: string, description: string): SupervisorTool => ({ + type: "uc_connection", + uc_connection: { name, description }, + }), +}; + +// --------------------------------------------------------------------------- +// Adapter +// --------------------------------------------------------------------------- + +export interface SupervisorApiAdapterOptions { + /** + * Model identifier to pass in the request body + * (e.g. "databricks-claude-sonnet-4"). + */ + model: string; + /** + * Hosted tools the SA endpoint should expose to the model. Use the + * {@link supervisorTools} factories for the most common shapes. + */ + tools?: SupervisorTool[]; + /** + * A WorkspaceClient (or structural equivalent) used for host resolution + * and per-request authentication. When omitted, a `WorkspaceClient({})` + * is created internally using the default SDK credential chain + * (`DATABRICKS_HOST`, OAuth, PAT, etc.). + */ + workspaceClient?: WorkspaceClientLike; +} + +export interface SupervisorApiAdapterCtorOptions { + streamBody: StreamBody; + model: string; + tools?: SupervisorTool[]; +} + +/** + * Adapter that calls the Databricks AI Gateway Responses API + * (`/ai-gateway/mlflow/v1/responses`). + * + * Streams SSE events in the OpenAI Responses API wire format and maps them + * to the AppKit `AgentEvent` protocol. Tool execution is handled + * server-side, so the adapter ignores the agents-plugin tool index. + * + * Authentication is handled via the Databricks SDK credential chain — the + * same mechanism used by `DatabricksAdapter.fromModelServing`. The transport + * is injected via {@link SupervisorApiAdapterCtorOptions.streamBody}; the + * {@link fromSupervisorApi} factory wires it through the SDK's + * `apiClient.request({ raw: true })`. + * + * Set `DEBUG=appkit:agents:supervisor-api` to log the outbound request + * shape (model, instructions length, input shape, tool count) and to be + * notified when the recovery path engages (no incremental deltas, text + * pulled from `response.completed.output[]`). The no-delta warning includes + * a per-turn event-type histogram and the SA-reported status/error/ + * incomplete_details, so it's already actionable without DEBUG. + * + * @example + * ```ts + * import { createApp, createAgent, agents } from "@databricks/appkit"; + * import { + * fromSupervisorApi, + * supervisorTools, + * } from "@databricks/appkit/agents/supervisor-api"; + * + * const adapter = await fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * ], + * }); + * + * await createApp({ + * plugins: [ + * agents({ + * agents: { + * assistant: createAgent({ + * instructions: "You are a helpful assistant.", + * model: adapter, + * }), + * }, + * }), + * ], + * }); + * ``` + */ +export class SupervisorApiAdapter implements AgentAdapter { + private streamBody: StreamBody; + private model: string; + private tools: SupervisorTool[]; + + constructor(options: SupervisorApiAdapterCtorOptions) { + this.streamBody = options.streamBody; + this.model = options.model; + this.tools = options.tools ?? []; + } + + async *run( + input: AgentInput, + context: AgentRunContext, + ): AsyncGenerator { + if (context.signal?.aborted) return; + + yield { type: "status", status: "running" }; + + const { instructions, input: payloadInput } = this.buildInput( + input.messages, + ); + yield* this.streamResponse(instructions, payloadInput, context.signal); + } + + private async *streamResponse( + instructions: string | undefined, + input: ResponseInput, + signal?: AbortSignal, + ): AsyncGenerator { + const body: Record = { + model: this.model, + input, + stream: true, + }; + if (instructions) { + body.instructions = instructions; + } + // SA's protobuf validation rejects `tools: []` and `tools: null`. Only + // include the field when at least one tool is configured. + if (this.tools.length > 0) { + body.tools = this.tools; + } + + logger.debug( + "model=%s instructionsLen=%d inputType=%s tools=%d", + this.model, + instructions?.length ?? 0, + typeof input === "string" ? "string" : `array[${input.length}]`, + this.tools.length, + ); + + let stream: ReadableStream; + try { + stream = await this.streamBody(body, signal); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + logger.warn("Supervisor API request failed: %s", message); + yield { + type: "status", + status: "error", + error: `Supervisor API error: ${message}`, + }; + return; + } + + let receivedAnyDelta = false; + // Tracks `item_id`s we've already streamed text deltas for. Used by + // `mapEvent` to fall back to the final item text on `output_item.done` + // only when no incremental deltas streamed for that item — avoids + // double-emitting text when SA does both delta and done. + const streamedItemIds = new Set(); + // Histogram of received event types — surfaced in the no-delta warning + // so it's actionable without re-running with DEBUG. + const eventCounts = new Map(); + // Set to true once we've yielded a terminal `{status:"error"}` event so + // the recovery / completion / no-delta-warning blocks below all bail + // out — the consumer's already seen the terminal status, anything + // further would contradict the protocol's terminal-event semantics. + let terminated = false; + // Diagnostic snapshot of the last `response.completed` event. SA stuffs + // the final assistant message into `response.output[]` even when it + // didn't emit any deltas (e.g. when a tool failed or the model produced + // nothing). Keeping it lets us recover the text and surface useful + // errors instead of a silent empty turn. + let lastCompleted: + | { + status?: string; + output?: Array<{ + type?: string; + content?: Array<{ type?: string; text?: string }>; + }>; + error?: unknown; + incomplete_details?: unknown; + } + | undefined; + + for await (const { event, data } of readSseEvents(stream, signal)) { + if (data === "[DONE]") continue; + + let parsed: Record; + try { + parsed = JSON.parse(data); + } catch (err) { + logger.debug( + "Failed to parse SSE data line: %s (%O)", + data.slice(0, 200), + err, + ); + continue; + } + + const eventType = event || (parsed.type as string) || ""; + eventCounts.set(eventType, (eventCounts.get(eventType) ?? 0) + 1); + + // `response.completed` is held back until after the loop so we can + // synthesise a `message_delta` from `response.output[]` when the + // stream produced no incremental deltas (intermittent SA behaviour). + // Emitting `complete` first would let UIs finalise the turn before the + // recovered text arrives. + if (eventType === "response.completed") { + lastCompleted = parsed.response as typeof lastCompleted; + continue; + } + + const out = mapEvent(eventType, parsed, streamedItemIds); + if (out) { + if (out.type === "message_delta") receivedAnyDelta = true; + yield out; + if (out.type === "status" && out.status === "error") { + terminated = true; + break; + } + } + } + + if (signal?.aborted) return; + + if (eventCounts.size === 0) { + logger.warn( + "Supervisor API stream closed without emitting any SSE events.", + ); + return; + } + + if (terminated) return; + + // Recovery path: no deltas streamed but SA finished — pull the assistant + // text out of `response.completed.response.output[]`. + if (!receivedAnyDelta) { + const recovered = extractTextFromCompletedResponse(lastCompleted); + if (recovered) { + logger.debug( + "Recovered %d chars from response.completed.output[]", + recovered.length, + ); + yield { type: "message_delta", content: recovered }; + receivedAnyDelta = true; + } + } + + if (eventCounts.has("response.completed")) { + yield { type: "status", status: "complete" }; + } + + if (!receivedAnyDelta) { + const histogram = [...eventCounts.entries()] + .map(([t, n]) => `${t}=${n}`) + .join(", "); + const completedError = lastCompleted?.error + ? JSON.stringify(lastCompleted.error) + : undefined; + const completedIncomplete = lastCompleted?.incomplete_details + ? JSON.stringify(lastCompleted.incomplete_details) + : undefined; + logger.warn( + "Supervisor API stream completed without any output_text deltas. " + + "events={%s} completed.status=%s completed.error=%s completed.incomplete=%s", + histogram, + lastCompleted?.status ?? "", + completedError ?? "", + completedIncomplete ?? "", + ); + } + } + + /** + * Splits the agent's message list into a Responses-API payload. System + * messages are concatenated (in order) into the top-level `instructions` + * field; user/assistant turns become `input` (as a plain string for the + * common single-user-turn case, otherwise as `{role,content}[]`). Tool-role + * messages are skipped — SA owns its own tool history server-side, so + * re-feeding our tool-result records would only confuse it. + */ + private buildInput(messages: Message[]): { + instructions: string | undefined; + input: ResponseInput; + } { + const instructionsParts: string[] = []; + const turns: Array<{ + role: "user" | "assistant" | "system"; + content: string; + }> = []; + + for (const m of messages) { + if (m.role === "system") instructionsParts.push(m.content); + else if (m.role !== "tool") + turns.push({ role: m.role, content: m.content }); + } + + const instructions = instructionsParts.length + ? instructionsParts.join("\n\n") + : undefined; + + if (turns.length === 1 && turns[0].role === "user") { + return { instructions, input: turns[0].content }; + } + return { instructions, input: turns }; + } +} + +type ResponseInput = + | string + | Array<{ role: "user" | "assistant" | "system"; content: string }>; + +/** + * Pulls the final assistant text out of the `response` payload attached to a + * `response.completed` event. SA always materialises the full response there, + * so this is our last-resort recovery path when the stream produced neither + * `output_text.delta` nor an actionable `output_item.done` (observed + * intermittently with tool-enabled SA agents). + */ +function extractTextFromCompletedResponse( + response: + | { + output?: Array<{ + type?: string; + content?: Array<{ type?: string; text?: string }>; + }>; + } + | undefined, +): string { + if (!response?.output) return ""; + let text = ""; + for (const item of response.output) { + if (item?.type !== "message" || !Array.isArray(item.content)) continue; + for (const part of item.content) { + if (part?.type === "output_text" && typeof part.text === "string") { + text += part.text; + } + } + } + return text; +} + +function mapEvent( + eventType: string, + data: Record, + streamedItemIds: Set, +): AgentEvent | null { + // The cast restricts the switch domain to the closed wire-event union + // exported by `shared`, so typos in case clauses (e.g. `response.faled`) + // become compile errors instead of silent string mismatches. Unknown + // event names still fall through to `default` at runtime — we don't + // require exhaustive matching since SA emits more lifecycle events + // than we care to map. + switch (eventType as ResponseStreamEvent["type"]) { + case "response.output_text.delta": { + const itemId = data.item_id as string | undefined; + if (itemId) streamedItemIds.add(itemId); + return { type: "message_delta", content: (data.delta as string) ?? "" }; + } + + // `response.completed` is intentionally absent: `streamResponse` holds + // it back so it can synthesise a delta from `response.output[]` when + // the stream produced none, then emits `{status:"complete"}` itself. + + case "response.failed": + return { type: "status", status: "error", error: "Response failed" }; + + case "error": { + const errMsg = + typeof data.error === "string" + ? data.error + : JSON.stringify(data.error ?? "Unknown error"); + return { type: "status", status: "error", error: errMsg }; + } + + case "response.output_item.done": { + const item = data.item as + | { + id?: string; + type?: string; + content?: Array<{ text?: string; type?: string }>; + } + | undefined; + + if (item?.id === "error") { + const errText = item.content?.[0]?.text ?? "Unknown tool error from SA"; + return { type: "status", status: "error", error: errText }; + } + + // Fallback: when SA produces a tool-driven response (e.g. Genie space), + // it often omits `response.output_text.delta` events and only emits the + // final assistant message via `output_item.done`. Surface that text as + // a single delta so the UI sees the answer. + if ( + item?.type === "message" && + item.id && + !streamedItemIds.has(item.id) + ) { + const text = (item.content ?? []) + .map((c) => (c.type === "output_text" ? (c.text ?? "") : "")) + .join(""); + if (text.length > 0) { + streamedItemIds.add(item.id); + return { type: "message_delta", content: text }; + } + } + return null; + } + + // All other event types are intentionally ignored. Notable lifecycle + // events we drop on the floor: `response.created`, `response.in_progress`, + // `response.output_text.done`, `response.output_item.added`, + // `response.content_part.added`, `response.content_part.done`. + default: + return null; + } +} + +/** + * Creates an {@link AgentAdapter} backed by the Databricks AI Gateway + * Responses API (`/ai-gateway/mlflow/v1/responses`). + * + * Uses the SDK's default credential chain for auth (reads DATABRICKS_HOST, + * DATABRICKS_TOKEN, OAuth config, etc.). + * + * @example + * ```ts + * import { + * fromSupervisorApi, + * supervisorTools, + * } from "@databricks/appkit/agents/supervisor-api"; + * + * const adapter = await fromSupervisorApi({ + * model: "databricks-claude-sonnet-4", + * tools: [ + * supervisorTools.genieSpace( + * "01ABCDEF12345678", + * "NYC taxi trip records and zones", + * ), + * ], + * }); + * ``` + */ +export async function fromSupervisorApi( + options: SupervisorApiAdapterOptions, +): Promise { + let client = options.workspaceClient; + if (!client) { + const sdk = await import("@databricks/sdk-experimental"); + client = new sdk.WorkspaceClient({}) as unknown as WorkspaceClientLike; + } + + await client.config.ensureResolved(); + + // Capture the resolved client so the closure doesn't depend on the outer + // `let` binding being reassigned later. + const resolved = client; + return new SupervisorApiAdapter({ + streamBody: (body, signal) => + streamPath(resolved, "/ai-gateway/mlflow/v1/responses", body, signal), + model: options.model, + tools: options.tools ?? [], + }); +} diff --git a/packages/appkit/src/agents/tests/supervisor-api.test.ts b/packages/appkit/src/agents/tests/supervisor-api.test.ts new file mode 100644 index 000000000..9606b1c6a --- /dev/null +++ b/packages/appkit/src/agents/tests/supervisor-api.test.ts @@ -0,0 +1,662 @@ +import type { AgentEvent, AgentInput } from "shared"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + fromSupervisorApi, + SupervisorApiAdapter, + type SupervisorTool, + supervisorTools, +} from "../supervisor-api"; + +function createReadableStream(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +function sseEvent(eventName: string, data: Record): string { + return `event: ${eventName}\ndata: ${JSON.stringify(data)}\n\n`; +} + +/** + * Captures the body the adapter posts and returns a fake stream of SSE + * chunks. Mirrors the `streamBody` test pattern used by `DatabricksAdapter`. + */ +function makeStreamBody(chunks: string[]): { + streamBody: ReturnType; + lastBody: () => Record | undefined; +} { + let captured: Record | undefined; + const streamBody = vi.fn(async (body: Record) => { + captured = body; + return createReadableStream(chunks); + }); + return { streamBody, lastBody: () => captured }; +} + +function createInput(): AgentInput { + return { + messages: [ + { id: "1", role: "user", content: "Hello", createdAt: new Date() }, + ], + tools: [], + threadId: "thread-1", + }; +} + +async function collect( + gen: AsyncGenerator, +): Promise { + const out: AgentEvent[] = []; + for await (const e of gen) out.push(e); + return out; +} + +describe("supervisorTools factories", () => { + test("genieSpace produces correct wire shape", () => { + expect(supervisorTools.genieSpace("space123", "NYC taxi data")).toEqual({ + type: "genie_space", + genie_space: { id: "space123", description: "NYC taxi data" }, + }); + }); + + test("ucFunction produces correct wire shape", () => { + expect( + supervisorTools.ucFunction("main.default.add", "Adds two integers."), + ).toEqual({ + type: "uc_function", + uc_function: { + name: "main.default.add", + description: "Adds two integers.", + }, + }); + }); + + test("knowledgeAssistant maps id into knowledge_assistant_id", () => { + expect( + supervisorTools.knowledgeAssistant("ka-1", "Internal docs Q&A"), + ).toEqual({ + type: "knowledge_assistant", + knowledge_assistant: { + knowledge_assistant_id: "ka-1", + description: "Internal docs Q&A", + }, + }); + }); + + test("app produces correct wire shape", () => { + expect(supervisorTools.app("my-app", "Demo Databricks app.")).toEqual({ + type: "app", + app: { name: "my-app", description: "Demo Databricks app." }, + }); + }); + + test("ucConnection produces correct wire shape", () => { + expect( + supervisorTools.ucConnection("my-conn", "Connection to external DB."), + ).toEqual({ + type: "uc_connection", + uc_connection: { + name: "my-conn", + description: "Connection to external DB.", + }, + }); + }); +}); + +describe("SupervisorApiAdapter", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("posts model, input, tools, and stream:true through streamBody", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { delta: "Hi" }), + sseEvent("response.completed", {}), + ]); + + const tools: SupervisorTool[] = [ + supervisorTools.genieSpace("g1", "Test genie space"), + ]; + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + tools, + }); + + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + + expect(streamBody).toHaveBeenCalledTimes(1); + expect(lastBody()).toMatchObject({ + model: "databricks-claude-sonnet-4", + input: "Hello", + stream: true, + tools, + }); + }); + + test("omits the tools field entirely when no tools are configured", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + expect(lastBody()).not.toHaveProperty("tools"); + }); + + test("hoists system messages into the top-level instructions field", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect( + adapter.run( + { + messages: [ + { + id: "s", + role: "system", + content: "Be terse.", + createdAt: new Date(), + }, + { id: "u", role: "user", content: "Hi", createdAt: new Date() }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + const body = lastBody(); + expect(body?.instructions).toBe("Be terse."); + expect(body?.input).toBe("Hi"); + }); + + test("omits instructions when there is no system message", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + expect(lastBody()).not.toHaveProperty("instructions"); + }); + + test("emits message_delta and complete on the happy path", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { delta: "Hello" }), + sseEvent("response.output_text.delta", { delta: " world" }), + sseEvent("response.completed", {}), + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hello" }, + { type: "message_delta", content: " world" }, + { type: "status", status: "complete" }, + ]); + }); + + test("maps response.failed to a status:error event", async () => { + const { streamBody } = makeStreamBody([sseEvent("response.failed", {})]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Response failed", + }); + }); + + test("maps top-level error events", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("error", { error: "rate limited" }), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "rate limited", + }); + }); + + test("maps response.output_item.done with id:'error' to status:error", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_item.done", { + item: { + id: "error", + content: [{ text: "Tool execution failed" }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Tool execution failed", + }); + }); + + test("falls back to output_item.done text when no deltas streamed (tool-driven SA response)", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_item.added", { + item: { type: "message", id: "msg-1", role: "assistant", content: [] }, + }), + sseEvent("response.output_item.done", { + item: { + type: "message", + id: "msg-1", + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: "Genie says hi." }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Genie says hi." }, + { type: "status", status: "complete" }, + ]); + }); + + test("does not double-emit when both deltas and output_item.done arrive for the same item", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { + item_id: "msg-1", + delta: "Hello", + }), + sseEvent("response.output_text.delta", { + item_id: "msg-1", + delta: " world", + }), + sseEvent("response.output_item.done", { + item: { + type: "message", + id: "msg-1", + status: "completed", + role: "assistant", + content: [{ type: "output_text", text: "Hello world" }], + }, + }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hello" }, + { type: "message_delta", content: " world" }, + { type: "status", status: "complete" }, + ]); + }); + + test("emits status:error when the underlying streamBody throws", async () => { + const streamBody = vi + .fn() + .mockRejectedValue(new Error("Supervisor API error (500): boom")); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "status", + status: "error", + error: "Supervisor API error: Supervisor API error (500): boom", + }); + }); + + test("short-circuits when the signal is already aborted", async () => { + const streamBody = vi.fn(); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + + const controller = new AbortController(); + controller.abort(); + + const events = await collect( + adapter.run(createInput(), { + executeTool: vi.fn(), + signal: controller.signal, + }), + ); + + expect(events).toEqual([]); + expect(streamBody).not.toHaveBeenCalled(); + }); + + test("multi-turn input (user + assistant + user) is sent as a structured array", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + + await collect( + adapter.run( + { + messages: [ + { id: "u1", role: "user", content: "Hi", createdAt: new Date() }, + { + id: "a", + role: "assistant", + content: "Hello!", + createdAt: new Date(), + }, + { + id: "u2", + role: "user", + content: "Tell me more", + createdAt: new Date(), + }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + + expect(lastBody()?.input).toEqual([ + { role: "user", content: "Hi" }, + { role: "assistant", content: "Hello!" }, + { role: "user", content: "Tell me more" }, + ]); + }); + + test("drops tool-role messages from the request payload", async () => { + const { streamBody, lastBody } = makeStreamBody([ + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + await collect( + adapter.run( + { + messages: [ + { id: "u", role: "user", content: "Hi", createdAt: new Date() }, + { + id: "t1", + role: "tool", + content: "(genie result)", + createdAt: new Date(), + }, + ], + tools: [], + threadId: "t", + }, + { executeTool: vi.fn() }, + ), + ); + expect(lastBody()?.input).toBe("Hi"); + }); + + test("recovers final assistant text from response.completed.output when no deltas streamed", async () => { + // Real-world flake: SA occasionally finishes a turn with zero + // `output_text.delta` events and no `output_item.done` for the message, + // but still mirrors the full assistant text in `response.completed`. + // Without recovery the UI sees a silent empty turn. + const { streamBody } = makeStreamBody([ + sseEvent("response.created", {}), + sseEvent("response.in_progress", {}), + sseEvent("response.completed", { + response: { + status: "completed", + output: [ + { + type: "message", + id: "msg-x", + role: "assistant", + content: [ + { type: "output_text", text: "Recovered " }, + { type: "output_text", text: "answer." }, + ], + }, + ], + }, + }), + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Recovered answer." }, + { type: "status", status: "complete" }, + ]); + }); + + test("does not recover from response.completed when deltas already streamed", async () => { + const { streamBody } = makeStreamBody([ + sseEvent("response.output_text.delta", { + item_id: "msg-x", + delta: "Hi", + }), + sseEvent("response.completed", { + response: { + status: "completed", + output: [ + { + type: "message", + id: "msg-x", + role: "assistant", + content: [{ type: "output_text", text: "Hi" }], + }, + ], + }, + }), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + const deltas = events.filter((e) => e.type === "message_delta"); + expect(deltas).toHaveLength(1); + expect(deltas[0]).toEqual({ type: "message_delta", content: "Hi" }); + }); + + test("treats response.failed as terminal: no events follow the error", async () => { + // SA may keep sending events after `response.failed` (and even a stray + // `response.completed`). The adapter must stop yielding once it has + // surfaced a terminal `status: error` so consumers don't see contradictory + // `message_delta`/`complete` events after the failure. + const { streamBody } = makeStreamBody([ + sseEvent("response.failed", {}), + sseEvent("response.output_text.delta", { delta: "ignored" }), + sseEvent("response.completed", {}), + ]); + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "status", status: "error", error: "Response failed" }, + ]); + }); + + test("does not yield complete when the consumer aborts mid-stream", async () => { + // Stream that yields one delta, then waits forever — the consumer aborts + // after the first event arrives. The adapter must NOT subsequently yield + // a synthesised `complete` from a buffered `response.completed`. + const controller = new AbortController(); + const encoder = new TextEncoder(); + const stream = new ReadableStream({ + start(c) { + c.enqueue( + encoder.encode( + sseEvent("response.output_text.delta", { delta: "Hi" }), + ), + ); + }, + pull() { + return new Promise(() => { + /* never resolves until cancel() */ + }); + }, + }); + + const adapter = new SupervisorApiAdapter({ + streamBody: async () => stream, + model: "databricks-claude-sonnet-4", + }); + + const events: AgentEvent[] = []; + for await (const e of adapter.run(createInput(), { + executeTool: vi.fn(), + signal: controller.signal, + })) { + events.push(e); + if (e.type === "message_delta") controller.abort(); + } + + expect(events).toEqual([ + { type: "status", status: "running" }, + { type: "message_delta", content: "Hi" }, + ]); + }); + + test("recovers when event: and data: lines arrive in separate chunks", async () => { + const { streamBody } = makeStreamBody([ + "event: response.output_text.delta\n", + `data: ${JSON.stringify({ delta: "split" })}\n\n`, + "event: response.completed\ndata: {}\n\n", + ]); + + const adapter = new SupervisorApiAdapter({ + streamBody, + model: "databricks-claude-sonnet-4", + }); + const events = await collect( + adapter.run(createInput(), { executeTool: vi.fn() }), + ); + expect(events).toContainEqual({ + type: "message_delta", + content: "split", + }); + expect(events).toContainEqual({ type: "status", status: "complete" }); + }); +}); + +describe("fromSupervisorApi", () => { + test("calls ensureResolved on the supplied workspace client", async () => { + const ensureResolved = vi.fn(async () => {}); + const adapter = await fromSupervisorApi({ + model: "databricks-claude-sonnet-4", + workspaceClient: { + config: { ensureResolved }, + apiClient: { request: vi.fn() }, + }, + }); + expect(ensureResolved).toHaveBeenCalledTimes(1); + expect(adapter).toBeInstanceOf(SupervisorApiAdapter); + }); + + test("routes streaming through apiClient.request with the SA path", async () => { + const encoder = new TextEncoder(); + const contents = new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(sseEvent("response.completed", {}))); + controller.close(); + }, + }); + const request = vi.fn().mockResolvedValue({ contents }); + + const adapter = await fromSupervisorApi({ + model: "databricks-claude-sonnet-4", + workspaceClient: { + config: { ensureResolved: vi.fn(async () => {}) }, + apiClient: { request }, + }, + }); + + await collect(adapter.run(createInput(), { executeTool: vi.fn() })); + + expect(request).toHaveBeenCalledTimes(1); + const [requestArgs] = request.mock.calls[0]; + expect(requestArgs.path).toBe("/ai-gateway/mlflow/v1/responses"); + expect(requestArgs.method).toBe("POST"); + expect(requestArgs.raw).toBe(true); + expect(requestArgs.payload).toMatchObject({ + model: "databricks-claude-sonnet-4", + input: "Hello", + stream: true, + }); + expect(requestArgs.payload).not.toHaveProperty("tools"); + }); +}); diff --git a/packages/appkit/src/beta.ts b/packages/appkit/src/beta.ts index 3f5bba80c..7ccc77c5b 100644 --- a/packages/appkit/src/beta.ts +++ b/packages/appkit/src/beta.ts @@ -19,6 +19,16 @@ export type { ToolProvider, } from "shared"; export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks"; +export type { + SupervisorApiAdapterCtorOptions, + SupervisorApiAdapterOptions, + SupervisorTool, +} from "./agents/supervisor-api"; +export { + fromSupervisorApi, + SupervisorApiAdapter, + supervisorTools, +} from "./agents/supervisor-api"; // Agent runtime export { createAgent } from "./core/agent/create-agent"; diff --git a/packages/appkit/src/connectors/serving/client.ts b/packages/appkit/src/connectors/serving/client.ts index 83f065e69..f75993a39 100644 --- a/packages/appkit/src/connectors/serving/client.ts +++ b/packages/appkit/src/connectors/serving/client.ts @@ -41,6 +41,20 @@ function cancellationTokenFromAbortSignal( }; } +/** + * Structural shape of a Databricks SDK client we need for the low-level + * `apiClient.request` call. Lets `streamPath` be reused by adapters that + * don't want a hard dependency on the concrete `WorkspaceClient` type. + */ +export interface ApiClientLike { + apiClient: { + request( + options: Record, + context?: unknown, + ): Promise; + }; +} + /** * Invokes a serving endpoint using the SDK's high-level query API. * Returns a typed QueryEndpointResponse. @@ -62,22 +76,23 @@ export async function invoke( } /** - * Returns the raw SSE byte stream from a serving endpoint. - * No parsing is performed — bytes are passed through as-is. + * POSTs `body` as JSON to an arbitrary workspace API path and returns the raw + * SSE byte stream. No parsing is performed — bytes are passed through as-is. + * + * Uses the SDK's low-level `apiClient.request({ raw: true })` so callers + * inherit URL resolution, the SDK credential chain (PAT/OAuth/OIDC), and + * any future retries/telemetry baked into the SDK transport. * - * Uses the SDK's low-level `apiClient.request({ raw: true })` because - * the high-level `servingEndpoints.query()` returns `Promise` - * and does not support SSE streaming. + * When `signal` is provided it is bridged to the SDK's `Context` / + * `CancellationToken` so aborts cancel the outbound HTTP request. */ -export async function stream( - client: WorkspaceClient, - endpointName: string, +export async function streamPath( + client: ApiClientLike, + path: string, body: Record, signal?: AbortSignal, ): Promise> { - const { stream: _stream, ...cleanBody } = body; - - logger.debug("Streaming from endpoint %s", endpointName); + logger.debug("Streaming from path %s", path); const context = signal ? new Context({ @@ -87,17 +102,17 @@ export async function stream( const response = (await client.apiClient.request( { - path: `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + path, method: "POST", headers: new Headers({ "Content-Type": "application/json", Accept: "text/event-stream", }), - payload: { ...cleanBody, stream: true }, + payload: body, raw: true, }, context, - )) as { contents: ReadableStream }; + )) as { contents: ReadableStream | null }; if (!response.contents) { throw new Error("Response body is null — streaming not supported"); @@ -105,3 +120,23 @@ export async function stream( return response.contents; } + +/** + * Returns the raw SSE byte stream from a serving endpoint. Thin wrapper over + * {@link streamPath} that handles serving-specific URL encoding and forces + * `stream: true` in the payload. + */ +export async function stream( + client: WorkspaceClient, + endpointName: string, + body: Record, + signal?: AbortSignal, +): Promise> { + const { stream: _stream, ...cleanBody } = body; + return streamPath( + client as unknown as ApiClientLike, + `/serving-endpoints/${encodeURIComponent(endpointName)}/invocations`, + { ...cleanBody, stream: true }, + signal, + ); +} diff --git a/packages/appkit/src/stream/index.ts b/packages/appkit/src/stream/index.ts index cc756130a..75ad8b5c4 100644 --- a/packages/appkit/src/stream/index.ts +++ b/packages/appkit/src/stream/index.ts @@ -1 +1,2 @@ +export { readSseEvents } from "./sse-reader"; export { StreamManager } from "./stream-manager"; diff --git a/packages/appkit/src/stream/sse-reader.ts b/packages/appkit/src/stream/sse-reader.ts new file mode 100644 index 000000000..091f132dc --- /dev/null +++ b/packages/appkit/src/stream/sse-reader.ts @@ -0,0 +1,114 @@ +/** + * One parsed Server-Sent Event. Field names follow the spec: + * https://html.spec.whatwg.org/multipage/server-sent-events.html + * + * The reader does not interpret `data` (no JSON parsing), so callers control + * the wire shape they expect. + */ +export interface SseEvent { + /** Value of the most recent `event:` field, or `""` for an unnamed event. */ + event: string; + /** Joined `data:` lines for the event (empty string when no data was set). */ + data: string; + /** Value of the most recent `id:` field, or `undefined` if none. */ + id?: string; +} + +/** + * Async-iterates Server-Sent Events from a UTF-8 byte stream. + * + * Block-oriented parser: events are delimited by blank lines (`\n\n` after + * CRLF normalization), so an `event:` line in chunk N pairs correctly with a + * `data:` line in chunk N+1 — no hoisted state needed. + * + * The reader passes through the sentinel string `[DONE]` as `event=""`, + * `data="[DONE]"`. Callers that care about it should match `data === "[DONE]"` + * after destructuring. + * + * Terminates when the stream closes or `signal` aborts; releases the reader + * lock in either case. + */ +export async function* readSseEvents( + stream: ReadableStream, + signal?: AbortSignal, +): AsyncGenerator { + const reader = stream.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + // Cancel the reader on abort so an in-flight `reader.read()` returns + // immediately instead of waiting for the next chunk. Without this, an + // aborted consumer would only notice between reads — fine for chatty + // streams, but unbounded for an idle/heartbeat-less upstream. + const onAbort = () => { + reader.cancel().catch(() => { + // `cancel()` rejects if the stream is already errored/closed; ignore. + }); + }; + if (signal) { + if (signal.aborted) onAbort(); + else signal.addEventListener("abort", onAbort, { once: true }); + } + + try { + while (true) { + if (signal?.aborted) break; + const { done, value } = await reader.read(); + if (done) { + const tail = parseSseBlock(buffer); + if (tail) yield tail; + break; + } + + buffer += decoder.decode(value, { stream: true }); + + const normalized = buffer.replace(/\r\n/g, "\n"); + const blocks = normalized.split("\n\n"); + // Last entry is either an incomplete block or "" (when the chunk ended + // exactly on a boundary). Either way, keep it for the next iteration. + buffer = blocks.pop() ?? ""; + + for (const block of blocks) { + const event = parseSseBlock(block); + if (event) yield event; + } + } + } finally { + if (signal) signal.removeEventListener("abort", onAbort); + reader.releaseLock(); + } +} + +function parseSseBlock(block: string): SseEvent | null { + if (block.length === 0) return null; + const lines = block.split("\n"); + + let eventName = ""; + let id: string | undefined; + const dataLines: string[] = []; + + for (const rawLine of lines) { + const line = rawLine.replace(/\r$/, ""); + if (line === "" || line.startsWith(":")) continue; + + if (line.startsWith("event:")) { + eventName = line.slice(6).trimStart(); + } else if (line.startsWith("data:")) { + dataLines.push(line.slice(5).trimStart()); + } else if (line.startsWith("id:")) { + id = line.slice(3).trimStart(); + } + // Other fields (`retry:`, custom) are ignored by design. + } + + // Per the SSE spec, a block is only dispatched when the data buffer is + // non-empty. Blocks containing only `event:`/`id:` (or comments) do not + // surface as events. + if (dataLines.length === 0) return null; + + return { + event: eventName, + data: dataLines.join("\n"), + id, + }; +} diff --git a/packages/appkit/src/stream/tests/sse-reader.test.ts b/packages/appkit/src/stream/tests/sse-reader.test.ts new file mode 100644 index 000000000..6f7176b62 --- /dev/null +++ b/packages/appkit/src/stream/tests/sse-reader.test.ts @@ -0,0 +1,182 @@ +import { describe, expect, test } from "vitest"; +import { readSseEvents, type SseEvent } from "../sse-reader"; + +function streamOf(chunks: string[]): ReadableStream { + const encoder = new TextEncoder(); + let i = 0; + return new ReadableStream({ + pull(controller) { + if (i < chunks.length) { + controller.enqueue(encoder.encode(chunks[i])); + i++; + } else { + controller.close(); + } + }, + }); +} + +async function collect( + gen: AsyncGenerator, +): Promise { + const out: SseEvent[] = []; + for await (const e of gen) out.push(e); + return out; +} + +describe("readSseEvents", () => { + test("parses a single named event with JSON data", async () => { + const events = await collect( + readSseEvents( + streamOf(['event: response.completed\ndata: {"ok":true}\n\n']), + ), + ); + expect(events).toEqual([ + { event: "response.completed", data: '{"ok":true}', id: undefined }, + ]); + }); + + test("pairs event: and data: across chunk boundaries", async () => { + const events = await collect( + readSseEvents( + streamOf([ + "event: response.output_text.delta\n", + 'data: {"delta":"split"}\n\n', + ]), + ), + ); + expect(events).toEqual([ + { + event: "response.output_text.delta", + data: '{"delta":"split"}', + id: undefined, + }, + ]); + }); + + test("ignores blank lines, comment lines, and unknown fields", async () => { + const events = await collect( + readSseEvents( + streamOf([": heartbeat\n\nretry: 1000\nevent: ping\ndata: hi\n\n"]), + ), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("captures id: when present", async () => { + const events = await collect( + readSseEvents(streamOf(["id: abc-123\nevent: ping\ndata: hi\n\n"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: "abc-123" }]); + }); + + test("falls back to empty event name when only data: is present", async () => { + const events = await collect(readSseEvents(streamOf(["data: 1\n\n"]))); + expect(events).toEqual([{ event: "", data: "1", id: undefined }]); + }); + + test("joins multi-line data: payloads with \\n", async () => { + const events = await collect( + readSseEvents(streamOf(["data: line1\ndata: line2\n\n"])), + ); + expect(events).toEqual([ + { event: "", data: "line1\nline2", id: undefined }, + ]); + }); + + test("normalises CRLF line endings", async () => { + const events = await collect( + readSseEvents(streamOf(["event: x\r\ndata: y\r\n\r\n"])), + ); + expect(events).toEqual([{ event: "x", data: "y", id: undefined }]); + }); + + test("emits a trailing event when the stream closes without a final blank line", async () => { + const events = await collect( + readSseEvents(streamOf(["event: ping\ndata: hi"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("passes through [DONE] sentinels as data", async () => { + const events = await collect(readSseEvents(streamOf(["data: [DONE]\n\n"]))); + expect(events).toEqual([{ event: "", data: "[DONE]", id: undefined }]); + }); + + test("aborts when the signal fires before the next read", async () => { + const controller = new AbortController(); + let pulls = 0; + const stream = new ReadableStream({ + pull(c) { + pulls++; + if (pulls === 1) { + c.enqueue(new TextEncoder().encode("event: a\ndata: 1\n\n")); + } else { + controller.abort(); + c.enqueue(new TextEncoder().encode("event: b\ndata: 2\n\n")); + } + }, + }); + + const out: SseEvent[] = []; + for await (const e of readSseEvents(stream, controller.signal)) { + out.push(e); + if (out.length === 1) controller.abort(); + } + expect(out.map((e) => e.event)).toEqual(["a"]); + }); + + test("aborts an idle reader immediately via reader.cancel()", async () => { + // Stream that sends one event then never resolves further reads — models + // an upstream that has stopped sending data. Without `reader.cancel()` + // the consumer would block forever after aborting. + const controller = new AbortController(); + let cancelled = false; + const stream = new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode("event: a\ndata: 1\n\n")); + }, + pull() { + return new Promise(() => { + /* never resolves */ + }); + }, + cancel() { + cancelled = true; + }, + }); + + const out: SseEvent[] = []; + const iterator = readSseEvents(stream, controller.signal); + const first = await iterator.next(); + if (!first.done) out.push(first.value); + controller.abort(); + const second = await iterator.next(); + expect(second.done).toBe(true); + expect(out.map((e) => e.event)).toEqual(["a"]); + expect(cancelled).toBe(true); + }); + + test("does not dispatch a block whose only field is id: (spec compliance)", async () => { + const events = await collect( + readSseEvents(streamOf(["id: only\n\nevent: ping\ndata: hi\n\n"])), + ); + expect(events).toEqual([{ event: "ping", data: "hi", id: undefined }]); + }); + + test("decodes a multi-byte UTF-8 character split across chunks", async () => { + const checkBytes = new TextEncoder().encode("✓"); + expect(checkBytes.length).toBe(3); + const stream = new ReadableStream({ + start(c) { + c.enqueue(new TextEncoder().encode("data: ")); + c.enqueue(checkBytes.subarray(0, 1)); + c.enqueue(checkBytes.subarray(1)); + c.enqueue(new TextEncoder().encode("\n\n")); + c.close(); + }, + }); + const events = await collect(readSseEvents(stream)); + expect(events).toEqual([{ event: "", data: "✓", id: undefined }]); + }); +});