diff --git a/.changeset/base-toolset.md b/.changeset/base-toolset.md new file mode 100644 index 000000000..805ef6012 --- /dev/null +++ b/.changeset/base-toolset.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": minor +--- + +Add base `Toolset` support: a stateful container for a group of tools with `setup()` / `aclose()` lifecycle hooks. Toolsets can be passed directly into `Agent({ tools: [...] })` alongside individual function tools; their tools are flattened into the agent's `ToolContext` and the runtime drives `setup()` on activity start, `aclose()` on close, and a diff on `updateTools()`. `Toolset.setup()` failures propagate (with rollback of successfully-set-up toolsets) so the agent fails explicitly rather than running with uninitialized resources. The `IGNORE_ON_ENTER` flag is also respected for function tools nested inside a Toolset. Every LLM and realtime plugin tool builder iterates `ToolContext.flatten()` so toolset-contributed tools are correctly advertised. Also exports `ToolCalledEvent` / `ToolCompletedEvent` payload types. diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index 4837f2cb9..bbb77a2fa 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -10,12 +10,15 @@ export { ToolContext, ToolError, ToolFlag, + Toolset, toToolContext, type AgentHandoff, type FunctionTool, type ProviderDefinedTool, type Tool, + type ToolCalledEvent, type ToolChoice, + type ToolCompletedEvent, type ToolContextEntry, type ToolCtxInput, type ToolOptions, diff --git a/agents/src/llm/tool_context.test.ts b/agents/src/llm/tool_context.test.ts index 74ab07683..c4f0aabc5 100644 --- a/agents/src/llm/tool_context.test.ts +++ b/agents/src/llm/tool_context.test.ts @@ -5,7 +5,7 @@ import { describe, expect, it } from 'vitest'; import { z } from 'zod'; import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; -import { ToolContext, type ToolOptions, tool } from './tool_context.js'; +import { ToolContext, type ToolOptions, Toolset, tool } from './tool_context.js'; import { createToolOptions, oaiParams } from './utils.js'; describe('Tool Context', () => { @@ -523,3 +523,56 @@ describe('ToolContext', () => { expect(new ToolContext([a, b]).equals(new ToolContext([a, c]))).toBe(false); }); }); + +describe('Toolset', () => { + const makeFn = (name: string) => + tool({ + name, + description: `${name} tool`, + execute: async () => name, + }); + + it('exposes its id and the tools it was constructed with', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set1', tools: [a, b] }); + + expect(ts.id).toBe('set1'); + expect(ts.tools).toEqual([a, b]); + }); + + it('default setup and aclose are no-ops', async () => { + const ts = new Toolset({ id: 'noop', tools: [] }); + await expect(ts.setup()).resolves.toBeUndefined(); + await expect(ts.aclose()).resolves.toBeUndefined(); + }); + + it('lets subclasses override lifecycle hooks', async () => { + const events: string[] = []; + class Recording extends Toolset { + override async setup(): Promise { + events.push(`setup:${this.id}`); + } + override async aclose(): Promise { + events.push(`close:${this.id}`); + } + } + + const ts = new Recording({ id: 'rec', tools: [] }); + await ts.setup(); + await ts.aclose(); + expect(events).toEqual(['setup:rec', 'close:rec']); + }); + + it('is flattened into a ToolContext: function tools merged, toolset tracked', () => { + const a = makeFn('a'); + const b = makeFn('b'); + const ts = new Toolset({ id: 'set', tools: [a, b] }); + const direct = makeFn('direct'); + + const ctx = new ToolContext([direct, ts]); + + expect(Object.keys(ctx.functionTools).sort()).toEqual(['a', 'b', 'direct']); + expect(ctx.toolsets).toEqual([ts]); + }); +}); diff --git a/agents/src/llm/tool_context.ts b/agents/src/llm/tool_context.ts index 138af5250..64f7f380c 100644 --- a/agents/src/llm/tool_context.ts +++ b/agents/src/llm/tool_context.ts @@ -196,6 +196,43 @@ export interface FunctionTool< [FUNCTION_TOOL_SYMBOL]: true; } +export interface ToolCalledEvent { + ctx: RunContext; + arguments: Record; +} + +export interface ToolCompletedEvent { + ctx: RunContext; + output?: { type: 'output'; value: unknown } | { type: 'error'; value: Error }; +} + +/** + * A stateful collection of tools sharing a lifecycle. Tools registered through a `Toolset` are + * flattened into the surrounding `ToolContext`, while the `Toolset` itself is tracked so its + * `setup()` / `aclose()` hooks can be invoked by the agent runtime. + */ +export class Toolset { + readonly #id: string; + readonly #tools: Tool[]; + + constructor({ id, tools }: { id: string; tools: readonly Tool[] }) { + this.#id = id; + this.#tools = [...tools]; + } + + get id(): string { + return this.#id; + } + + get tools(): readonly Tool[] { + return this.#tools; + } + + async setup(): Promise {} + + async aclose(): Promise {} +} + /** * Convenience input shape accepted by APIs that want to take a list of tools directly without * forcing callers to wrap them in `new ToolContext(...)`. @@ -218,23 +255,25 @@ export function toToolContext( } /** - * A flat, addressable view over a heterogeneous list of `FunctionTool` and `ProviderDefinedTool` - * entries. + * A flat, addressable view over a heterogeneous list of `FunctionTool`, `ProviderDefinedTool`, + * and `Toolset` entries. * * Mirrors the Python `ToolContext`: the original input list is preserved on `_tools`, while - * `_functionToolsMap` and `_providerTools` denormalize it for cheap access. When two function - * tools share the same name the later entry overwrites the earlier one. + * `_functionToolsMap`, `_providerTools`, and `_toolsets` denormalize it for cheap access. Tools + * contributed by a `Toolset` are flattened into the function and provider collections; later + * function tools sharing the same name as an earlier one overwrite the earlier entry. */ // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext entries accept any function-tool parameter/result types export type ToolContextEntry = // eslint-disable-next-line @typescript-eslint/no-explicit-any - FunctionTool | ProviderDefinedTool; + FunctionTool | ProviderDefinedTool | Toolset; export class ToolContext { private _tools: ToolContextEntry[] = []; // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext stores generic function tools private _functionToolsMap: Map> = new Map(); private _providerTools: ProviderDefinedTool[] = []; + private _toolsets: Toolset[] = []; constructor(tools: readonly ToolContextEntry[] = []) { this.updateTools(tools); @@ -244,18 +283,23 @@ export class ToolContext { return new ToolContext([]); } - /** A copy of all function tools in the context. */ + /** A copy of all function tools in the context, including tools contributed by toolsets. */ // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Generic registry over any parameter/result types get functionTools(): Record> { return Object.fromEntries(this._functionToolsMap); } - /** A copy of all provider tools in the context. */ + /** A copy of all provider tools in the context, including provider tools from toolsets. */ get providerTools(): readonly ProviderDefinedTool[] { return [...this._providerTools]; } - /** A copy of the raw tool list this context was constructed with. */ + /** A copy of the toolsets registered in the context. */ + get toolsets(): readonly Toolset[] { + return [...this._toolsets]; + } + + /** A copy of the raw tool/toolset list this context was constructed with. */ get tools(): readonly ToolContextEntry[] { return [...this._tools]; } @@ -286,19 +330,31 @@ export class ToolContext { this._tools = [...tools]; this._functionToolsMap = new Map(); this._providerTools = []; - - for (const tool of tools) { + this._toolsets = []; + + const addTool = (tool: ToolContextEntry | Tool): void => { + if (tool instanceof Toolset) { + for (const inner of tool.tools) { + addTool(inner); + } + this._toolsets.push(tool); + return; + } if (isProviderDefinedTool(tool)) { this._providerTools.push(tool); - continue; + return; } if (isFunctionTool(tool)) { // Later tool wins on duplicate names. `tool()` enforces a non-empty name at // construction so we don't re-check here. this._functionToolsMap.set(tool.name, tool); - continue; + return; } throw new Error(`unknown tool type: ${typeof tool}`); + }; + + for (const tool of tools) { + addTool(tool); } } @@ -315,6 +371,14 @@ export class ToolContext { return false; } } + if (this._toolsets.length !== other._toolsets.length) { + return false; + } + for (let i = 0; i < this._toolsets.length; i++) { + if (this._toolsets[i] !== other._toolsets[i]) { + return false; + } + } if (this._providerTools.length !== other._providerTools.length) { return false; } diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index 48a95e553..e659c57ce 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -35,10 +35,12 @@ import { RealtimeModel, type RealtimeModelError, type RealtimeSession, + type Tool, type ToolChoice, ToolContext, type ToolContextEntry, ToolFlag, + Toolset, } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; import { isSameToolChoice } from '../llm/tool_context.js'; @@ -215,6 +217,7 @@ export class AgentActivity implements RecognitionHooks { private toolChoice: ToolChoice | null = null; private _preemptiveGeneration?: PreemptiveGeneration; private _preemptiveGenerationCount = 0; + private _toolsetsSetup = false; private interruptionDetector?: AdaptiveInterruptionDetector; private isInterruptionDetectionEnabled: boolean; private isInterruptionByAudioActivityEnabled: boolean; @@ -421,6 +424,8 @@ export class AgentActivity implements RecognitionHooks { this.agent._agentActivity = this; + await this.setupToolsets(); + if (this.llm instanceof RealtimeModel) { const rtReused = reuseResources?.rtSession !== undefined; @@ -765,14 +770,32 @@ export class AgentActivity implements RecognitionHooks { } async updateTools(tools: readonly ToolContextEntry[]): Promise { - const oldToolNames = new Set(Object.keys(this.tools.functionTools)); + const oldToolCtx = this.tools; + const oldToolNames = new Set(Object.keys(oldToolCtx.functionTools)); + const oldToolsets = oldToolCtx.toolsets; const newToolCtx = new ToolContext(tools); const newToolNames = new Set(Object.keys(newToolCtx.functionTools)); + const newToolsets = newToolCtx.toolsets; const toolsAdded = [...newToolNames].filter((name) => !oldToolNames.has(name)); const toolsRemoved = [...oldToolNames].filter((name) => !newToolNames.has(name)); + const addedToolsets = newToolsets.filter((toolset) => !oldToolsets.includes(toolset)); + const removedToolsets = oldToolsets.filter((toolset) => !newToolsets.includes(toolset)); + + // Run lifecycle calls in the order setup → swap toolCtx → close so a `setup()` failure + // leaves the agent pointing at the OLD toolCtx. The removed toolsets stay tracked by + // `agent.toolCtx.toolsets` and will be closed by the activity's normal teardown path — + // no leak even if `setupToolsets` throws and propagates. `setupToolsets()` already rolls + // back its own partial successes internally. + if (this._toolsetsSetup) { + await this.setupToolsets(addedToolsets, false); + } this.agent._toolCtx = newToolCtx; + if (this._toolsetsSetup) { + await this.closeToolsets(removedToolsets, false); + } + if (toolsAdded.length > 0 || toolsRemoved.length > 0) { const configUpdate = new AgentConfigUpdate({ toolsAdded: toolsAdded.length > 0 ? toolsAdded : undefined, @@ -1733,11 +1756,17 @@ export class AgentActivity implements RecognitionHooks { const tools: ToolContext = shouldFilterTools ? new ToolContext( - this.agent.toolCtx.tools.filter((t) => { - if (t.type === 'function') { - return !((t as unknown as { flags: number }).flags & ToolFlag.IGNORE_ON_ENTER); + // Recurse into Toolsets so function tools nested inside a Toolset are also subject + // to IGNORE_ON_ENTER. The Toolset itself is omitted from this short-lived context + // because lifecycle ownership stays with the agent's persistent toolCtx. + this.agent.toolCtx.tools.flatMap((t): ToolContextEntry[] => { + const keepFn = (fn: Tool): boolean => + fn.type !== 'function' || + !((fn as unknown as { flags: number }).flags & ToolFlag.IGNORE_ON_ENTER); + if (t instanceof Toolset) { + return t.tools.filter(keepFn) as ToolContextEntry[]; } - return true; + return keepFn(t) ? [t] : []; }), ) : this.agent.toolCtx; @@ -3726,9 +3755,73 @@ export class AgentActivity implements RecognitionHooks { this.realtimeSpans?.clear(); await this.realtimeSession?.close(); await this.audioRecognition?.close(); + await this.closeToolsets(); this.realtimeSession = undefined; this.audioRecognition = undefined; } + + private async setupToolsets( + toolsets: readonly Toolset[] = this.agent.toolCtx.toolsets, + updateSetupState = true, + ): Promise { + if (updateSetupState && this._toolsetsSetup) { + return; + } + + // Toolset.setup() failures bubble up so the activity (and its agent) fails explicitly + // rather than silently advertising tools whose backing resources never initialized. + // If any setup fails, close the ones that already succeeded to avoid leaking their + // backing resources — `closeToolsets()` won't run them after a throw because + // `_toolsetsSetup` never flipped to true. + const outputs = await Promise.allSettled(toolsets.map((toolset) => toolset.setup())); + + let firstError: unknown; + const succeeded: Toolset[] = []; + for (let i = 0; i < outputs.length; i++) { + const output = outputs[i]!; + if (output.status === 'rejected') { + if (firstError === undefined) firstError = output.reason; + } else { + succeeded.push(toolsets[i]!); + } + } + + if (firstError !== undefined) { + const closeOutputs = await Promise.allSettled(succeeded.map((t) => t.aclose())); + for (const output of closeOutputs) { + if (output.status === 'rejected') { + this.logger.error( + { error: output.reason }, + 'error closing toolset during setup rollback', + ); + } + } + throw firstError; + } + + if (updateSetupState) { + this._toolsetsSetup = true; + } + } + + private async closeToolsets( + toolsets: readonly Toolset[] = this.agent.toolCtx.toolsets, + updateSetupState = true, + ): Promise { + if (updateSetupState && !this._toolsetsSetup) { + return; + } + + const outputs = await Promise.allSettled(toolsets.map((toolset) => toolset.aclose())); + for (const output of outputs) { + if (output.status === 'rejected') { + this.logger.error({ error: output.reason }, 'error closing toolset'); + } + } + if (updateSetupState) { + this._toolsetsSetup = false; + } + } } function toOaiToolChoice(toolChoice: ToolChoice | null): ToolChoice | undefined { diff --git a/examples/src/basic_toolsets.ts b/examples/src/basic_toolsets.ts new file mode 100644 index 000000000..f43b3a3dd --- /dev/null +++ b/examples/src/basic_toolsets.ts @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + type JobContext, + type JobProcess, + ServerOptions, + cli, + defineAgent, + inference, + llm, + voice, +} from '@livekit/agents'; +import * as livekit from '@livekit/agents-plugin-livekit'; +import * as silero from '@livekit/agents-plugin-silero'; +import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +export default defineAgent({ + prewarm: async (proc: JobProcess) => { + proc.userData.vad = await silero.VAD.load(); + }, + entry: async (ctx: JobContext) => { + const getWeather = llm.tool({ + name: 'getWeather', + description: 'Get the weather for a given location.', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + execute: async ({ location }) => { + return `The weather in ${location} is sunny.`; + }, + }); + + const lookupTimezone = llm.tool({ + name: 'lookupTimezone', + description: 'Look up the timezone for a city or region.', + parameters: z.object({ + location: z.string().describe('The city or region to look up'), + }), + execute: async ({ location }) => { + return `${location} is in the America/Los_Angeles timezone.`; + }, + }); + + const locationTools = new llm.Toolset({ + id: 'location_tools', + tools: [getWeather, lookupTimezone], + }); + + const agent = new voice.Agent({ + instructions: + 'You are a helpful assistant. Use the location toolset when users ask about weather or timezones.', + tools: [locationTools], + }); + + const session = new voice.AgentSession({ + vad: ctx.proc.userData.vad! as silero.VAD, + stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), + llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), + tts: new inference.TTS({ + model: 'cartesia/sonic-3', + voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', + }), + turnDetection: new livekit.turnDetector.MultilingualModel(), + }); + + await session.start({ + agent, + room: ctx.room, + inputOptions: { + noiseCancellation: BackgroundVoiceCancellation(), + }, + }); + + session.say('Hello, ask me about the weather or timezone for a location.'); + }, +}); + +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/plugins/google/src/utils.ts b/plugins/google/src/utils.ts index 5548c076e..f77864390 100644 --- a/plugins/google/src/utils.ts +++ b/plugins/google/src/utils.ts @@ -139,8 +139,13 @@ function isEmptyObjectSchema(jsonSchema: JSONSchema7Definition): boolean { export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclaration[] { const functionDeclarations: FunctionDeclaration[] = []; - for (const [name, tool] of Object.entries(toolCtx.functionTools)) { - const { description, parameters } = tool; + // flatten() yields function tools + provider tools, including any contributed by Toolsets. + for (const tool of toolCtx.flatten()) { + if (!llm.isFunctionTool(tool)) { + // Provider-defined tools are not wired into the Gemini schema yet. + continue; + } + const { name, description, parameters } = tool; const jsonSchema = llm.toJsonSchema(parameters, false); // Create a deep copy to prevent the Google GenAI library from mutating the schema diff --git a/plugins/mistralai/src/llm.ts b/plugins/mistralai/src/llm.ts index 274f57837..55b3e8d51 100644 --- a/plugins/mistralai/src/llm.ts +++ b/plugins/mistralai/src/llm.ts @@ -211,14 +211,19 @@ export class LLMStream extends llm.LLMStream { // eslint-disable-next-line @typescript-eslint/no-explicit-any const toolsList: any[] = []; - if (this.toolCtx && Object.keys(this.toolCtx.functionTools).length > 0) { - for (const [name, func] of Object.entries(this.toolCtx.functionTools)) { + if (this.toolCtx) { + // flatten() yields function tools + provider tools, including any contributed by Toolsets. + for (const t of this.toolCtx.flatten()) { + if (!llm.isFunctionTool(t)) { + // Provider-defined tools are not wired into the Mistral schema yet (AJS-112). + continue; + } toolsList.push({ type: 'function' as const, function: { - name, - description: func.description, - parameters: llm.toJsonSchema(func.parameters, true, false), + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters, true, false), }, }); } diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 8481e0f47..85435f8fe 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -698,11 +698,16 @@ export class RealtimeSession extends llm.RealtimeSession { // TODO(brian): these logics below are noops I think, leaving it here to keep // parity with the python but we should remove them later const retainedToolNames = new Set(ev.session.tools.map((tool) => tool.name)); - const retainedTools = Object.entries(_tools.functionTools) - .filter(([name]) => retainedToolNames.has(name)) - .map(([, tool]) => tool); + // Keep Toolsets and provider tools as-is; only filter out function tools the server didn't + // accept. This preserves toolset references so subsequent updateTools() diffs are accurate. + const retainedEntries = _tools.tools.filter((entry) => { + if (llm.isFunctionTool(entry)) { + return retainedToolNames.has(entry.name); + } + return true; + }); - this._tools = new llm.ToolContext(retainedTools); + this._tools = new llm.ToolContext(retainedEntries); unlock(); } @@ -710,21 +715,28 @@ export class RealtimeSession extends llm.RealtimeSession { private createToolsUpdateEvent(_tools: llm.ToolContext): api_proto.SessionUpdateEvent { const oaiTools: api_proto.Tool[] = []; - for (const [name, tool] of Object.entries(_tools.functionTools)) { - const { parameters: toolParameters, description } = tool; + // flatten() yields function tools + provider tools, including any contributed by Toolsets. + for (const t of _tools.flatten()) { + if (!llm.isFunctionTool(t)) { + // Provider-defined tools aren't wired into the Realtime session-update schema yet. + continue; + } try { const parameters = llm.toJsonSchema( - toolParameters, + t.parameters, ) as unknown as api_proto.Tool['parameters']; oaiTools.push({ - name, - description, + name: t.name, + description: t.description, parameters: parameters, type: 'function', }); } catch (e) { - this.#logger.error({ name, tool }, "OpenAI Realtime API doesn't support this tool type"); + this.#logger.error( + { name: t.name, tool: t }, + "OpenAI Realtime API doesn't support this tool type", + ); continue; } } diff --git a/plugins/openai/src/responses/llm.ts b/plugins/openai/src/responses/llm.ts index c8982bfb8..fe46d4bed 100644 --- a/plugins/openai/src/responses/llm.ts +++ b/plugins/openai/src/responses/llm.ts @@ -186,24 +186,29 @@ class ResponsesHttpLLMStream extends llm.LLMStream { )) as OpenAI.Responses.ResponseInputItem[]; const tools = this.toolCtx - ? Object.entries(this.toolCtx.functionTools).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name: name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) + ? this.toolCtx + .flatten() + .map((t) => { + if (llm.isFunctionTool(t)) { + const oaiParams = { + type: 'function' as const, + name: t.name, + description: t.description, + parameters: llm.toJsonSchema( + t.parameters, + true, + this.strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + if (this.strictToolSchema) { + oaiParams.strict = true; + } + return oaiParams; + } + // Provider-defined tools are not wired up here yet; skip until AJS-112 lands. + return undefined; + }) + .filter((t): t is NonNullable => t !== undefined) : undefined; const requestOptions: Record = { ...this.modelOptions }; diff --git a/plugins/openai/src/ws/llm.ts b/plugins/openai/src/ws/llm.ts index 02ab6d1d3..db6a80123 100644 --- a/plugins/openai/src/ws/llm.ts +++ b/plugins/openai/src/ws/llm.ts @@ -429,24 +429,29 @@ export class WSLLMStream extends llm.LLMStream { )) as OpenAI.Responses.ResponseInputItem[]; const tools = this.toolCtx - ? Object.entries(this.toolCtx.functionTools).map(([name, func]) => { - const oaiParams = { - type: 'function' as const, - name, - description: func.description, - parameters: llm.toJsonSchema( - func.parameters, - true, - this.#strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.#strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) + ? this.toolCtx + .flatten() + .map((t) => { + if (llm.isFunctionTool(t)) { + const oaiParams = { + type: 'function' as const, + name: t.name, + description: t.description, + parameters: llm.toJsonSchema( + t.parameters, + true, + this.#strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + if (this.#strictToolSchema) { + oaiParams.strict = true; + } + return oaiParams; + } + // Provider-defined tools are not wired up here yet; skip until AJS-112 lands. + return undefined; + }) + .filter((t): t is NonNullable => t !== undefined) : undefined; const requestOptions: Record = { ...this.#modelOptions }; diff --git a/plugins/phonic/src/realtime/realtime_model.ts b/plugins/phonic/src/realtime/realtime_model.ts index 09933b580..16d9fd52c 100644 --- a/plugins/phonic/src/realtime/realtime_model.ts +++ b/plugins/phonic/src/realtime/realtime_model.ts @@ -368,23 +368,27 @@ export class RealtimeSession extends llm.RealtimeSession { } this._tools = tools.copy(); - this.toolDefinitions = Object.entries(tools.functionTools).map(([name, tool]) => ({ - type: 'custom_websocket', - tool_schema: { - type: 'function', - function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), - strict: true, + // flatten() yields function tools + provider tools, including any contributed by Toolsets. + this.toolDefinitions = tools + .flatten() + .filter((t): t is llm.FunctionTool => llm.isFunctionTool(t)) + .map((t) => ({ + type: 'custom_websocket' as const, + tool_schema: { + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), + strict: true, + }, }, - }, - tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, - // Tool chaining and tool calls during speech are not supported at this time - // for ease of implementation within the RealtimeSession generations framework - wait_for_speech_before_tool_call: true, - allow_tool_chaining: false, - })); + tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, + // Tool chaining and tool calls during speech are not supported at this time + // for ease of implementation within the RealtimeSession generations framework + wait_for_speech_before_tool_call: true, + allow_tool_chaining: false, + })); this.toolsReady.resolve(); } @@ -404,21 +408,25 @@ export class RealtimeSession extends llm.RealtimeSession { } if (tools !== undefined) { this._tools = tools.copy(); - this.toolDefinitions = Object.entries(tools.functionTools).map(([name, tool]) => ({ - type: 'custom_websocket', - tool_schema: { - type: 'function', - function: { - name, - description: tool.description, - parameters: llm.toJsonSchema(tool.parameters), - strict: true, + // flatten() yields function tools + provider tools, including any contributed by Toolsets. + this.toolDefinitions = tools + .flatten() + .filter((t): t is llm.FunctionTool => llm.isFunctionTool(t)) + .map((t) => ({ + type: 'custom_websocket' as const, + tool_schema: { + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: llm.toJsonSchema(t.parameters), + strict: true, + }, }, - }, - tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, - wait_for_speech_before_tool_call: true, - allow_tool_chaining: false, - })); + tool_call_output_timeout_ms: TOOL_CALL_OUTPUT_TIMEOUT_MS, + wait_for_speech_before_tool_call: true, + allow_tool_chaining: false, + })); } if (chatCtx !== undefined) { this._chatCtx = chatCtx.copy(); diff --git a/plugins/test/src/llm.ts b/plugins/test/src/llm.ts index 534dce2df..7de794735 100644 --- a/plugins/test/src/llm.ts +++ b/plugins/test/src/llm.ts @@ -200,6 +200,62 @@ export const llm = async (llm: llmlib.LLM, skipOptionalArgs: boolean) => { expect(JSON.parse(calls[0]!.args).address).toBeUndefined(); }); }); + + describe('toolset', async () => { + // Tools registered through a Toolset must reach the underlying LLM exactly like + // top-level tools: the plugin must flatten the toolset's contents before advertising + // tools to the provider, and the model must be able to call them. + const buildToolsetContext = () => { + const weatherToolset = new llmlib.Toolset({ + id: 'weather_toolset', + tools: [ + llmlib.tool({ + name: 'getWeather', + description: 'Get the current weather in a given location', + parameters: z.object({ + location: z.string().describe('The city and state, e.g. San Francisco, CA'), + unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'), + }), + execute: async () => {}, + }), + ], + }); + + // Mix a direct top-level tool with the Toolset so we also confirm that direct tools + // continue to work side-by-side. + const directTool = llmlib.tool({ + name: 'playMusic', + description: 'Play music', + parameters: z.object({ + name: z.string().describe('The artist and name of the song'), + }), + execute: async () => {}, + }); + + return new llmlib.ToolContext([weatherToolset, directTool]); + }; + + it('should call a function tool that lives inside a Toolset', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall( + llm, + "What's the weather in San Francisco, in Celsius?", + ctx, + ); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('getWeather'); + expect(JSON.parse(calls[0]!.args).unit).toStrictEqual('celsius'); + }); + + it('should expose direct tools alongside Toolset tools', async () => { + const ctx = buildToolsetContext(); + const calls = await requestFncCall(llm, 'Play the song "Bohemian Rhapsody" by Queen.', ctx); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.name).toStrictEqual('playMusic'); + }); + }); }); };