diff --git a/docs/client.md b/docs/client.md index 0946eeec97..bd4e92d4f5 100644 --- a/docs/client.md +++ b/docs/client.md @@ -35,7 +35,7 @@ import { StdioClientTransport } from '@modelcontextprotocol/client/stdio'; ### Streamable HTTP -For remote HTTP servers, use {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}: +For remote HTTP servers, use {@linkcode @modelcontextprotocol/client!client/modernStreamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}: ```ts source="../examples/client/src/clientGuide.examples.ts#connect_streamableHttp" const client = new Client({ name: 'my-client', version: '1.0.0' }); @@ -49,7 +49,7 @@ For a full interactive client over Streamable HTTP, see [`simpleStreamableHttp.t ### stdio -For local, process-spawned servers (Claude Desktop, CLI tools), use {@linkcode @modelcontextprotocol/client!client/stdio.StdioClientTransport | StdioClientTransport}. The transport spawns the server process and communicates over stdin/stdout: +For local, process-spawned servers (Claude Desktop, CLI tools), use {@linkcode @modelcontextprotocol/client!client/modernStdio.StdioClientTransport | StdioClientTransport}. The transport spawns the server process and communicates over stdin/stdout: ```ts source="../examples/client/src/clientGuide.examples.ts#connect_stdio" const client = new Client({ name: 'my-client', version: '1.0.0' }); @@ -64,7 +64,7 @@ await client.connect(transport); ### SSE fallback for legacy servers -To support both modern Streamable HTTP and legacy SSE servers, try {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} first and fall back to {@linkcode @modelcontextprotocol/client!client/sse.SSEClientTransport | SSEClientTransport} on failure: +To support both modern Streamable HTTP and legacy SSE servers, try {@linkcode @modelcontextprotocol/client!client/modernStreamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} first and fall back to {@linkcode @modelcontextprotocol/client!client/sse.SSEClientTransport | SSEClientTransport} on failure: ```ts source="../examples/client/src/clientGuide.examples.ts#connect_sseFallback" const baseUrl = new URL(url); @@ -113,7 +113,7 @@ console.log(systemPrompt); ## Authentication -MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/modernStreamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. ### Bearer tokens @@ -162,7 +162,7 @@ For a runnable example supporting both auth methods via environment variables, s ### Full OAuth with user authorization -For user-facing applications, implement the {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface to handle the full authorization code flow (redirects, code verifiers, token storage, dynamic client registration). The {@linkcode @modelcontextprotocol/client!client/client.Client#connect | connect()} call will throw {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} when authorization is needed — catch it, complete the browser flow, call {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport#finishAuth | transport.finishAuth(code)}, and reconnect. +For user-facing applications, implement the {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface to handle the full authorization code flow (redirects, code verifiers, token storage, dynamic client registration). The {@linkcode @modelcontextprotocol/client!client/client.Client#connect | connect()} call will throw {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} when authorization is needed — catch it, complete the browser flow, call {@linkcode @modelcontextprotocol/client!client/modernStreamableHttp.StreamableHTTPClientTransport#finishAuth | transport.finishAuth(code)}, and reconnect. For a complete working OAuth flow, see [`simpleOAuthClient.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleOAuthClient.ts) and [`simpleOAuthClientProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleOAuthClientProvider.ts). @@ -599,14 +599,7 @@ For an end-to-end example of server-initiated SSE disconnection and automatic cl ## Tasks (experimental) > [!WARNING] -> The tasks API is experimental and may change without notice. - -Task-based execution enables "call-now, fetch-later" patterns for long-running operations (see [Tasks](https://modelcontextprotocol.io/specification/latest/basic/utilities/tasks) in the MCP specification). Instead of returning a result immediately, a tool creates a task that can be polled or resumed later. To use tasks: - -- Call {@linkcode @modelcontextprotocol/client!experimental/tasks/client.ExperimentalClientTasks#callToolStream | client.experimental.tasks.callToolStream(...)} to start a tool call that may create a task and emit status updates over time. -- Call {@linkcode @modelcontextprotocol/client!experimental/tasks/client.ExperimentalClientTasks#getTask | client.experimental.tasks.getTask(...)} and {@linkcode @modelcontextprotocol/client!experimental/tasks/client.ExperimentalClientTasks#getTaskResult | getTaskResult(...)} to check status and fetch results after reconnecting. - -For a full runnable example, see [`simpleTaskInteractiveClient.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTaskInteractiveClient.ts). +> The tasks API has been removed from this version of the SDK. See the [Migration guide](./migration.md) for details. ## See also diff --git a/docs/server.md b/docs/server.md index 3b173af4e0..b2a8f3343f 100644 --- a/docs/server.md +++ b/docs/server.md @@ -54,7 +54,7 @@ For a complete server with sessions, logging, and CORS mounted on Express, see [ ### stdio -For local, process-spawned integrations, use {@linkcode @modelcontextprotocol/server!server/stdio.StdioServerTransport | StdioServerTransport}: +For local, process-spawned integrations, use {@linkcode @modelcontextprotocol/server!server/modernStdio.StdioServerTransport | StdioServerTransport}: ```ts source="../examples/server/src/serverGuide.examples.ts#stdio_basic" const server = new McpServer({ name: 'my-server', version: '1.0.0' }); @@ -498,15 +498,7 @@ server.registerTool( ## Tasks (experimental) > [!WARNING] -> The tasks API is experimental and may change without notice. - -Task-based execution enables "call-now, fetch-later" patterns for long-running operations (see [Tasks](https://modelcontextprotocol.io/specification/latest/basic/utilities/tasks) in the MCP specification). Instead of returning a result immediately, a tool creates a task that can be polled or resumed later. To use tasks: - -- Provide a {@linkcode @modelcontextprotocol/server!index.TaskStore | TaskStore} implementation that persists task metadata and results (see {@linkcode @modelcontextprotocol/server!index.InMemoryTaskStore | InMemoryTaskStore} for reference). -- Enable the `tasks` capability when constructing the server. -- Register tools with {@linkcode @modelcontextprotocol/server!experimental/tasks/mcpServer.ExperimentalMcpServerTasks#registerToolTask | server.experimental.tasks.registerToolTask(...)}. - -For a full runnable example, see [`simpleTaskInteractive.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/server/src/simpleTaskInteractive.ts). +> The tasks API has been removed from this version of the SDK. See the [Migration guide](./migration.md) for details. ## Shutdown diff --git a/examples/client/README.md b/examples/client/README.md index 12a2b0d68b..9a408d065e 100644 --- a/examples/client/README.md +++ b/examples/client/README.md @@ -35,7 +35,6 @@ Most clients expect a server to be running. Start one from [`../server/README.md | OAuth provider helper | Demonstrates reusable OAuth providers. | [`src/simpleOAuthClientProvider.ts`](src/simpleOAuthClientProvider.ts) | | Client credentials (M2M) | Machine-to-machine OAuth client credentials example. | [`src/simpleClientCredentials.ts`](src/simpleClientCredentials.ts) | | URL elicitation client | Drives URL-mode elicitation flows (sensitive input in a browser). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | -| Task interactive client | Demonstrates task-based execution + interactive server→client requests. | [`src/simpleTaskInteractiveClient.ts`](src/simpleTaskInteractiveClient.ts) | ## URL elicitation example (server + client) diff --git a/examples/client/src/simpleOAuthClient.ts b/examples/client/src/simpleOAuthClient.ts index c75aea9483..193bc28f0b 100644 --- a/examples/client/src/simpleOAuthClient.ts +++ b/examples/client/src/simpleOAuthClient.ts @@ -4,7 +4,7 @@ import { createServer } from 'node:http'; import { createInterface } from 'node:readline'; import { URL } from 'node:url'; -import type { CallToolResult, ListToolsRequest, OAuthClientMetadata } from '@modelcontextprotocol/client'; +import type { ListToolsRequest, OAuthClientMetadata } from '@modelcontextprotocol/client'; import { Client, StreamableHTTPClientTransport, UnauthorizedError } from '@modelcontextprotocol/client'; import open from 'open'; @@ -209,7 +209,6 @@ class InteractiveOAuthClient { console.log('Commands:'); console.log(' list - List available tools'); console.log(' call [args] - Call a tool'); - console.log(' stream [args] - Call a tool with streaming (shows task status)'); console.log(' quit - Exit the client'); console.log(); @@ -229,10 +228,8 @@ class InteractiveOAuthClient { await this.listTools(); } else if (command.startsWith('call ')) { await this.handleCallTool(command); - } else if (command.startsWith('stream ')) { - await this.handleStreamTool(command); } else { - console.log("❌ Unknown command. Try 'list', 'call ', 'stream ', or 'quit'"); + console.log("❌ Unknown command. Try 'list', 'call ', or 'quit'"); } } catch (error) { if (error instanceof Error && error.message === 'SIGINT') { @@ -328,94 +325,6 @@ class InteractiveOAuthClient { } } - private async handleStreamTool(command: string): Promise { - const parts = command.split(/\s+/); - const toolName = parts[1]; - - if (!toolName) { - console.log('❌ Please specify a tool name'); - return; - } - - // Parse arguments (simple JSON-like format) - let toolArgs: Record = {}; - if (parts.length > 2) { - const argsString = parts.slice(2).join(' '); - try { - toolArgs = JSON.parse(argsString); - } catch { - console.log('❌ Invalid arguments format (expected JSON)'); - return; - } - } - - await this.streamTool(toolName, toolArgs); - } - - private async streamTool(toolName: string, toolArgs: Record): Promise { - if (!this.client) { - console.log('❌ Not connected to server'); - return; - } - - try { - // Using the experimental tasks API - WARNING: may change without notice - console.log(`\n🔧 Streaming tool '${toolName}'...`); - - const stream = this.client.experimental.tasks.callToolStream( - { - name: toolName, - arguments: toolArgs - }, - { - task: { - taskId: `task-${Date.now()}`, - ttl: 60_000 - } - } - ); - - // Iterate through all messages yielded by the generator - for await (const message of stream) { - switch (message.type) { - case 'taskCreated': { - console.log(`✓ Task created: ${message.task.taskId}`); - break; - } - - case 'taskStatus': { - console.log(`⟳ Status: ${message.task.status}`); - if (message.task.statusMessage) { - console.log(` ${message.task.statusMessage}`); - } - break; - } - - case 'result': { - console.log('✓ Completed!'); - const toolResult = message.result as CallToolResult; - for (const content of toolResult.content) { - if (content.type === 'text') { - console.log(content.text); - } else { - console.log(content); - } - } - break; - } - - case 'error': { - console.log('✗ Error:'); - console.log(` ${message.error.message}`); - break; - } - } - } - } catch (error) { - console.error(`❌ Failed to stream tool '${toolName}':`, error); - } - } - close(): void { this.rl.close(); if (this.client) { diff --git a/examples/client/src/simpleStreamableHttp.ts b/examples/client/src/simpleStreamableHttp.ts index f22d16ba4b..ec3077e29d 100644 --- a/examples/client/src/simpleStreamableHttp.ts +++ b/examples/client/src/simpleStreamableHttp.ts @@ -1,7 +1,6 @@ import { createInterface } from 'node:readline'; import type { - CallToolResult, GetPromptRequest, ListPromptsRequest, ListResourcesRequest, @@ -9,15 +8,7 @@ import type { ReadResourceRequest, ResourceLink } from '@modelcontextprotocol/client'; -import { - Client, - getDisplayName, - InMemoryTaskStore, - ProtocolError, - ProtocolErrorCode, - RELATED_TASK_META_KEY, - StreamableHTTPClientTransport -} from '@modelcontextprotocol/client'; +import { Client, getDisplayName, ProtocolError, ProtocolErrorCode, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; import { Ajv } from 'ajv'; // Create readline interface for user input @@ -56,11 +47,9 @@ function printHelp(): void { console.log(' reconnect - Reconnect to the server'); console.log(' list-tools - List available tools'); console.log(' call-tool [args] - Call a tool with optional JSON arguments'); - console.log(' call-tool-task [args] - Call a tool with task-based execution (example: call-tool-task delay {"duration":3000})'); console.log(' greet [name] - Call the greet tool'); console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); console.log(' collect-info [type] - Test form elicitation with collect-user-info tool (contact/preferences/feedback)'); - console.log(' collect-info-task [type] - Test bidirectional task support (server+client tasks) with elicitation'); console.log(' start-notifications [interval] [count] - Start periodic notifications'); console.log(' run-notifications-tool-with-resumability [interval] [count] - Run notification tool with resumability'); console.log(' list-prompts - List available prompts'); @@ -136,11 +125,6 @@ function commandLoop(): void { break; } - case 'collect-info-task': { - await callCollectInfoWithTask(args[1] || 'contact'); - break; - } - case 'start-notifications': { const interval = args[1] ? Number.parseInt(args[1], 10) : 2000; const count = args[2] ? Number.parseInt(args[2], 10) : 10; @@ -155,24 +139,6 @@ function commandLoop(): void { break; } - case 'call-tool-task': { - if (args.length < 2) { - console.log('Usage: call-tool-task [args]'); - } else { - const toolName = args[1]!; - let toolArgs = {}; - if (args.length > 2) { - try { - toolArgs = JSON.parse(args.slice(2).join(' ')); - } catch { - console.log('Invalid JSON arguments. Using empty args.'); - } - } - await callToolTask(toolName, toolArgs); - } - break; - } - case 'list-prompts': { await listPrompts(); break; @@ -250,10 +216,7 @@ async function connect(url?: string): Promise { console.log(`Connecting to ${serverUrl}...`); try { - // Create task store for client-side task support - const clientTaskStore = new InMemoryTaskStore(); - - // Create a new client with form elicitation capability and task support + // Create a new client with form elicitation capability client = new Client( { name: 'example-client', @@ -263,14 +226,6 @@ async function connect(url?: string): Promise { capabilities: { elicitation: { form: {} - }, - tasks: { - taskStore: clientTaskStore, - requests: { - elicitation: { - create: {} - } - } } } } @@ -279,33 +234,16 @@ async function connect(url?: string): Promise { console.error('\u001B[31mClient error:', error, '\u001B[0m'); }; - // Set up elicitation request handler with proper validation and task support - client.setRequestHandler('elicitation/create', async (request, extra) => { + // Set up elicitation request handler with proper validation + client.setRequestHandler('elicitation/create', async request => { if (request.params.mode !== 'form') { throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); } - console.log('\n🔔 Elicitation (form) Request Received:'); + console.log('\n Elicitation (form) Request Received:'); console.log(`Message: ${request.params.message}`); - console.log(`Related Task: ${request.params._meta?.[RELATED_TASK_META_KEY]?.taskId}`); - console.log(`Task Creation Requested: ${request.params.task ? 'yes' : 'no'}`); console.log('Requested Schema:'); console.log(JSON.stringify(request.params.requestedSchema, null, 2)); - // Helper to return result, optionally creating a task if requested - const returnResult = async (result: { - action: 'accept' | 'decline' | 'cancel'; - content?: Record; - }) => { - if (request.params.task && extra.task?.store) { - // Create a task and store the result - const task = await extra.task.store.createTask({ ttl: extra.task.requestedTtl }); - await extra.task.store.storeTaskResult(task.taskId, 'completed', result); - console.log(`📋 Created client-side task: ${task.taskId}`); - return { task }; - } - return result; - }; - const schema = request.params.requestedSchema; const properties = schema.properties; const required = schema.required || []; @@ -439,7 +377,7 @@ async function connect(url?: string): Promise { } if (inputCancelled) { - return returnResult({ action: 'cancel' }); + return { action: 'cancel' }; } // If we didn't complete all fields due to an error, try again @@ -452,7 +390,7 @@ async function connect(url?: string): Promise { continue; } else { console.log('Maximum attempts reached. Declining request.'); - return returnResult({ action: 'decline' }); + return { action: 'decline' }; } } @@ -471,7 +409,7 @@ async function connect(url?: string): Promise { continue; } else { console.log('Maximum attempts reached. Declining request.'); - return returnResult({ action: 'decline' }); + return { action: 'decline' }; } } @@ -488,14 +426,14 @@ async function connect(url?: string): Promise { switch (confirmAnswer) { case 'yes': case 'y': { - return returnResult({ + return { action: 'accept', content - }); + }; } case 'cancel': case 'c': { - return returnResult({ action: 'cancel' }); + return { action: 'cancel' }; } case 'no': case 'n': { @@ -503,7 +441,7 @@ async function connect(url?: string): Promise { console.log('Please re-enter the information...'); continue; } else { - return returnResult({ action: 'decline' }); + return { action: 'decline' }; } break; @@ -513,7 +451,7 @@ async function connect(url?: string): Promise { } console.log('Maximum attempts reached. Declining request.'); - return returnResult({ action: 'decline' }); + return { action: 'decline' }; }); transport = new StreamableHTTPClientTransport(new URL(serverUrl), { @@ -716,12 +654,6 @@ async function callCollectInfoTool(infoType: string): Promise { await callTool('collect-user-info', { infoType }); } -async function callCollectInfoWithTask(infoType: string): Promise { - console.log(`\n🔄 Testing bidirectional task support with collect-user-info-task tool (${infoType})...`); - console.log('This will create a task on the server, which will elicit input and create a task on the client.\n'); - await callToolTask('collect-user-info-task', { infoType }); -} - async function startNotifications(interval: number, count: number): Promise { console.log(`Starting notification stream: interval=${interval}ms, count=${count || 'unlimited'}`); await callTool('start-notification-stream', { interval, count }); @@ -880,70 +812,6 @@ async function readResource(uri: string): Promise { } } -async function callToolTask(name: string, args: Record): Promise { - if (!client) { - console.log('Not connected to server.'); - return; - } - - console.log(`Calling tool '${name}' with task-based execution...`); - console.log('Arguments:', args); - - // Use task-based execution - call now, fetch later - // Using the experimental tasks API - WARNING: may change without notice - console.log('This will return immediately while processing continues in the background...'); - - try { - // Call the tool with task metadata using streaming API - const stream = client.experimental.tasks.callToolStream( - { - name, - arguments: args - }, - { - task: { - ttl: 60_000 // Keep results for 60 seconds - } - } - ); - - console.log('Waiting for task completion...'); - - let lastStatus = ''; - for await (const message of stream) { - switch (message.type) { - case 'taskCreated': { - console.log('Task created successfully with ID:', message.task.taskId); - break; - } - case 'taskStatus': { - if (lastStatus !== message.task.status) { - console.log(` ${message.task.status}${message.task.statusMessage ? ` - ${message.task.statusMessage}` : ''}`); - } - lastStatus = message.task.status; - break; - } - case 'result': { - console.log('Task completed!'); - console.log('Tool result:'); - const toolResult = message.result as CallToolResult; - for (const item of toolResult.content) { - if (item.type === 'text') { - console.log(` ${item.text}`); - } - } - break; - } - case 'error': { - throw message.error; - } - } - } - } catch (error) { - console.log(`Error with task-based execution: ${error}`); - } -} - async function cleanup(): Promise { if (client && transport) { try { diff --git a/examples/client/src/simpleTaskInteractiveClient.ts b/examples/client/src/simpleTaskInteractiveClient.ts deleted file mode 100644 index 0a35faba24..0000000000 --- a/examples/client/src/simpleTaskInteractiveClient.ts +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Simple interactive task client demonstrating elicitation and sampling responses. - * - * This client connects to simpleTaskInteractive.ts server and demonstrates: - * - Handling elicitation requests (y/n confirmation) - * - Handling sampling requests (returns a hardcoded haiku) - * - Using task-based tool execution with streaming - */ - -import { createInterface } from 'node:readline'; - -import type { CallToolResult, CreateMessageRequest, CreateMessageResult, TextContent } from '@modelcontextprotocol/client'; -import { Client, ProtocolError, ProtocolErrorCode, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; - -// Create readline interface for user input -const readline = createInterface({ - input: process.stdin, - output: process.stdout -}); - -function question(prompt: string): Promise { - return new Promise(resolve => { - readline.question(prompt, answer => { - resolve(answer.trim()); - }); - }); -} - -function getTextContent(result: { content: Array<{ type: string; text?: string }> }): string { - const textContent = result.content.find((c): c is TextContent => c.type === 'text'); - return textContent?.text ?? '(no text)'; -} - -async function elicitationCallback(params: { - mode?: string; - message: string; - requestedSchema?: object; -}): Promise<{ action: 'accept' | 'cancel' | 'decline'; content?: Record }> { - console.log(`\n[Elicitation] Server asks: ${params.message}`); - - // Simple terminal prompt for y/n - const response = await question('Your response (y/n): '); - const confirmed = ['y', 'yes', 'true', '1'].includes(response.toLowerCase()); - - console.log(`[Elicitation] Responding with: confirm=${confirmed}`); - return { action: 'accept', content: { confirm: confirmed } }; -} - -async function samplingCallback(params: CreateMessageRequest['params']): Promise { - // Get the prompt from the first message - let prompt = 'unknown'; - if (params.messages && params.messages.length > 0) { - const firstMessage = params.messages[0]!; - const content = firstMessage.content; - if (typeof content === 'object' && !Array.isArray(content) && content.type === 'text' && 'text' in content) { - prompt = content.text; - } else if (Array.isArray(content)) { - const textPart = content.find(c => c.type === 'text' && 'text' in c); - if (textPart && 'text' in textPart) { - prompt = textPart.text; - } - } - } - - console.log(`\n[Sampling] Server requests LLM completion for: ${prompt}`); - - // Return a hardcoded haiku (in real use, call your LLM here) - const haiku = `Cherry blossoms fall -Softly on the quiet pond -Spring whispers goodbye`; - - console.log('[Sampling] Responding with haiku'); - return { - model: 'mock-haiku-model', - role: 'assistant', - content: { type: 'text', text: haiku } - }; -} - -async function run(url: string): Promise { - console.log('Simple Task Interactive Client'); - console.log('=============================='); - console.log(`Connecting to ${url}...`); - - // Create client with elicitation and sampling capabilities - const client = new Client( - { name: 'simple-task-interactive-client', version: '1.0.0' }, - { - capabilities: { - elicitation: { form: {} }, - sampling: {} - } - } - ); - - // Set up elicitation request handler - client.setRequestHandler('elicitation/create', async request => { - if (request.params.mode && request.params.mode !== 'form') { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Unsupported elicitation mode: ${request.params.mode}`); - } - return elicitationCallback(request.params); - }); - - // Set up sampling request handler - client.setRequestHandler('sampling/createMessage', async request => { - return samplingCallback(request.params) as unknown as ReturnType; - }); - - // Connect to server - const transport = new StreamableHTTPClientTransport(new URL(url)); - await client.connect(transport); - console.log('Connected!\n'); - - // List tools - const toolsResult = await client.listTools(); - console.log(`Available tools: ${toolsResult.tools.map(t => t.name).join(', ')}`); - - // Demo 1: Elicitation (confirm_delete) - console.log('\n--- Demo 1: Elicitation ---'); - console.log('Calling confirm_delete tool...'); - - const confirmStream = client.experimental.tasks.callToolStream( - { name: 'confirm_delete', arguments: { filename: 'important.txt' } }, - { task: { ttl: 60_000 } } - ); - - for await (const message of confirmStream) { - switch (message.type) { - case 'taskCreated': { - console.log(`Task created: ${message.task.taskId}`); - break; - } - case 'taskStatus': { - console.log(`Task status: ${message.task.status}`); - break; - } - case 'result': { - const toolResult = message.result as CallToolResult; - console.log(`Result: ${getTextContent(toolResult)}`); - break; - } - case 'error': { - console.error(`Error: ${message.error}`); - break; - } - } - } - - // Demo 2: Sampling (write_haiku) - console.log('\n--- Demo 2: Sampling ---'); - console.log('Calling write_haiku tool...'); - - const haikuStream = client.experimental.tasks.callToolStream( - { name: 'write_haiku', arguments: { topic: 'autumn leaves' } }, - { task: { ttl: 60_000 } } - ); - - for await (const message of haikuStream) { - switch (message.type) { - case 'taskCreated': { - console.log(`Task created: ${message.task.taskId}`); - break; - } - case 'taskStatus': { - console.log(`Task status: ${message.task.status}`); - break; - } - case 'result': { - const toolResult = message.result as CallToolResult; - console.log(`Result:\n${getTextContent(toolResult)}`); - break; - } - case 'error': { - console.error(`Error: ${message.error}`); - break; - } - } - } - - // Cleanup - console.log('\nDemo complete. Closing connection...'); - await transport.close(); - readline.close(); -} - -// Parse command line arguments -const args = process.argv.slice(2); -let url = 'http://localhost:8000/mcp'; - -for (let i = 0; i < args.length; i++) { - if (args[i] === '--url' && args[i + 1]) { - url = args[i + 1]!; - i++; - } -} - -// Run the client -try { - await run(url); -} catch (error) { - console.error('Error running client:', error); - // eslint-disable-next-line unicorn/no-process-exit - process.exit(1); -} diff --git a/examples/server/README.md b/examples/server/README.md index 0f684bec7e..3151cf7f20 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -36,7 +36,6 @@ pnpm tsx src/simpleStreamableHttp.ts | Form elicitation server | Collects **non-sensitive** user input via schema-driven forms. | [`src/elicitationFormExample.ts`](src/elicitationFormExample.ts) | | URL elicitation server | Secure browser-based flows for **sensitive** input (API keys, OAuth, payments). | [`src/elicitationUrlExample.ts`](src/elicitationUrlExample.ts) | | Sampling + tasks server | Demonstrates sampling and experimental task-based execution. | [`src/toolWithSampleServer.ts`](src/toolWithSampleServer.ts) | -| Task interactive server | Task-based execution with interactive server→client requests. | [`src/simpleTaskInteractive.ts`](src/simpleTaskInteractive.ts) | | Hono Streamable HTTP server | Streamable HTTP server built with Hono instead of Express. | [`src/honoWebStandardStreamableHttp.ts`](src/honoWebStandardStreamableHttp.ts) | | SSE polling demo server | Legacy SSE server intended for polling demos. | [`src/ssePollingExample.ts`](src/ssePollingExample.ts) | diff --git a/examples/server/src/simpleStreamableHttp.ts b/examples/server/src/simpleStreamableHttp.ts index 6da0841ec1..1f0998cca9 100644 --- a/examples/server/src/simpleStreamableHttp.ts +++ b/examples/server/src/simpleStreamableHttp.ts @@ -5,13 +5,12 @@ import { createMcpExpressApp, getOAuthProtectedResourceMetadataUrl, requireBeare import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; import type { CallToolResult, - ElicitResult, GetPromptResult, PrimitiveSchemaDefinition, ReadResourceResult, ResourceLink } from '@modelcontextprotocol/server'; -import { InMemoryTaskMessageQueue, InMemoryTaskStore, isInitializeRequest, McpServer } from '@modelcontextprotocol/server'; +import { isInitializeRequest, McpServer } from '@modelcontextprotocol/server'; import cors from 'cors'; import type { Request, Response } from 'express'; import * as z from 'zod/v4'; @@ -22,9 +21,6 @@ import { InMemoryEventStore } from './inMemoryEventStore.js'; const useOAuth = process.argv.includes('--oauth'); const dangerousLoggingEnabled = process.argv.includes('--dangerous-logging-enabled'); -// Create shared task store for demonstration -const taskStore = new InMemoryTaskStore(); - // Create an MCP server with implementation details const getServer = () => { const server = new McpServer( @@ -36,12 +32,7 @@ const getServer = () => { }, { capabilities: { - logging: {}, - tasks: { - requests: { tools: { call: {} } }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() - } + logging: {} } } ); @@ -439,160 +430,6 @@ const getServer = () => { } ); - // Register a long-running tool that demonstrates task execution - // Using the experimental tasks API - WARNING: may change without notice - server.experimental.tasks.registerToolTask( - 'delay', - { - title: 'Delay', - description: 'A simple tool that delays for a specified duration, useful for testing task execution', - inputSchema: z.object({ - duration: z.number().describe('Duration in milliseconds').default(5000) - }) - }, - { - async createTask({ duration }, ctx) { - // Create the task - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - // Simulate out-of-band work - (async () => { - await new Promise(resolve => setTimeout(resolve, duration)); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [ - { - type: 'text', - text: `Completed ${duration}ms delay` - } - ] - }); - })(); - - // Return CreateTaskResult with the created task - return { - task - }; - }, - async getTask(_args, ctx) { - return await ctx.task.store.getTask(ctx.task.id); - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - // Register a tool that demonstrates bidirectional task support: - // Server creates a task, then elicits input from client using elicitInputStream - // Using the experimental tasks API - WARNING: may change without notice - server.experimental.tasks.registerToolTask( - 'collect-user-info-task', - { - title: 'Collect Info with Task', - description: 'Collects user info via elicitation with task support using elicitInputStream', - inputSchema: z.object({ - infoType: z.enum(['contact', 'preferences']).describe('Type of information to collect').default('contact') - }) - }, - { - async createTask({ infoType }, ctx) { - // Create the server-side task - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - // Perform async work that makes a nested elicitation request using elicitInputStream - (async () => { - try { - const message = infoType === 'contact' ? 'Please provide your contact information' : 'Please set your preferences'; - - // Define schemas with proper typing for PrimitiveSchemaDefinition - const contactSchema: { - type: 'object'; - properties: Record; - required: string[]; - } = { - type: 'object', - properties: { - name: { type: 'string', title: 'Full Name', description: 'Your full name' }, - email: { type: 'string', title: 'Email', description: 'Your email address' } - }, - required: ['name', 'email'] - }; - - const preferencesSchema: { - type: 'object'; - properties: Record; - required: string[]; - } = { - type: 'object', - properties: { - theme: { type: 'string', title: 'Theme', enum: ['light', 'dark', 'auto'] }, - notifications: { type: 'boolean', title: 'Enable Notifications', default: true } - }, - required: ['theme'] - }; - - const requestedSchema = infoType === 'contact' ? contactSchema : preferencesSchema; - - // Use elicitInputStream to elicit input from client - // This demonstrates the streaming elicitation API - // Access via server.server to get the underlying Server instance - const stream = server.server.experimental.tasks.elicitInputStream({ - mode: 'form', - message, - requestedSchema - }); - - let elicitResult: ElicitResult | undefined; - for await (const msg of stream) { - if (msg.type === 'result') { - elicitResult = msg.result as ElicitResult; - } else if (msg.type === 'error') { - throw msg.error; - } - } - - if (!elicitResult) { - throw new Error('No result received from elicitation'); - } - - let resultText: string; - if (elicitResult.action === 'accept') { - resultText = `Collected ${infoType} info: ${JSON.stringify(elicitResult.content, null, 2)}`; - } else if (elicitResult.action === 'decline') { - resultText = `User declined to provide ${infoType} information`; - } else { - resultText = 'User cancelled the request'; - } - - await taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: resultText }] - }); - } catch (error) { - console.error('Error in collect-user-info-task:', error); - await taskStore.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: `Error: ${error}` }], - isError: true - }); - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - return await ctx.task.store.getTask(ctx.task.id); - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - return server; }; diff --git a/examples/server/src/simpleTaskInteractive.ts b/examples/server/src/simpleTaskInteractive.ts deleted file mode 100644 index fc0d7280c8..0000000000 --- a/examples/server/src/simpleTaskInteractive.ts +++ /dev/null @@ -1,758 +0,0 @@ -/** - * Simple interactive task server demonstrating elicitation and sampling. - * - * This server demonstrates the task message queue pattern from the MCP Tasks spec: - * - confirm_delete: Uses elicitation to ask the user for confirmation - * - write_haiku: Uses sampling to request an LLM to generate content - * - * Both tools use the "call-now, fetch-later" pattern where the initial call - * creates a task, and the result is fetched via tasks/result endpoint. - */ - -import { randomUUID } from 'node:crypto'; - -import { createMcpExpressApp } from '@modelcontextprotocol/express'; -import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { - CallToolResult, - CreateMessageRequest, - CreateMessageResult, - CreateTaskOptions, - CreateTaskResult, - ElicitRequestFormParams, - ElicitResult, - GetTaskPayloadResult, - GetTaskResult, - JSONRPCRequest, - PrimitiveSchemaDefinition, - QueuedMessage, - QueuedRequest, - RequestId, - Result, - SamplingMessage, - Task, - TaskMessageQueue, - TextContent, - Tool -} from '@modelcontextprotocol/server'; -import { InMemoryTaskStore, isTerminal, RELATED_TASK_META_KEY, Server } from '@modelcontextprotocol/server'; -import type { Request, Response } from 'express'; - -// ============================================================================ -// Resolver - Promise-like for passing results between async operations -// ============================================================================ - -class Resolver { - private _resolve!: (value: T) => void; - private _reject!: (error: Error) => void; - private _promise: Promise; - private _done = false; - - constructor() { - this._promise = new Promise((resolve, reject) => { - this._resolve = resolve; - this._reject = reject; - }); - } - - setResult(value: T): void { - if (this._done) return; - this._done = true; - this._resolve(value); - } - - setException(error: Error): void { - if (this._done) return; - this._done = true; - this._reject(error); - } - - wait(): Promise { - return this._promise; - } - - done(): boolean { - return this._done; - } -} - -// ============================================================================ -// Extended message queue with resolver support and wait functionality -// ============================================================================ - -interface QueuedRequestWithResolver extends QueuedRequest { - resolver?: Resolver>; - originalRequestId?: RequestId; -} - -type QueuedMessageWithResolver = QueuedRequestWithResolver | QueuedMessage; - -class TaskMessageQueueWithResolvers implements TaskMessageQueue { - private queues = new Map(); - private waitResolvers = new Map void)[]>(); - - private getQueue(taskId: string): QueuedMessageWithResolver[] { - let queue = this.queues.get(taskId); - if (!queue) { - queue = []; - this.queues.set(taskId, queue); - } - return queue; - } - - async enqueue(taskId: string, message: QueuedMessage, _sessionId?: string, maxSize?: number): Promise { - const queue = this.getQueue(taskId); - if (maxSize !== undefined && queue.length >= maxSize) { - throw new Error(`Task message queue overflow: queue size (${queue.length}) exceeds maximum (${maxSize})`); - } - queue.push(message); - // Notify any waiters - this.notifyWaiters(taskId); - } - - async enqueueWithResolver( - taskId: string, - message: JSONRPCRequest, - resolver: Resolver>, - originalRequestId: RequestId - ): Promise { - const queue = this.getQueue(taskId); - const queuedMessage: QueuedRequestWithResolver = { - type: 'request', - message, - timestamp: Date.now(), - resolver, - originalRequestId - }; - queue.push(queuedMessage); - this.notifyWaiters(taskId); - } - - async dequeue(taskId: string, _sessionId?: string): Promise { - const queue = this.getQueue(taskId); - return queue.shift(); - } - - async dequeueAll(taskId: string, _sessionId?: string): Promise { - const queue = this.queues.get(taskId) ?? []; - this.queues.delete(taskId); - return queue; - } - - async waitForMessage(taskId: string): Promise { - // Check if there are already messages - const queue = this.getQueue(taskId); - if (queue.length > 0) return; - - // Wait for a message to be added - return new Promise(resolve => { - let waiters = this.waitResolvers.get(taskId); - if (!waiters) { - waiters = []; - this.waitResolvers.set(taskId, waiters); - } - waiters.push(resolve); - }); - } - - private notifyWaiters(taskId: string): void { - const waiters = this.waitResolvers.get(taskId); - if (waiters) { - this.waitResolvers.delete(taskId); - for (const resolve of waiters) { - resolve(); - } - } - } - - cleanup(): void { - this.queues.clear(); - this.waitResolvers.clear(); - } -} - -// ============================================================================ -// Extended task store with wait functionality -// ============================================================================ - -class TaskStoreWithNotifications extends InMemoryTaskStore { - private updateResolvers = new Map void)[]>(); - - override async updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, sessionId?: string): Promise { - await super.updateTaskStatus(taskId, status, statusMessage, sessionId); - this.notifyUpdate(taskId); - } - - override async storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, sessionId?: string): Promise { - await super.storeTaskResult(taskId, status, result, sessionId); - this.notifyUpdate(taskId); - } - - async waitForUpdate(taskId: string): Promise { - return new Promise(resolve => { - let waiters = this.updateResolvers.get(taskId); - if (!waiters) { - waiters = []; - this.updateResolvers.set(taskId, waiters); - } - waiters.push(resolve); - }); - } - - private notifyUpdate(taskId: string): void { - const waiters = this.updateResolvers.get(taskId); - if (waiters) { - this.updateResolvers.delete(taskId); - for (const resolve of waiters) { - resolve(); - } - } - } -} - -// ============================================================================ -// Task Result Handler - delivers queued messages and routes responses -// ============================================================================ - -class TaskResultHandler { - private pendingRequests = new Map>>(); - - constructor( - private store: TaskStoreWithNotifications, - private queue: TaskMessageQueueWithResolvers - ) {} - - async handle(taskId: string, server: Server, _sessionId: string): Promise { - while (true) { - // Get fresh task state - const task = await this.store.getTask(taskId); - if (!task) { - throw new Error(`Task not found: ${taskId}`); - } - - // Dequeue and send all pending messages - await this.deliverQueuedMessages(taskId, server, _sessionId); - - // If task is terminal, return result - if (isTerminal(task.status)) { - const result = await this.store.getTaskResult(taskId); - // Add related-task metadata per spec - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { taskId } - } - }; - } - - // Wait for task update or new message - await this.waitForUpdate(taskId); - } - } - - private async deliverQueuedMessages(taskId: string, server: Server, _sessionId: string): Promise { - while (true) { - const message = await this.queue.dequeue(taskId); - if (!message) break; - - console.log(`[Server] Delivering queued ${message.type} message for task ${taskId}`); - - if (message.type === 'request') { - const reqMessage = message as QueuedRequestWithResolver; - // Send the request via the server - // Store the resolver so we can route the response back - if (reqMessage.resolver && reqMessage.originalRequestId) { - this.pendingRequests.set(reqMessage.originalRequestId, reqMessage.resolver); - } - - // Send the message - for elicitation/sampling, we use the server's methods - // But since we're in tasks/result context, we need to send via transport - // This is simplified - in production you'd use proper message routing - try { - const request = reqMessage.message; - let response: ElicitResult | CreateMessageResult; - - if (request.method === 'elicitation/create') { - // Send elicitation request to client - const params = request.params as ElicitRequestFormParams; - response = await server.elicitInput(params); - } else if (request.method === 'sampling/createMessage') { - // Send sampling request to client - const params = request.params as CreateMessageRequest['params']; - response = await server.createMessage(params); - } else { - throw new Error(`Unknown request method: ${request.method}`); - } - - // Route response back to resolver - if (reqMessage.resolver) { - reqMessage.resolver.setResult(response as unknown as Record); - } - } catch (error) { - if (reqMessage.resolver) { - reqMessage.resolver.setException(error instanceof Error ? error : new Error(String(error))); - } - } - } - // For notifications, we'd send them too but this example focuses on requests - } - } - - private async waitForUpdate(taskId: string): Promise { - // Race between store update and queue message - await Promise.race([this.store.waitForUpdate(taskId), this.queue.waitForMessage(taskId)]); - } - - routeResponse(requestId: RequestId, response: Record): boolean { - const resolver = this.pendingRequests.get(requestId); - if (resolver && !resolver.done()) { - this.pendingRequests.delete(requestId); - resolver.setResult(response); - return true; - } - return false; - } - - routeError(requestId: RequestId, error: Error): boolean { - const resolver = this.pendingRequests.get(requestId); - if (resolver && !resolver.done()) { - this.pendingRequests.delete(requestId); - resolver.setException(error); - return true; - } - return false; - } -} - -// ============================================================================ -// Task Session - wraps server to enqueue requests during task execution -// ============================================================================ - -class TaskSession { - private requestCounter = 0; - - constructor( - private server: Server, - private taskId: string, - private store: TaskStoreWithNotifications, - private queue: TaskMessageQueueWithResolvers - ) {} - - private nextRequestId(): string { - return `task-${this.taskId}-${++this.requestCounter}`; - } - - async elicit( - message: string, - requestedSchema: { - type: 'object'; - properties: Record; - required?: string[]; - } - ): Promise<{ action: string; content?: Record }> { - // Update task status to input_required - await this.store.updateTaskStatus(this.taskId, 'input_required'); - - const requestId = this.nextRequestId(); - - // Build the elicitation request with related-task metadata - const params: ElicitRequestFormParams = { - message, - requestedSchema, - mode: 'form', - _meta: { - [RELATED_TASK_META_KEY]: { taskId: this.taskId } - } - }; - - const jsonrpcRequest: JSONRPCRequest = { - jsonrpc: '2.0', - id: requestId, - method: 'elicitation/create', - params - }; - - // Create resolver to wait for response - const resolver = new Resolver>(); - - // Enqueue the request - await this.queue.enqueueWithResolver(this.taskId, jsonrpcRequest, resolver, requestId); - - try { - // Wait for response - const response = await resolver.wait(); - - // Update status back to working - await this.store.updateTaskStatus(this.taskId, 'working'); - - return response as { action: string; content?: Record }; - } catch (error) { - await this.store.updateTaskStatus(this.taskId, 'working'); - throw error; - } - } - - async createMessage( - messages: SamplingMessage[], - maxTokens: number - ): Promise<{ role: string; content: TextContent | { type: string } }> { - // Update task status to input_required - await this.store.updateTaskStatus(this.taskId, 'input_required'); - - const requestId = this.nextRequestId(); - - // Build the sampling request with related-task metadata - const params = { - messages, - maxTokens, - _meta: { - [RELATED_TASK_META_KEY]: { taskId: this.taskId } - } - }; - - const jsonrpcRequest: JSONRPCRequest = { - jsonrpc: '2.0', - id: requestId, - method: 'sampling/createMessage', - params - }; - - // Create resolver to wait for response - const resolver = new Resolver>(); - - // Enqueue the request - await this.queue.enqueueWithResolver(this.taskId, jsonrpcRequest, resolver, requestId); - - try { - // Wait for response - const response = await resolver.wait(); - - // Update status back to working - await this.store.updateTaskStatus(this.taskId, 'working'); - - return response as { role: string; content: TextContent | { type: string } }; - } catch (error) { - await this.store.updateTaskStatus(this.taskId, 'working'); - throw error; - } - } -} - -// ============================================================================ -// Server Setup -// ============================================================================ - -const PORT = process.env.PORT ? Number.parseInt(process.env.PORT, 10) : 8000; - -// Create shared stores -const taskStore = new TaskStoreWithNotifications(); -const messageQueue = new TaskMessageQueueWithResolvers(); -const taskResultHandler = new TaskResultHandler(taskStore, messageQueue); - -// Track active task executions -const activeTaskExecutions = new Map< - string, - { - promise: Promise; - server: Server; - sessionId: string; - } ->(); - -// Create the server -const createServer = (): Server => { - const server = new Server( - { name: 'simple-task-interactive', version: '1.0.0' }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { call: {} } - } - } - } - } - ); - - // Register tools - server.setRequestHandler('tools/list', async (): Promise<{ tools: Tool[] }> => { - return { - tools: [ - { - name: 'confirm_delete', - description: 'Asks for confirmation before deleting (demonstrates elicitation)', - inputSchema: { - type: 'object', - properties: { - filename: { type: 'string' } - } - }, - execution: { taskSupport: 'required' } - }, - { - name: 'write_haiku', - description: 'Asks LLM to write a haiku (demonstrates sampling)', - inputSchema: { - type: 'object', - properties: { - topic: { type: 'string' } - } - }, - execution: { taskSupport: 'required' } - } - ] - }; - }); - - // Handle tool calls - server.setRequestHandler('tools/call', async (request, ctx): Promise => { - const { name, arguments: args } = request.params; - const taskParams = (request.params._meta?.task || request.params.task) as { ttl?: number; pollInterval?: number } | undefined; - - // Validate task mode - these tools require tasks - if (!taskParams) { - throw new Error(`Tool ${name} requires task mode`); - } - - // Create task - const taskOptions: CreateTaskOptions = { - ttl: taskParams.ttl, - pollInterval: taskParams.pollInterval ?? 1000 - }; - - const task = await taskStore.createTask(taskOptions, ctx.mcpReq.id, request, ctx.sessionId); - - console.log(`\n[Server] ${name} called, task created: ${task.taskId}`); - - // Start background task execution - const taskExecution = (async () => { - try { - const taskSession = new TaskSession(server, task.taskId, taskStore, messageQueue); - - if (name === 'confirm_delete') { - const filename = args?.filename ?? 'unknown.txt'; - console.log(`[Server] confirm_delete: asking about '${filename}'`); - - console.log('[Server] Sending elicitation request to client...'); - const result = await taskSession.elicit(`Are you sure you want to delete '${filename}'?`, { - type: 'object', - properties: { - confirm: { type: 'boolean' } - }, - required: ['confirm'] - }); - - console.log( - `[Server] Received elicitation response: action=${result.action}, content=${JSON.stringify(result.content)}` - ); - - let text: string; - if (result.action === 'accept' && result.content) { - const confirmed = result.content.confirm; - text = confirmed ? `Deleted '${filename}'` : 'Deletion cancelled'; - } else { - text = 'Deletion cancelled'; - } - - console.log(`[Server] Completing task with result: ${text}`); - await taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text }] - }); - } else if (name === 'write_haiku') { - const topic = args?.topic ?? 'nature'; - console.log(`[Server] write_haiku: topic '${topic}'`); - - console.log('[Server] Sending sampling request to client...'); - const result = await taskSession.createMessage( - [ - { - role: 'user', - content: { type: 'text', text: `Write a haiku about ${topic}` } - } - ], - 50 - ); - - let haiku = 'No response'; - if (result.content && 'text' in result.content) { - haiku = (result.content as TextContent).text; - } - - console.log(`[Server] Received sampling response: ${haiku.slice(0, 50)}...`); - console.log('[Server] Completing task with haiku'); - await taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Haiku:\n${haiku}` }] - }); - } - } catch (error) { - console.error(`[Server] Task ${task.taskId} failed:`, error); - await taskStore.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: `Error: ${error}` }], - isError: true - }); - } finally { - activeTaskExecutions.delete(task.taskId); - } - })(); - - activeTaskExecutions.set(task.taskId, { - promise: taskExecution, - server, - sessionId: ctx.sessionId ?? '' - }); - - return { task }; - }); - - // Handle tasks/get - server.setRequestHandler('tasks/get', async (request): Promise => { - const { taskId } = request.params; - const task = await taskStore.getTask(taskId); - if (!task) { - throw new Error(`Task ${taskId} not found`); - } - return task; - }); - - // Handle tasks/result - server.setRequestHandler('tasks/result', async (request, ctx): Promise => { - const { taskId } = request.params; - console.log(`[Server] tasks/result called for task ${taskId}`); - return taskResultHandler.handle(taskId, server, ctx.sessionId ?? ''); - }); - - return server; -}; - -// ============================================================================ -// Express App Setup -// ============================================================================ - -const app = createMcpExpressApp(); - -// Map to store transports by session ID -const transports: { [sessionId: string]: NodeStreamableHTTPServerTransport } = {}; - -// Helper to check if request is initialize -const isInitializeRequest = (body: unknown): boolean => { - return typeof body === 'object' && body !== null && 'method' in body && (body as { method: string }).method === 'initialize'; -}; - -// MCP POST endpoint -app.post('/mcp', async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - - try { - let transport: NodeStreamableHTTPServerTransport; - - if (sessionId && transports[sessionId]) { - transport = transports[sessionId]; - } else if (!sessionId && isInitializeRequest(req.body)) { - transport = new NodeStreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - onsessioninitialized: sid => { - console.log(`Session initialized: ${sid}`); - transports[sid] = transport; - } - }); - - transport.onclose = () => { - const sid = transport.sessionId; - if (sid && transports[sid]) { - console.log(`Transport closed for session ${sid}`); - delete transports[sid]; - } - }; - - const server = createServer(); - await server.connect(transport); - await transport.handleRequest(req, res, req.body); - return; - } else if (sessionId) { - res.status(404).json({ - jsonrpc: '2.0', - error: { code: -32_001, message: 'Session not found' }, - id: null - }); - return; - } else { - res.status(400).json({ - jsonrpc: '2.0', - error: { code: -32_000, message: 'Bad Request: Session ID required' }, - id: null - }); - return; - } - - await transport.handleRequest(req, res, req.body); - } catch (error) { - console.error('Error handling MCP request:', error); - if (!res.headersSent) { - res.status(500).json({ - jsonrpc: '2.0', - error: { code: -32_603, message: 'Internal server error' }, - id: null - }); - } - } -}); - -// Handle GET requests for SSE streams -app.get('/mcp', async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId) { - res.status(400).send('Missing session ID'); - return; - } - if (!transports[sessionId]) { - res.status(404).send('Session not found'); - return; - } - - const transport = transports[sessionId]; - await transport.handleRequest(req, res); -}); - -// Handle DELETE requests for session termination -app.delete('/mcp', async (req: Request, res: Response) => { - const sessionId = req.headers['mcp-session-id'] as string | undefined; - if (!sessionId) { - res.status(400).send('Missing session ID'); - return; - } - if (!transports[sessionId]) { - res.status(404).send('Session not found'); - return; - } - - console.log(`Session termination request: ${sessionId}`); - const transport = transports[sessionId]; - await transport.handleRequest(req, res); -}); - -// Start server -app.listen(PORT, () => { - console.log(`Starting server on http://localhost:${PORT}/mcp`); - console.log('\nAvailable tools:'); - console.log(' - confirm_delete: Demonstrates elicitation (asks user y/n)'); - console.log(' - write_haiku: Demonstrates sampling (requests LLM completion)'); -}); - -// Handle shutdown -process.on('SIGINT', async () => { - console.log('\nShutting down server...'); - for (const sessionId of Object.keys(transports)) { - try { - await transports[sessionId]!.close(); - delete transports[sessionId]; - } catch (error) { - console.error(`Error closing session ${sessionId}:`, error); - } - } - taskStore.cleanup(); - messageQueue.cleanup(); - console.log('Server shutdown complete'); - process.exit(0); -}); diff --git a/packages/client/src/client/authExtensions.examples.ts b/packages/client/src/client/authExtensions.examples.ts index bcb26a3d41..5c58661453 100644 --- a/packages/client/src/client/authExtensions.examples.ts +++ b/packages/client/src/client/authExtensions.examples.ts @@ -8,7 +8,7 @@ */ import { ClientCredentialsProvider, createPrivateKeyJwtAuth, PrivateKeyJwtProvider } from './authExtensions.js'; -import { StreamableHTTPClientTransport } from './streamableHttp.js'; +import { StreamableHTTPClientTransport } from './modernStreamableHttp.js'; /** * Example: Creating a private key JWT authentication function. diff --git a/packages/client/src/client/client.examples.ts b/packages/client/src/client/client.examples.ts index b08694cfbd..c7abdc7863 100644 --- a/packages/client/src/client/client.examples.ts +++ b/packages/client/src/client/client.examples.ts @@ -10,9 +10,9 @@ import type { Prompt, Resource, Tool } from '@modelcontextprotocol/core'; import { Client } from './client.js'; +import { StdioClientTransport } from './modernStdio.js'; +import { StreamableHTTPClientTransport } from './modernStreamableHttp.js'; import { SSEClientTransport } from './sse.js'; -import { StdioClientTransport } from './stdio.js'; -import { StreamableHTTPClientTransport } from './streamableHttp.js'; /** * Example: Using listChanged to automatically track tool and prompt updates. diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 5fa2e14d94..925d74da6a 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -21,34 +21,37 @@ import type { ListToolsRequest, LoggingLevel, MessageExtraInfo, + Notification, NotificationMethod, + NotificationOptions, + NotificationTypeMap, ProtocolOptions, ReadResourceRequest, RequestMethod, RequestOptions, + RequestTypeMap, Result, + ResultTypeMap, ServerCapabilities, + StandardSchemaV1, SubscribeRequest, - TaskManagerOptions, Tool, Transport, UnsubscribeRequest } from '@modelcontextprotocol/core'; import { - assertClientRequestTaskCapability, - assertToolsCallTaskCapability, CallToolResultSchema, CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, - CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, EmptyResultSchema, - extractTaskManagerOptions, GetPromptResultSchema, + HandlerRegistry, InitializeResultSchema, + isStandardSchema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, ListPromptsResultSchema, @@ -65,7 +68,8 @@ import { SdkErrorCode } from '@modelcontextprotocol/core'; -import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import { ModernClientImpl } from './modernClientImpl.js'; +import { isVersionProbingTransport } from './versionProbing.js'; /** * Elicitation default application helper. Applies defaults to the `data` based on the `schema`. @@ -141,19 +145,152 @@ export function getSupportedElicitationModes(capabilities: ClientCapabilities['e return { supportsFormMode, supportsUrlMode }; } +// --------------------------------------------------------------------------- +// Standalone functions for HandlerRegistry callbacks (mirrors server.ts pattern) +// --------------------------------------------------------------------------- + +function assertClientHandlerCapability(method: string, capabilities: ClientCapabilities): void { + switch (method) { + case 'sampling/createMessage': { + if (!capabilities.sampling) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Client does not support sampling capability (required for ${method})` + ); + } + break; + } + + case 'elicitation/create': { + if (!capabilities.elicitation) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Client does not support elicitation capability (required for ${method})` + ); + } + break; + } + + case 'roots/list': { + if (!capabilities.roots) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Client does not support roots capability (required for ${method})` + ); + } + break; + } + + case 'ping': { + break; + } + } +} + +function clientWrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise, + getCapabilities: () => ClientCapabilities +): (request: JSONRPCRequest, ctx: ClientContext) => Promise { + if (method === 'elicitation/create') { + return async (request, ctx) => { + const validatedRequest = parseSchema(ElicitRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + params.mode = params.mode ?? 'form'; + const capabilities = getCapabilities(); + const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(capabilities.elicitation); + + if (params.mode === 'form' && !supportsFormMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); + } + + if (params.mode === 'url' && !supportsUrlMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); + } + + const result = await handler(request, ctx); + + const validationResult = parseSchema(ElicitResultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); + } + + const validatedResult = validationResult.data; + const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined; + + if ( + params.mode === 'form' && + validatedResult.action === 'accept' && + validatedResult.content && + requestedSchema && + capabilities.elicitation?.form?.applyDefaults + ) { + try { + applyElicitationDefaults(requestedSchema, validatedResult.content); + } catch { + // gracefully ignore errors in default application + } + } + + return validatedResult; + }; + } + + if (method === 'sampling/createMessage') { + return async (request, ctx) => { + const validatedRequest = parseSchema(CreateMessageRequestSchema, request); + if (!validatedRequest.success) { + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + } + + const { params } = validatedRequest.data; + + const result = await handler(request, ctx); + + const hasTools = params.tools || params.toolChoice; + const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; + const validationResult = parseSchema(resultSchema, result); + if (!validationResult.success) { + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + } + + return validationResult.data; + }; + } + + return handler; +} + /** - * Extended tasks capability that includes runtime configuration (store, messageQueue). - * The runtime-only fields are stripped before advertising capabilities to servers. + * Creates a client HandlerRegistry with client-specific callbacks. + * @internal */ -export type ClientTasksCapabilityWithRuntime = NonNullable & TaskManagerOptions; +export function createClientRegistry(capabilities?: ClientCapabilities): HandlerRegistry { + const registry: HandlerRegistry = new HandlerRegistry({ + capabilities, + assertRequestHandlerCapability: method => assertClientHandlerCapability(method, registry.getCapabilities()), + wrapHandler: (method, handler) => clientWrapHandler(method, handler, () => registry.getCapabilities()) + }); + return registry; +} export type ClientOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this client. */ - capabilities?: Omit & { - tasks?: ClientTasksCapabilityWithRuntime; - }; + capabilities?: ClientCapabilities; /** * JSON Schema validator for tool output validation. @@ -192,37 +329,18 @@ export type ClientOptions = ProtocolOptions & { * ``` */ listChanged?: ListChangedHandlers; + + /** @internal */ + registry?: HandlerRegistry; }; /** - * An MCP client on top of a pluggable transport. - * - * The client will automatically begin the initialization flow with the server when {@linkcode connect} is called. - * - * To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}. - * The client must declare the corresponding capability for the handler to be accepted. For - * `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with - * schema validation for both the incoming request and the returned result. + * The Protocol-based MCP client implementation. Handles JSON-RPC dispatch, + * request/response correlation, and bidirectional session management. * - * @example Handling a sampling request - * ```ts source="./client.examples.ts#Client_setRequestHandler_sampling" - * client.setRequestHandler('sampling/createMessage', async request => { - * const lastMessage = request.params.messages.at(-1); - * console.log('Sampling request:', lastMessage); - * - * // In production, send messages to your LLM here - * return { - * model: 'my-model', - * role: 'assistant' as const, - * content: { - * type: 'text' as const, - * text: 'Response from the model' - * } - * }; - * }); - * ``` + * Used internally by {@linkcode Client} for transport connections. */ -export class Client extends Protocol { +export class LegacyClient extends Protocol { private _serverCapabilities?: ServerCapabilities; private _serverVersion?: Implementation; private _negotiatedProtocolVersion?: string; @@ -230,9 +348,6 @@ export class Client extends Protocol { private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); - private _cachedKnownTaskTools: Set = new Set(); - private _cachedRequiredTaskTools: Set = new Set(); - private _experimental?: { tasks: ExperimentalClientTasks }; private _listChangedDebounceTimers: Map> = new Map(); private _pendingListChangedConfig?: ListChangedHandlers; private _enforceStrictCapabilities: boolean; @@ -244,22 +359,16 @@ export class Client extends Protocol { private _clientInfo: Implementation, options?: ClientOptions ) { - super({ - ...options, - tasks: extractTaskManagerOptions(options?.capabilities?.tasks) - }); + const registry = options?.registry ?? createClientRegistry(options?.capabilities); + super(registry, options); + if (!options?.registry) { + registry.assertRequestHandlerCapability = method => this._assertClientRequestHandlerCapability(method); + registry.wrapHandler = (method, handler) => this._clientWrapHandler(method, handler); + } this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false; - // Strip runtime-only fields from advertised capabilities - if (options?.capabilities?.tasks) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize, ...wireCapabilities } = - options.capabilities.tasks; - this._capabilities.tasks = wireCapabilities; - } - // Store list changed config for setup after connection (when we know server capabilities) if (options?.listChanged) { this._pendingListChangedConfig = options.listChanged; @@ -299,22 +408,6 @@ export class Client extends Protocol { } } - /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - get experimental(): { tasks: ExperimentalClientTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalClientTasks(this) - }; - } - return this._experimental; - } - /** * Registers new capabilities. This can only be called before connecting to a transport. * @@ -326,13 +419,14 @@ export class Client extends Protocol { } this._capabilities = mergeCapabilities(this._capabilities, capabilities); + this._registry.registerCapabilities(capabilities); } /** * Enforces client-side validation for `elicitation/create` and `sampling/createMessage` * regardless of how the handler was registered. */ - protected override _wrapHandler( + private _clientWrapHandler( method: string, handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise ): (request: JSONRPCRequest, ctx: ClientContext) => Promise { @@ -360,20 +454,6 @@ export class Client extends Protocol { const result = await handler(request, ctx); - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } - - // For non-task requests, validate against ElicitResultSchema const validationResult = parseSchema(ElicitResultSchema, result); if (!validationResult.success) { // Type guard: if success is false, error is guaranteed to exist @@ -416,20 +496,7 @@ export class Client extends Protocol { const result = await handler(request, ctx); - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } - - // For non-task requests, validate against appropriate schema based on tools presence + // Validate against appropriate schema based on tools presence const hasTools = params.tools || params.toolChoice; const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; const validationResult = parseSchema(resultSchema, result); @@ -662,7 +729,7 @@ export class Client extends Protocol { } } - protected assertRequestHandlerCapability(method: string): void { + private _assertClientRequestHandlerCapability(method: string): void { switch (method) { case 'sampling/createMessage': { if (!this._capabilities.sampling) { @@ -701,14 +768,6 @@ export class Client extends Protocol { } } - protected assertTaskCapability(method: string): void { - assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'); - } - - protected assertTaskHandlerCapability(method: string): void { - assertClientRequestTaskCapability(this._capabilities?.tasks?.requests, method, 'Client'); - } - async ping(options?: RequestOptions) { return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema, options); } @@ -828,8 +887,6 @@ export class Client extends Protocol { * a problem), and thrown {@linkcode ProtocolError} for protocol-level failures or {@linkcode SdkError} for * SDK-level issues (timeouts, missing capabilities). * - * For task-based execution with streaming behavior, use {@linkcode ExperimentalClientTasks.callToolStream | client.experimental.tasks.callToolStream()} instead. - * * @example Basic usage * ```ts source="./client.examples.ts#Client_callTool_basic" * const result = await client.callTool({ @@ -860,14 +917,6 @@ export class Client extends Protocol { * ``` */ async callTool(params: CallToolRequest['params'], options?: RequestOptions) { - // Guard: required-task tools need experimental API - if (this.isToolTaskRequired(params.name)) { - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() instead.` - ); - } - const result = await this._requestWithSchema({ method: 'tools/call', params }, CallToolResultSchema, options); // Check if the tool has an outputSchema @@ -908,30 +957,12 @@ export class Client extends Protocol { return result; } - private isToolTask(toolName: string): boolean { - if (!this._serverCapabilities?.tasks?.requests?.tools?.call) { - return false; - } - - return this._cachedKnownTaskTools.has(toolName); - } - - /** - * Check if a tool requires task-based execution. - * Unlike {@linkcode isToolTask} which includes `'optional'` tools, this only checks for `'required'`. - */ - private isToolTaskRequired(toolName: string): boolean { - return this._cachedRequiredTaskTools.has(toolName); - } - /** * Cache validators for tool output schemas. * Called after {@linkcode listTools | listTools()} to pre-compile validators for better performance. */ private cacheToolMetadata(tools: Tool[]): void { this._cachedToolOutputValidators.clear(); - this._cachedKnownTaskTools.clear(); - this._cachedRequiredTaskTools.clear(); for (const tool of tools) { // If the tool has an outputSchema, create and cache the validator @@ -939,15 +970,6 @@ export class Client extends Protocol { const toolValidator = this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType); this._cachedToolOutputValidators.set(tool.name, toolValidator); } - - // If the tool supports task-based execution, cache that information - const taskSupport = tool.execution?.taskSupport; - if (taskSupport === 'required' || taskSupport === 'optional') { - this._cachedKnownTaskTools.add(tool.name); - } - if (taskSupport === 'required') { - this._cachedRequiredTaskTools.add(tool.name); - } } } @@ -1058,3 +1080,343 @@ export class Client extends Protocol { return this.notification({ method: 'notifications/roots/list_changed' }); } } + +/** + * An MCP client on top of a pluggable transport. + * + * The client will automatically begin the initialization flow with the server when {@linkcode connect} is called. + * + * To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}. + * The client must declare the corresponding capability for the handler to be accepted. For + * `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with + * schema validation for both the incoming request and the returned result. + * + * Owns a {@linkcode LegacyClient} internally for protocol communication. + * For now always creates the legacy implementation; a modern client + * implementation will be added in a future phase. + * + * @example Handling a sampling request + * ```ts source="./client.examples.ts#Client_setRequestHandler_sampling" + * client.setRequestHandler('sampling/createMessage', async request => { + * const lastMessage = request.params.messages.at(-1); + * console.log('Sampling request:', lastMessage); + * + * // In production, send messages to your LLM here + * return { + * model: 'my-model', + * role: 'assistant' as const, + * content: { + * type: 'text' as const, + * text: 'Response from the model' + * } + * }; + * }); + * ``` + */ +export class Client { + private _registry: HandlerRegistry; + private _legacyImpl: LegacyClient; + private _modernImpl?: ModernClientImpl; + + get onclose() { + return this._legacyImpl.onclose; + } + set onclose(h) { + this._legacyImpl.onclose = h; + if (this._modernImpl) { + this._modernImpl.onclose = h; + } + } + + get onerror() { + return this._legacyImpl.onerror; + } + set onerror(h) { + this._legacyImpl.onerror = h; + if (this._modernImpl) { + this._modernImpl.onerror = h; + } + } + + get fallbackRequestHandler() { + return this._registry.fallbackRequestHandler; + } + set fallbackRequestHandler(h) { + this._registry.fallbackRequestHandler = h; + } + + get fallbackNotificationHandler() { + return this._registry.fallbackNotificationHandler; + } + set fallbackNotificationHandler(h) { + this._registry.fallbackNotificationHandler = h; + } + + /** + * Initializes this client with the given name and version information. + */ + constructor( + private _clientInfo: Implementation, + private _options?: ClientOptions + ) { + this._registry = createClientRegistry(_options?.capabilities); + this._legacyImpl = new LegacyClient(_clientInfo, { + ..._options, + registry: this._registry + }); + } + + /** + * Connects to a server via the given transport. + * + * If the transport implements version probing and detected modern (2026-06) + * protocol support, a {@linkcode ModernClientImpl} is used instead of the + * legacy Protocol-based implementation. Otherwise, the legacy path (with + * full initialize handshake) is used. + */ + async connect(transport: Transport, options?: RequestOptions): Promise { + if (isVersionProbingTransport(transport)) { + await transport.start(); + if (transport.mode === 'modern') { + const modern = new ModernClientImpl( + this._clientInfo, + this._registry.getCapabilities(), + transport.getDiscoverResult()!, + this._registry + ); + modern.onclose = this._legacyImpl.onclose; + modern.onerror = this._legacyImpl.onerror; + await modern.connect(transport); + this._modernImpl = modern; + return; + } + } + return this._legacyImpl.connect(transport, options); + } + + async close(): Promise { + if (this._modernImpl) { + return this._modernImpl.close(); + } + return this._legacyImpl.close(); + } + + get transport(): Transport | undefined { + if (this._modernImpl) { + return this._modernImpl.transport; + } + return this._legacyImpl.transport; + } + + // --------------------------------------------------------------------------- + // Handler registration — delegates to shared registry + // --------------------------------------------------------------------------- + + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: ( + params: StandardSchemaV1.InferOutput

, + ctx: ClientContext + ) => + | (R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result) + | Promise : Result> + ): void; + setRequestHandler(method: string, ...args: unknown[]): void { + (this._registry.setRequestHandler as (...a: unknown[]) => void).call(this._registry, method, ...args); + } + + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void; + setNotificationHandler

( + method: string, + schemas: { params: P }, + handler: (params: StandardSchemaV1.InferOutput

, notification: Notification) => void | Promise + ): void; + setNotificationHandler(method: string, ...args: unknown[]): void { + (this._registry.setNotificationHandler as (...a: unknown[]) => void).call(this._registry, method, ...args); + } + + removeRequestHandler(method: RequestMethod | string): void { + this._registry.removeRequestHandler(method); + } + + removeNotificationHandler(method: NotificationMethod | string): void { + this._registry.removeNotificationHandler(method); + } + + assertCanSetRequestHandler(method: RequestMethod | string): void { + this._registry.assertCanSetRequestHandler(method); + } + + // --------------------------------------------------------------------------- + // Capability and state accessors + // --------------------------------------------------------------------------- + + registerCapabilities(capabilities: ClientCapabilities): void { + this._legacyImpl.registerCapabilities(capabilities); + } + + getServerCapabilities(): ServerCapabilities | undefined { + if (this._modernImpl) { + return this._modernImpl.getServerCapabilities(); + } + return this._legacyImpl.getServerCapabilities(); + } + + getServerVersion(): Implementation | undefined { + if (this._modernImpl) { + return this._modernImpl.getServerVersion(); + } + return this._legacyImpl.getServerVersion(); + } + + getNegotiatedProtocolVersion(): string | undefined { + if (this._modernImpl) { + return '2026-06-30'; + } + return this._legacyImpl.getNegotiatedProtocolVersion(); + } + + getInstructions(): string | undefined { + if (this._modernImpl) { + return this._modernImpl.getInstructions(); + } + return this._legacyImpl.getInstructions(); + } + + // --------------------------------------------------------------------------- + // High-level request methods — delegate to modern or legacy impl + // --------------------------------------------------------------------------- + + async ping(options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.ping(options); + } + return this._legacyImpl.ping(options); + } + + async complete(params: CompleteRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.complete(params, options); + } + return this._legacyImpl.complete(params, options); + } + + async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.setLoggingLevel(level, options); + } + return this._legacyImpl.setLoggingLevel(level, options); + } + + async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.getPrompt(params, options); + } + return this._legacyImpl.getPrompt(params, options); + } + + async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.listPrompts(params, options); + } + return this._legacyImpl.listPrompts(params, options); + } + + async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.listResources(params, options); + } + return this._legacyImpl.listResources(params, options); + } + + async listResourceTemplates(params?: ListResourceTemplatesRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.listResourceTemplates(params, options); + } + return this._legacyImpl.listResourceTemplates(params, options); + } + + async readResource(params: ReadResourceRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.readResource(params, options); + } + return this._legacyImpl.readResource(params, options); + } + + async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.subscribeResource(params, options); + } + return this._legacyImpl.subscribeResource(params, options); + } + + async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.unsubscribeResource(params, options); + } + return this._legacyImpl.unsubscribeResource(params, options); + } + + async callTool(params: CallToolRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.callTool(params, options); + } + return this._legacyImpl.callTool(params, options); + } + + async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { + if (this._modernImpl) { + return this._modernImpl.listTools(params, options); + } + return this._legacyImpl.listTools(params, options); + } + + async sendRootsListChanged() { + if (this._modernImpl) { + throw new SdkError( + SdkErrorCode.UnsupportedOperation, + 'Client-to-server notifications are not supported on the modern (2026-06) protocol path' + ); + } + return this._legacyImpl.sendRootsListChanged(); + } + + // --------------------------------------------------------------------------- + // Low-level protocol methods — for backward compat (tests call directly) + // --------------------------------------------------------------------------- + + request( + request: { method: M; params?: Record }, + options?: RequestOptions + ): Promise; + request( + request: { method: string; params?: Record }, + resultSchema: T, + options?: RequestOptions + ): Promise>; + request(request: { method: string; params?: Record }, ...args: unknown[]): Promise { + if (this._modernImpl) { + const opts = isStandardSchema(args[0]) ? (args[1] as RequestOptions | undefined) : (args[0] as RequestOptions | undefined); + return this._modernImpl.request(request, opts); + } + return (this._legacyImpl.request as (...a: unknown[]) => Promise).call(this._legacyImpl, request, ...args); + } + + async notification(notification: Notification, options?: NotificationOptions): Promise { + if (this._modernImpl) { + throw new SdkError( + SdkErrorCode.UnsupportedOperation, + 'Client-to-server notifications are not supported on the modern (2026-06) protocol path' + ); + } + return this._legacyImpl.notification(notification, options); + } +} diff --git a/packages/client/src/client/modernClientImpl.ts b/packages/client/src/client/modernClientImpl.ts new file mode 100644 index 0000000000..18bc15795d --- /dev/null +++ b/packages/client/src/client/modernClientImpl.ts @@ -0,0 +1,420 @@ +import type { + CallToolRequest, + CallToolResult, + ClientCapabilities, + ClientContext, + CompleteRequest, + CompleteResult, + GetPromptRequest, + GetPromptResult, + HandlerRegistry, + Implementation, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListResourceTemplatesRequest, + ListResourceTemplatesResult, + ListToolsRequest, + ListToolsResult, + LoggingLevel, + ReadResourceRequest, + ReadResourceResult, + RequestOptions, + Result, + ServerCapabilities, + SubscribeRequest, + Transport, + UnsubscribeRequest +} from '@modelcontextprotocol/core'; +import { + DEFAULT_REQUEST_TIMEOUT_MSEC, + isJSONRPCErrorResponse, + isJSONRPCNotification, + isJSONRPCRequest, + isJSONRPCResultResponse, + ProtocolError, + SdkError, + SdkErrorCode +} from '@modelcontextprotocol/core'; + +/** + * The result returned by the `server/discover` endpoint on modern (2026-06) servers. + */ +export interface DiscoverResult { + supportedVersions: string[]; + capabilities: ServerCapabilities; + serverInfo: Implementation; + instructions?: string; +} + +/** + * Pending request entry in the correlator map. + */ +interface PendingRequest { + resolve: (result: unknown) => void; + reject: (error: Error) => void; + timer: ReturnType; +} + +/** + * A lightweight MCP client for the modern (2026-06) protocol. + * + * Unlike `LegacyClient`, this class does NOT extend Protocol. + * It manages its own request/response correlation, injects `_meta` with protocol + * version and client info into every request, and delegates HTTP-level concerns + * (like the `Mcp-Method` header) to the transport layer. + * + * Server state (capabilities, version, instructions) is populated from a + * {@linkcode DiscoverResult} passed to the constructor rather than from an + * initialize handshake. + */ +export class ModernClientImpl { + private _transport?: Transport; + private _nextId = 0; + private _pending: Map = new Map(); + private _clientInfo: Implementation; + private _clientCapabilities: ClientCapabilities; + private _serverCapabilities: ServerCapabilities; + private _serverVersion: Implementation; + private _instructions?: string; + private _registry: HandlerRegistry; + + /** + * Callback for when the connection is closed. + */ + onclose?: () => void; + + /** + * Callback for when an error occurs. + */ + onerror?: (error: Error) => void; + + constructor( + clientInfo: Implementation, + clientCapabilities: ClientCapabilities, + discoverResult: DiscoverResult, + registry: HandlerRegistry + ) { + this._clientInfo = clientInfo; + this._clientCapabilities = clientCapabilities; + this._serverCapabilities = discoverResult.capabilities; + this._serverVersion = discoverResult.serverInfo; + this._instructions = discoverResult.instructions; + this._registry = registry; + } + + /** + * Connects to a transport. Wires `transport.onmessage` to dispatch + * responses, notifications, and server-to-client requests. + * + * Unlike the legacy path, no initialize handshake is performed -- + * server state was already obtained via `server/discover`. + */ + async connect(transport: Transport): Promise { + this._transport = transport; + + transport.onmessage = (message: JSONRPCMessage) => { + if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { + this._onResponse(message); + } else if (isJSONRPCNotification(message)) { + this._onNotification(message); + } else if (isJSONRPCRequest(message)) { + this._onRequest(message); + } + }; + + transport.onclose = () => { + this._rejectAll(new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed')); + this.onclose?.(); + }; + + transport.onerror = (error: Error) => { + this.onerror?.(error); + }; + + // Transport is already started by StreamableHTTPClientTransport.start() + } + + /** + * Closes the connection and rejects all pending requests. + */ + async close(): Promise { + this._rejectAll(new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed')); + await this._transport?.close(); + this._transport = undefined; + } + + get transport(): Transport | undefined { + return this._transport; + } + + // --------------------------------------------------------------------------- + // Server state accessors + // --------------------------------------------------------------------------- + + getServerCapabilities(): ServerCapabilities { + return this._serverCapabilities; + } + + getServerVersion(): Implementation { + return this._serverVersion; + } + + getInstructions(): string | undefined { + return this._instructions; + } + + // --------------------------------------------------------------------------- + // High-level request methods + // --------------------------------------------------------------------------- + + async ping(options?: RequestOptions): Promise { + return this._request('ping', undefined, options); + } + + async complete(params: CompleteRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('completions', 'completion/complete'); + return this._request('completion/complete', params, options); + } + + async setLoggingLevel(level: LoggingLevel, options?: RequestOptions): Promise { + this._assertCapability('logging', 'logging/setLevel'); + return this._request('logging/setLevel', { level }, options); + } + + async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('prompts', 'prompts/get'); + return this._request('prompts/get', params, options); + } + + async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions): Promise { + if (!this._serverCapabilities.prompts) { + return { prompts: [] }; + } + return this._request('prompts/list', params, options); + } + + async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions): Promise { + if (!this._serverCapabilities.resources) { + return { resources: [] }; + } + return this._request('resources/list', params, options); + } + + async listResourceTemplates( + params?: ListResourceTemplatesRequest['params'], + options?: RequestOptions + ): Promise { + if (!this._serverCapabilities.resources) { + return { resourceTemplates: [] }; + } + return this._request('resources/templates/list', params, options); + } + + async readResource(params: ReadResourceRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('resources', 'resources/read'); + return this._request('resources/read', params, options); + } + + async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('resources', 'resources/subscribe'); + if (!this._serverCapabilities.resources?.subscribe) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + 'Server does not support resource subscriptions (required for resources/subscribe)' + ); + } + return this._request('resources/subscribe', params, options); + } + + async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('resources', 'resources/unsubscribe'); + return this._request('resources/unsubscribe', params, options); + } + + async callTool(params: CallToolRequest['params'], options?: RequestOptions): Promise { + this._assertCapability('tools', 'tools/call'); + return this._request('tools/call', params, options); + } + + async listTools(params?: ListToolsRequest['params'], options?: RequestOptions): Promise { + if (!this._serverCapabilities.tools) { + return { tools: [] }; + } + return this._request('tools/list', params, options); + } + + async request(request: { method: string; params?: Record }, options?: RequestOptions): Promise { + return this._request(request.method, request.params, options); + } + + // --------------------------------------------------------------------------- + // Internal + // --------------------------------------------------------------------------- + + private _assertCapability(capability: keyof ServerCapabilities, method: string): void { + if (!this._serverCapabilities[capability]) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support ${capability} (required for ${method})`); + } + } + + /** + * Sends a JSON-RPC request with `_meta` injection containing protocol version, + * client capabilities, and client info. Returns a promise that resolves when + * the server responds. + */ + private _request(method: string, params?: Record, options?: RequestOptions): Promise { + const id = this._nextId++; + const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; + + const message: JSONRPCRequest = { + jsonrpc: '2.0', + id, + method, + params: { + ...params, + _meta: { + ...(params?._meta as Record | undefined), + protocolVersion: '2026-06-30', + clientCapabilities: this._clientCapabilities, + clientInfo: this._clientInfo + } + } + }; + + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + this._pending.delete(id); + reject(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })); + }, timeout); + + this._pending.set(id, { + resolve: result => { + clearTimeout(timer); + resolve(result as T); + }, + reject: error => { + clearTimeout(timer); + reject(error); + }, + timer + }); + + this._transport!.send(message).catch(sendError => { + clearTimeout(timer); + this._pending.delete(id); + reject(sendError); + }); + }); + } + + /** + * Dispatches a JSON-RPC response to the matching pending request. + */ + private _onResponse(response: JSONRPCMessage): void { + const id = Number((response as { id?: unknown }).id); + const pending = this._pending.get(id); + if (!pending) { + this.onerror?.(new Error(`Received response for unknown request ID: ${id}`)); + return; + } + this._pending.delete(id); + + if (isJSONRPCResultResponse(response)) { + pending.resolve(response.result); + } else if (isJSONRPCErrorResponse(response)) { + pending.reject(ProtocolError.fromError(response.error.code, response.error.message, response.error.data)); + } + } + + /** + * Dispatches a server-to-client notification to a registered handler. + */ + private _onNotification(notification: JSONRPCNotification): void { + const handler = this._registry.notificationHandlers.get(notification.method) ?? this._registry.fallbackNotificationHandler; + if (handler) { + Promise.resolve() + .then(() => handler(notification)) + .catch(error => this.onerror?.(new Error(`Uncaught error in notification handler: ${error}`))); + } + } + + /** + * Dispatches a server-to-client request to a registered handler. + */ + private _onRequest(request: JSONRPCRequest): void { + const handler = this._registry.requestHandlers.get(request.method) ?? this._registry.fallbackRequestHandler; + if (!handler) { + this._transport + ?.send({ + jsonrpc: '2.0', + id: request.id, + error: { code: -32_601, message: 'Method not found' } + }) + .catch(error => this.onerror?.(new Error(`Failed to send error response: ${error}`))); + return; + } + + const abortController = new AbortController(); + const ctx: ClientContext = { + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta as ClientContext['mcpReq']['_meta'], + signal: abortController.signal, + send: (() => { + throw new SdkError( + SdkErrorCode.UnsupportedOperation, + 'Bidirectional requests not supported on the modern (2026-06) client path' + ); + }) as ClientContext['mcpReq']['send'], + notify: () => { + throw new SdkError( + SdkErrorCode.UnsupportedOperation, + 'Bidirectional notifications not supported on the modern (2026-06) client path' + ); + } + } + }; + + Promise.resolve() + .then(() => handler(request, ctx)) + .then(result => { + this._transport + ?.send({ + jsonrpc: '2.0', + id: request.id, + result + }) + .catch(error => this.onerror?.(new Error(`Failed to send response: ${error}`))); + }) + .catch(error => { + const errorRecord = error as Record; + this._transport + ?.send({ + jsonrpc: '2.0', + id: request.id, + error: { + code: Number.isSafeInteger(errorRecord['code']) ? (errorRecord['code'] as number) : -32_603, + message: (error as Error).message ?? 'Internal error' + } + }) + .catch(error_ => this.onerror?.(new Error(`Failed to send error response: ${error_}`))); + }); + } + + /** + * Rejects all pending requests with the given error. + */ + private _rejectAll(error: Error): void { + const pending = this._pending; + this._pending = new Map(); + for (const entry of pending.values()) { + entry.reject(error); + } + } +} diff --git a/packages/client/src/client/modernStdio.ts b/packages/client/src/client/modernStdio.ts new file mode 100644 index 0000000000..1dcc8f25ad --- /dev/null +++ b/packages/client/src/client/modernStdio.ts @@ -0,0 +1,181 @@ +import type { Stream } from 'node:stream'; + +import type { JSONRPCMessage, MessageExtraInfo, TransportSendOptions } from '@modelcontextprotocol/core'; +import { isJSONRPCResultResponse } from '@modelcontextprotocol/core'; + +import type { DiscoverResult } from './modernClientImpl.js'; +import type { StdioServerParameters } from './stdio.js'; +import { LegacyStdioClientTransport } from './stdio.js'; +import type { VersionProbingTransport } from './versionProbing.js'; + +const DEFAULT_PROBE_TIMEOUT_MS = 5000; + +export type StdioClientTransportOptions = StdioServerParameters & { + /** Skip version probing and always use legacy mode. */ + forceLegacy?: boolean; + /** Timeout for the server/discover probe in milliseconds. Default: 5000. */ + probeTimeoutMs?: number; +}; + +/** + * Dual-protocol stdio client transport with automatic version probing. + * + * During {@linkcode start | start()}, spawns the server process and sends a + * `server/discover` probe. If the server responds with a valid DiscoverResult, + * the transport operates in modern mode. Otherwise, falls back to legacy mode. + */ +export class StdioClientTransport implements VersionProbingTransport { + private _inner: LegacyStdioClientTransport; + private _mode: 'modern' | 'legacy' = 'legacy'; + private _discoverResult?: DiscoverResult; + private _started = false; + private _forceLegacy: boolean; + private _probeTimeoutMs: number; + + private _probeResolve?: (result: DiscoverResult | null) => void; + private _probeTimeout?: ReturnType; + private _pendingMessages: JSONRPCMessage[] = []; + private _probeId: string = crypto.randomUUID(); + + private _onclose?: (() => void) | undefined; + private _onerror?: ((error: Error) => void) | undefined; + private _onmessage?: ((message: T, extra?: MessageExtraInfo) => void) | undefined; + + constructor(options: StdioClientTransportOptions) { + this._forceLegacy = options.forceLegacy ?? false; + this._probeTimeoutMs = options.probeTimeoutMs ?? DEFAULT_PROBE_TIMEOUT_MS; + this._inner = new LegacyStdioClientTransport(options); + } + + async start(): Promise { + if (this._started) { + return; + } + this._started = true; + + await this._inner.start(); + + if (this._forceLegacy) { + return; + } + + try { + const result = await this._probeDiscover(); + if (result) { + this._mode = 'modern'; + this._discoverResult = result; + } + } catch { + // Any failure = legacy mode + } + } + + get mode(): 'modern' | 'legacy' { + return this._mode; + } + + getDiscoverResult(): DiscoverResult | undefined { + return this._discoverResult; + } + + get stderr(): Stream | null { + return this._inner.stderr; + } + + get pid(): number | null { + return this._inner.pid; + } + + // ------------------------------------------------------------------- + // Transport interface delegation + // ------------------------------------------------------------------- + + async send(message: JSONRPCMessage, _options?: TransportSendOptions): Promise { + await this._inner.send(message); + } + + async close(): Promise { + if (this._probeResolve) { + clearTimeout(this._probeTimeout); + this._probeResolve(null); + this._probeResolve = undefined; + } + await this._inner.close(); + } + + set onclose(handler: (() => void) | undefined) { + this._onclose = handler; + this._inner.onclose = handler; + } + get onclose(): (() => void) | undefined { + return this._onclose; + } + + set onerror(handler: ((error: Error) => void) | undefined) { + this._onerror = handler; + this._inner.onerror = handler; + } + get onerror(): ((error: Error) => void) | undefined { + return this._onerror; + } + + set onmessage(handler: ((message: T, extra?: MessageExtraInfo) => void) | undefined) { + this._onmessage = handler; + this._inner.onmessage = handler; + for (const msg of this._pendingMessages) { + handler?.(msg); + } + this._pendingMessages = []; + } + get onmessage(): ((message: T, extra?: MessageExtraInfo) => void) | undefined { + return this._onmessage; + } + + // ------------------------------------------------------------------- + // Probe + // ------------------------------------------------------------------- + + private async _probeDiscover(): Promise { + return new Promise(resolve => { + let resolved = false; + this._probeResolve = resolve; + + const finish = (result: DiscoverResult | null) => { + if (resolved) return; + resolved = true; + clearTimeout(this._probeTimeout); + this._probeResolve = undefined; + this._inner.onmessage = m => this._pendingMessages.push(m); + resolve(result); + }; + + this._probeTimeout = setTimeout(() => finish(null), this._probeTimeoutMs); + + this._inner.onmessage = (msg: JSONRPCMessage) => { + const id = (msg as { id?: unknown }).id; + if (id !== this._probeId) { + this._pendingMessages.push(msg); + return; + } + + if (isJSONRPCResultResponse(msg)) { + const result = msg.result as Record; + if (Array.isArray(result.supportedVersions) && result.capabilities && result.serverInfo) { + finish(result as unknown as DiscoverResult); + return; + } + } + finish(null); + }; + + this._inner + .send({ + jsonrpc: '2.0', + id: this._probeId, + method: 'server/discover', + params: {} + }) + .catch(() => finish(null)); + }); + } +} diff --git a/packages/client/src/client/modernStreamableHttp.ts b/packages/client/src/client/modernStreamableHttp.ts new file mode 100644 index 0000000000..f0eb9bab56 --- /dev/null +++ b/packages/client/src/client/modernStreamableHttp.ts @@ -0,0 +1,213 @@ +import type { JSONRPCMessage, JSONRPCRequest, MessageExtraInfo, TransportSendOptions } from '@modelcontextprotocol/core'; +import { isJSONRPCRequest } from '@modelcontextprotocol/core'; + +import type { DiscoverResult } from './modernClientImpl.js'; +import type { StreamableHTTPClientTransportOptions } from './streamableHttp.js'; +import { LegacyStreamableHTTPClientTransport } from './streamableHttp.js'; +import type { VersionProbingTransport } from './versionProbing.js'; + +/** + * Dual-protocol HTTP client transport with automatic version probing. + * + * During {@linkcode start | start()}, sends a `server/discover` probe to detect whether + * the server supports the modern (2026-06) MCP protocol. If the probe succeeds, the + * transport operates in `modern` mode and automatically adds the `Mcp-Method` header + * to every outgoing request. If the probe fails, falls back to `legacy` mode. + * + * Use {@linkcode getDiscoverResult | getDiscoverResult()} after {@linkcode start | start()} to + * retrieve the server's capabilities when in modern mode. + */ +export class StreamableHTTPClientTransport implements VersionProbingTransport { + private _inner: LegacyStreamableHTTPClientTransport; + private _mode: 'modern' | 'legacy' = 'legacy'; + private _discoverResult?: DiscoverResult; + private _started = false; + private _forceLegacy: boolean; + + private _onclose?: (() => void) | undefined; + private _onerror?: ((error: Error) => void) | undefined; + private _onmessage?: ((message: T, extra?: MessageExtraInfo) => void) | undefined; + + constructor(url: URL, options?: StreamableHTTPClientTransportOptions) { + this._forceLegacy = options?.forceLegacy ?? false; + this._inner = new LegacyStreamableHTTPClientTransport(url, { + ...options, + getExtraHeaders: (message: JSONRPCMessage | JSONRPCMessage[]) => { + // Merge user-provided extra headers first + const userExtras = options?.getExtraHeaders?.(message) ?? {}; + + if (this._mode === 'modern' && !Array.isArray(message) && isJSONRPCRequest(message)) { + return { + ...userExtras, + 'mcp-method': (message as JSONRPCRequest).method + }; + } + return userExtras; + } + }); + } + + /** + * Starts the inner transport, then probes `server/discover` to detect + * whether the server supports the modern protocol. + */ + async start(): Promise { + if (this._started) { + return; + } + this._started = true; + + await this._inner.start(); + + if (this._forceLegacy) { + return; + } + + try { + const result = await this._probeFetch(); + if (result) { + this._mode = 'modern'; + this._discoverResult = result; + } + } catch { + // Any failure means legacy mode -- no action needed + } + } + + /** + * Sends a raw `server/discover` request to the server endpoint to probe + * for modern protocol support. + * + * This bypasses the transport's `send()` to avoid triggering `onmessage` + * callbacks before the client is fully wired. + */ + private async _probeFetch(): Promise { + const headers = await this._inner.commonHeaders(); + headers.set('content-type', 'application/json'); + headers.set('accept', 'application/json'); + headers.set('mcp-method', 'server/discover'); + + const body: JSONRPCRequest = { + jsonrpc: '2.0', + id: 0, + method: 'server/discover', + params: {} + }; + + const response = await this._inner.fetchFn(this._inner.url, { + ...this._inner.requestInit, + method: 'POST', + headers, + body: JSON.stringify(body) + }); + + if (!response.ok) { + return null; + } + + const contentType = response.headers.get('content-type'); + if (!contentType?.includes('application/json')) { + await response.text?.().catch(() => {}); + return null; + } + + const data = (await response.json()) as Record; + + // The response should be a JSON-RPC result containing the discover info + if (data?.jsonrpc === '2.0' && data?.result) { + const result = data.result as Record; + if (Array.isArray(result.supportedVersions) && result.capabilities && result.serverInfo) { + return result as unknown as DiscoverResult; + } + } + + return null; + } + + /** + * Returns the discover result if the server supports the modern protocol, + * or `undefined` if the server is legacy. + */ + getDiscoverResult(): DiscoverResult | undefined { + return this._discoverResult; + } + + /** + * Whether the transport is operating in modern or legacy mode. + */ + get mode(): 'modern' | 'legacy' { + return this._mode; + } + + // --------------------------------------------------------------------------- + // Transport interface delegation + // --------------------------------------------------------------------------- + + async send(message: JSONRPCMessage, options?: TransportSendOptions): Promise { + await this._inner.send(message, options); + } + + async close(): Promise { + await this._inner.close(); + } + + get sessionId(): string | undefined { + if (this._mode === 'modern') { + return undefined; + } + return this._inner.sessionId; + } + + set onclose(handler: (() => void) | undefined) { + this._onclose = handler; + this._inner.onclose = handler; + } + + get onclose(): (() => void) | undefined { + return this._onclose; + } + + set onerror(handler: ((error: Error) => void) | undefined) { + this._onerror = handler; + this._inner.onerror = handler; + } + + get onerror(): ((error: Error) => void) | undefined { + return this._onerror; + } + + set onmessage(handler: ((message: T, extra?: MessageExtraInfo) => void) | undefined) { + this._onmessage = handler; + this._inner.onmessage = handler; + } + + get onmessage(): ((message: T, extra?: MessageExtraInfo) => void) | undefined { + return this._onmessage; + } + + setProtocolVersion(version: string): void { + this._inner.setProtocolVersion(version); + } + + get protocolVersion(): string | undefined { + return this._inner.protocolVersion; + } + + async finishAuth(authorizationCode: string): Promise { + return this._inner.finishAuth(authorizationCode); + } + + async terminateSession(): Promise { + if (this._mode === 'modern') { + throw new Error('terminateSession() is not available in modern protocol mode (no session to terminate)'); + } + return this._inner.terminateSession(); + } + + async resumeStream(lastEventId: string, options?: { onresumptiontoken?: (token: string) => void }): Promise { + if (this._mode === 'modern') { + throw new Error('resumeStream() is not available in modern protocol mode (no SSE stream to resume)'); + } + return this._inner.resumeStream(lastEventId, options); + } +} diff --git a/packages/client/src/client/stdio.ts b/packages/client/src/client/stdio.ts index 5dcb8ef9a6..b03ed8412b 100644 --- a/packages/client/src/client/stdio.ts +++ b/packages/client/src/client/stdio.ts @@ -90,7 +90,7 @@ export function getDefaultEnvironment(): Record { * * This transport is only available in Node.js environments. */ -export class StdioClientTransport implements Transport { +export class LegacyStdioClientTransport implements Transport { private _process?: ChildProcess; private _readBuffer: ReadBuffer = new ReadBuffer(); private _serverParams: StdioServerParameters; @@ -113,7 +113,7 @@ export class StdioClientTransport implements Transport { async start(): Promise { if (this._process) { throw new Error( - 'StdioClientTransport already started! If using Client class, note that connect() calls start() automatically.' + 'LegacyStdioClientTransport already started! If using Client class, note that connect() calls start() automatically.' ); } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index cd643c96dc..94d503a6dc 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -51,7 +51,7 @@ export interface StartSSEOptions { } /** - * Configuration options for reconnection behavior of the {@linkcode StreamableHTTPClientTransport}. + * Configuration options for reconnection behavior of the {@linkcode LegacyStreamableHTTPClientTransport}. */ export interface StreamableHTTPReconnectionOptions { /** @@ -91,7 +91,7 @@ export interface StreamableHTTPReconnectionOptions { * @param delay - Suggested delay in milliseconds (from backoff calculation). * @param attemptCount - Zero-indexed retry attempt number. * @returns An optional cancel function. If returned, it will be called on - * {@linkcode StreamableHTTPClientTransport.close | transport.close()} to abort the + * {@linkcode LegacyStreamableHTTPClientTransport.close | transport.close()} to abort the * pending reconnection. * * @example @@ -105,7 +105,7 @@ export interface StreamableHTTPReconnectionOptions { export type ReconnectionScheduler = (reconnect: () => void, delay: number, attemptCount: number) => (() => void) | void; /** - * Configuration options for the {@linkcode StreamableHTTPClientTransport}. + * Configuration options for the {@linkcode LegacyStreamableHTTPClientTransport}. */ export type StreamableHTTPClientTransportOptions = { /** @@ -122,7 +122,7 @@ export type StreamableHTTPClientTransportOptions = { * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation * directly — the transport adapts it to `AuthProvider` internally. Interactive flows: after * {@linkcode UnauthorizedError}, redirect the user, then call - * {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before + * {@linkcode LegacyStreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before * reconnecting. */ authProvider?: AuthProvider | OAuthClientProvider; @@ -160,6 +160,22 @@ export type StreamableHTTPClientTransportOptions = { * handshake so the reconnected transport continues sending the required header. */ protocolVersion?: string; + + /** + * When `true`, skip the `server/discover` probe and always use the legacy (2025-11) + * protocol path. Useful for connecting to known legacy servers or for testing + * legacy protocol behavior explicitly. + */ + forceLegacy?: boolean; + + /** + * Optional callback to inject extra HTTP headers into every outgoing POST request. + * Called after the standard headers are built, so returned headers can override them. + * + * Used internally by the probing transport to add the `Mcp-Method` header + * for modern (2026-06) protocol requests. + */ + getExtraHeaders?: (message: JSONRPCMessage | JSONRPCMessage[]) => Record; }; /** @@ -167,7 +183,7 @@ export type StreamableHTTPClientTransportOptions = { * It will connect to a server using HTTP `POST` for sending messages and HTTP `GET` with Server-Sent Events * for receiving messages. */ -export class StreamableHTTPClientTransport implements Transport { +export class LegacyStreamableHTTPClientTransport implements Transport { private _abortController?: AbortController; private _url: URL; private _resourceMetadataUrl?: URL; @@ -184,6 +200,7 @@ export class StreamableHTTPClientTransport implements Transport { private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private readonly _reconnectionScheduler?: ReconnectionScheduler; private _cancelReconnection?: () => void; + private _getExtraHeaders?: (message: JSONRPCMessage | JSONRPCMessage[]) => Record; onclose?: () => void; onerror?: (error: Error) => void; @@ -206,6 +223,7 @@ export class StreamableHTTPClientTransport implements Transport { this._protocolVersion = opts?.protocolVersion; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; this._reconnectionScheduler = opts?.reconnectionScheduler; + this._getExtraHeaders = opts?.getExtraHeaders; } private async _commonHeaders(): Promise { @@ -539,6 +557,12 @@ export class StreamableHTTPClientTransport implements Transport { } const headers = await this._commonHeaders(); + if (this._getExtraHeaders) { + const extra = this._getExtraHeaders(message); + for (const [k, v] of Object.entries(extra)) { + headers.set(k, v); + } + } headers.set('content-type', 'application/json'); const userAccept = headers.get('accept'); const types = [...(userAccept?.split(',').map(s => s.trim().toLowerCase()) ?? []), 'application/json', 'text/event-stream']; @@ -745,6 +769,26 @@ export class StreamableHTTPClientTransport implements Transport { return this._protocolVersion; } + /** @internal Exposes the endpoint URL for use by the probing transport wrapper. */ + get url(): URL { + return this._url; + } + + /** @internal Builds common headers. Exposed for wrapping transports to use during probing. */ + async commonHeaders(): Promise { + return this._commonHeaders(); + } + + /** @internal Exposes the fetch function for use by wrapping transports. */ + get fetchFn(): (url: string | URL, init?: RequestInit) => Promise { + return this._fetch ?? fetch; + } + + /** @internal Exposes the base requestInit for use by wrapping transports. */ + get requestInit(): RequestInit | undefined { + return this._requestInit; + } + /** * Resume an SSE stream from a previous event ID. * Opens a `GET` SSE connection with `Last-Event-ID` header to replay missed events. diff --git a/packages/client/src/client/versionProbing.ts b/packages/client/src/client/versionProbing.ts new file mode 100644 index 0000000000..bb6002c166 --- /dev/null +++ b/packages/client/src/client/versionProbing.ts @@ -0,0 +1,23 @@ +import type { Transport } from '@modelcontextprotocol/core'; + +import type { DiscoverResult } from './modernClientImpl.js'; + +/** + * A transport that detects the server's protocol version during start(). + * + * Implemented by StreamableHTTPClientTransport (HTTP) and + * StdioClientTransport (stdio). Used by Client to decide between + * ModernClientImpl and LegacyClient. + */ +export interface VersionProbingTransport extends Transport { + readonly mode: 'modern' | 'legacy'; + getDiscoverResult(): DiscoverResult | undefined; +} + +export function isVersionProbingTransport(transport: Transport): transport is VersionProbingTransport { + return ( + 'mode' in transport && + 'getDiscoverResult' in transport && + typeof (transport as VersionProbingTransport).getDiscoverResult === 'function' + ); +} diff --git a/packages/client/src/experimental/index.ts b/packages/client/src/experimental/index.ts deleted file mode 100644 index 926369f994..0000000000 --- a/packages/client/src/experimental/index.ts +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Experimental MCP SDK features. - * WARNING: These APIs are experimental and may change without notice. - * - * Import experimental features from this module: - * ```typescript - * import { TaskStore, InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental'; - * ``` - * - * @experimental - */ - -export * from './tasks/client.js'; diff --git a/packages/client/src/experimental/tasks/client.examples.ts b/packages/client/src/experimental/tasks/client.examples.ts deleted file mode 100644 index 5652062758..0000000000 --- a/packages/client/src/experimental/tasks/client.examples.ts +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Type-checked examples for `client.ts`. - * - * These examples are synced into JSDoc comments via the sync-snippets script. - * Each function's region markers define the code snippet that appears in the docs. - * - * @module - */ - -import type { RequestOptions } from '@modelcontextprotocol/core'; - -import type { Client } from '../../client/client.js'; - -/** - * Example: Using callToolStream to execute a tool with task lifecycle events. - */ -async function ExperimentalClientTasks_callToolStream(client: Client) { - //#region ExperimentalClientTasks_callToolStream - const stream = client.experimental.tasks.callToolStream({ name: 'myTool', arguments: {} }); - for await (const message of stream) { - switch (message.type) { - case 'taskCreated': { - console.log('Tool execution started:', message.task.taskId); - break; - } - case 'taskStatus': { - console.log('Tool status:', message.task.status); - break; - } - case 'result': { - console.log('Tool result:', message.result); - break; - } - case 'error': { - console.error('Tool error:', message.error); - break; - } - } - } - //#endregion ExperimentalClientTasks_callToolStream -} - -/** - * Example: Using requestStream to consume task lifecycle events for any request type. - */ -async function ExperimentalClientTasks_requestStream(client: Client, options: RequestOptions) { - //#region ExperimentalClientTasks_requestStream - const stream = client.experimental.tasks.requestStream({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, options); - for await (const message of stream) { - switch (message.type) { - case 'taskCreated': { - console.log('Task created:', message.task.taskId); - break; - } - case 'taskStatus': { - console.log('Task status:', message.task.status); - break; - } - case 'result': { - console.log('Final result:', message.result); - break; - } - case 'error': { - console.error('Error:', message.error); - break; - } - } - } - //#endregion ExperimentalClientTasks_requestStream -} diff --git a/packages/client/src/experimental/tasks/client.ts b/packages/client/src/experimental/tasks/client.ts deleted file mode 100644 index 75ba873c97..0000000000 --- a/packages/client/src/experimental/tasks/client.ts +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Experimental client task features for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - -import type { - AnyObjectSchema, - CallToolRequest, - CallToolResult, - CancelTaskResult, - CreateTaskResult, - GetTaskPayloadResult, - GetTaskResult, - ListTasksResult, - Request, - RequestMethod, - RequestOptions, - ResponseMessage, - ResultTypeMap -} from '@modelcontextprotocol/core'; -import { - CallToolResultSchema, - getResultSchema, - GetTaskPayloadResultSchema, - ProtocolError, - ProtocolErrorCode -} from '@modelcontextprotocol/core'; - -import type { Client } from '../../client/client.js'; - -/** - * Internal interface for accessing {@linkcode Client}'s private methods. - * @internal - */ -interface ClientInternal { - isToolTask(toolName: string): boolean; - getToolOutputValidator(toolName: string): ((data: unknown) => { valid: boolean; errorMessage?: string }) | undefined; -} - -/** - * Experimental task features for MCP clients. - * - * Access via `client.experimental.tasks`: - * ```typescript - * const stream = client.experimental.tasks.callToolStream({ name: 'tool', arguments: {} }); - * const task = await client.experimental.tasks.getTask(taskId); - * ``` - * - * @experimental - */ -export class ExperimentalClientTasks { - constructor(private readonly _client: Client) {} - - private get _module() { - return this._client.taskManager; - } - - /** - * Calls a tool and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a `'result'` or `'error'` message. - * - * This method provides streaming access to tool execution, allowing you to - * observe intermediate task status updates for long-running tool calls. - * Automatically validates structured output if the tool has an `outputSchema`. - * - * @example - * ```ts source="./client.examples.ts#ExperimentalClientTasks_callToolStream" - * const stream = client.experimental.tasks.callToolStream({ name: 'myTool', arguments: {} }); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': { - * console.log('Tool execution started:', message.task.taskId); - * break; - * } - * case 'taskStatus': { - * console.log('Tool status:', message.task.status); - * break; - * } - * case 'result': { - * console.log('Tool result:', message.result); - * break; - * } - * case 'error': { - * console.error('Tool error:', message.error); - * break; - * } - * } - * } - * ``` - * - * @param params - Tool call parameters (name and arguments) - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields {@linkcode ResponseMessage} objects - * - * @experimental - */ - async *callToolStream( - params: CallToolRequest['params'], - options?: RequestOptions - ): AsyncGenerator, void, void> { - // Access Client's internal methods - const clientInternal = this._client as unknown as ClientInternal; - - // Add task creation parameters if server supports it and not explicitly provided - const optionsWithTask = { - ...options, - // We check if the tool is known to be a task during auto-configuration, but assume - // the caller knows what they're doing if they pass this explicitly - task: options?.task ?? (clientInternal.isToolTask(params.name) ? {} : undefined) - }; - - const stream = this._module.requestStream({ method: 'tools/call', params }, CallToolResultSchema, optionsWithTask); - - // Get the validator for this tool (if it has an output schema) - const validator = clientInternal.getToolOutputValidator(params.name); - - // Iterate through the stream and validate the final result if needed - for await (const message of stream) { - // If this is a result message and the tool has an output schema, validate it - // Only validate CallToolResult (has 'content'), not CreateTaskResult (has 'task') - if (message.type === 'result' && validator && 'content' in message.result) { - const result = message.result as CallToolResult; - - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - yield { - type: 'error', - error: new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ) - }; - return; - } - - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content against the schema - const validationResult = validator(result.structuredContent); - - if (!validationResult.valid) { - yield { - type: 'error', - error: new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` - ) - }; - return; - } - } catch (error) { - if (error instanceof ProtocolError) { - yield { type: 'error', error }; - return; - } - yield { - type: 'error', - error: new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ) - }; - return; - } - } - } - - // Yield the message (either validated result or any other message type) - yield message; - } - } - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * @returns The task status - * - * @experimental - */ - async getTask(taskId: string, options?: RequestOptions): Promise { - return this._module.getTask({ taskId }, options); - } - - /** - * Retrieves the result of a completed task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * @returns The task result. The payload structure matches the result type of the - * original request (e.g., a `tools/call` task returns a `CallToolResult`). - * - * @experimental - */ - async getTaskResult(taskId: string, options?: RequestOptions): Promise { - return this._module.getTaskResult({ taskId }, GetTaskPayloadResultSchema, options); - } - - /** - * Lists tasks with optional pagination. - * - * @param cursor - Optional pagination cursor - * @param options - Optional request options - * @returns List of tasks with optional next cursor - * - * @experimental - */ - async listTasks(cursor?: string, options?: RequestOptions): Promise { - return this._module.listTasks(cursor ? { cursor } : undefined, options); - } - - /** - * Cancels a running task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * - * @experimental - */ - async cancelTask(taskId: string, options?: RequestOptions): Promise { - return this._module.cancelTask({ taskId }, options); - } - - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a `'result'` or `'error'` message. - * - * This method provides streaming access to request processing, allowing you to - * observe intermediate task status updates for task-augmented requests. - * - * @example - * ```ts source="./client.examples.ts#ExperimentalClientTasks_requestStream" - * const stream = client.experimental.tasks.requestStream({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, options); - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': { - * console.log('Task created:', message.task.taskId); - * break; - * } - * case 'taskStatus': { - * console.log('Task status:', message.task.status); - * break; - * } - * case 'result': { - * console.log('Final result:', message.result); - * break; - * } - * case 'error': { - * console.error('Error:', message.error); - * break; - * } - * } - * } - * ``` - * - * @param request - The request to send - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields {@linkcode ResponseMessage} objects - * - * @experimental - */ - requestStream( - request: { method: M; params?: Record }, - options?: RequestOptions - ): AsyncGenerator, void, void> { - const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema; - return this._module.requestStream(request as Request, resultSchema, options) as AsyncGenerator< - ResponseMessage, - void, - void - >; - } -} diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 06ca1141b2..a2a2322ebd 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -53,7 +53,7 @@ export { StaticPrivateKeyJwtProvider } from './client/authExtensions.js'; export type { ClientOptions } from './client/client.js'; -export { Client } from './client/client.js'; +export { Client, LegacyClient } from './client/client.js'; export { getSupportedElicitationModes } from './client/client.js'; export type { DiscoverAndRequestJwtAuthGrantOptions, JwtAuthGrantResult, RequestJwtAuthGrantOptions } from './client/crossAppAccess.js'; export { discoverAndRequestJwtAuthGrant, exchangeJwtAuthGrant, requestJwtAuthorizationGrant } from './client/crossAppAccess.js'; @@ -63,16 +63,17 @@ export type { SSEClientTransportOptions } from './client/sse.js'; export { SSEClientTransport, SseError } from './client/sse.js'; // StdioClientTransport, getDefaultEnvironment, DEFAULT_INHERITED_ENV_VARS, StdioServerParameters are exported from // the './stdio' subpath to keep the root entry free of process-spawning runtime dependencies (child_process, cross-spawn). +export type { DiscoverResult } from './client/modernClientImpl.js'; +export { ModernClientImpl } from './client/modernClientImpl.js'; +export { StreamableHTTPClientTransport } from './client/modernStreamableHttp.js'; export type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPClientTransportOptions, StreamableHTTPReconnectionOptions } from './client/streamableHttp.js'; -export { StreamableHTTPClientTransport } from './client/streamableHttp.js'; - -// experimental exports -export { ExperimentalClientTasks } from './experimental/tasks/client.js'; +export type { VersionProbingTransport } from './client/versionProbing.js'; +export { isVersionProbingTransport } from './client/versionProbing.js'; // runtime-aware wrapper (shadows core/public's fromJsonSchema with optional validator) export { fromJsonSchema } from './fromJsonSchema.js'; diff --git a/packages/client/src/stdio.ts b/packages/client/src/stdio.ts index a6ecd1697e..805747e407 100644 --- a/packages/client/src/stdio.ts +++ b/packages/client/src/stdio.ts @@ -4,5 +4,7 @@ // Cloudflare Workers targets does not pull in `node:child_process`, `node:stream`, or `cross-spawn`. Import // from `@modelcontextprotocol/client/stdio` only in process-spawning runtimes (Node.js, Bun, Deno). +export type { StdioClientTransportOptions } from './client/modernStdio.js'; +export { StdioClientTransport } from './client/modernStdio.js'; export type { StdioServerParameters } from './client/stdio.js'; -export { DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment, StdioClientTransport } from './client/stdio.js'; +export { DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment, LegacyStdioClientTransport } from './client/stdio.js'; diff --git a/packages/client/test/client/__fixtures__/legacyServer.mjs b/packages/client/test/client/__fixtures__/legacyServer.mjs new file mode 100644 index 0000000000..964ac52c55 --- /dev/null +++ b/packages/client/test/client/__fixtures__/legacyServer.mjs @@ -0,0 +1,23 @@ +// Fixture: responds to server/discover with a JSON-RPC error (simulating a legacy server). +import { createInterface } from 'node:readline'; + +const rl = createInterface({ input: process.stdin }); + +rl.on('line', line => { + if (!line.trim()) return; + let msg; + try { + msg = JSON.parse(line); + } catch { + return; + } + if (msg.method === 'server/discover') { + process.stdout.write( + JSON.stringify({ + jsonrpc: '2.0', + id: msg.id, + error: { code: -32601, message: 'Method not found' } + }) + '\n' + ); + } +}); diff --git a/packages/client/test/client/__fixtures__/modernServer.mjs b/packages/client/test/client/__fixtures__/modernServer.mjs new file mode 100644 index 0000000000..1ea9a0cca5 --- /dev/null +++ b/packages/client/test/client/__fixtures__/modernServer.mjs @@ -0,0 +1,27 @@ +// Fixture: responds to server/discover with a valid DiscoverResult. +import { createInterface } from 'node:readline'; + +const rl = createInterface({ input: process.stdin }); + +rl.on('line', line => { + if (!line.trim()) return; + let msg; + try { + msg = JSON.parse(line); + } catch { + return; + } + if (msg.method === 'server/discover') { + process.stdout.write( + JSON.stringify({ + jsonrpc: '2.0', + id: msg.id, + result: { + supportedVersions: ['2026-06-30'], + capabilities: { tools: {} }, + serverInfo: { name: 'modern-fixture', version: '1.0.0' } + } + }) + '\n' + ); + } +}); diff --git a/packages/client/test/client/__fixtures__/modernServerWithExtra.mjs b/packages/client/test/client/__fixtures__/modernServerWithExtra.mjs new file mode 100644 index 0000000000..8d1eb83dc2 --- /dev/null +++ b/packages/client/test/client/__fixtures__/modernServerWithExtra.mjs @@ -0,0 +1,35 @@ +// Fixture: responds to server/discover with a valid DiscoverResult, then +// immediately sends an unsolicited notification (to test message buffering). +import { createInterface } from 'node:readline'; + +const rl = createInterface({ input: process.stdin }); + +rl.on('line', line => { + if (!line.trim()) return; + let msg; + try { + msg = JSON.parse(line); + } catch { + return; + } + if (msg.method === 'server/discover') { + process.stdout.write( + JSON.stringify({ + jsonrpc: '2.0', + id: msg.id, + result: { + supportedVersions: ['2026-06-30'], + capabilities: { tools: {} }, + serverInfo: { name: 'modern-extra-fixture', version: '1.0.0' } + } + }) + '\n' + ); + process.stdout.write( + JSON.stringify({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'buffered-message' } + }) + '\n' + ); + } +}); diff --git a/packages/client/test/client/__fixtures__/silentServer.mjs b/packages/client/test/client/__fixtures__/silentServer.mjs new file mode 100644 index 0000000000..53e7f58dfb --- /dev/null +++ b/packages/client/test/client/__fixtures__/silentServer.mjs @@ -0,0 +1,3 @@ +// Fixture: reads stdin but never responds (for timeout tests). +process.stdin.resume(); +setInterval(() => {}, 60_000); diff --git a/packages/client/test/client/crossSpawn.test.ts b/packages/client/test/client/crossSpawn.test.ts index a6d0272a4c..654b536ebd 100644 --- a/packages/client/test/client/crossSpawn.test.ts +++ b/packages/client/test/client/crossSpawn.test.ts @@ -4,13 +4,13 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; import spawn from 'cross-spawn'; import type { Mock, MockedFunction } from 'vitest'; -import { getDefaultEnvironment, StdioClientTransport } from '../../src/client/stdio.js'; +import { getDefaultEnvironment, LegacyStdioClientTransport } from '../../src/client/stdio.js'; // mock cross-spawn vi.mock('cross-spawn'); const mockSpawn = spawn as unknown as MockedFunction; -describe('StdioClientTransport using cross-spawn', () => { +describe('LegacyStdioClientTransport using cross-spawn', () => { beforeEach(() => { // mock cross-spawn's return value mockSpawn.mockImplementation(() => { @@ -44,7 +44,7 @@ describe('StdioClientTransport using cross-spawn', () => { }); test('should call cross-spawn correctly', async () => { - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command', args: ['arg1', 'arg2'] }); @@ -63,7 +63,7 @@ describe('StdioClientTransport using cross-spawn', () => { test('should pass environment variables correctly', async () => { const customEnv = { TEST_VAR: 'test-value' }; - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command', env: customEnv }); @@ -84,7 +84,7 @@ describe('StdioClientTransport using cross-spawn', () => { }); test('should use default environment when env is undefined', async () => { - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command', env: undefined }); @@ -102,7 +102,7 @@ describe('StdioClientTransport using cross-spawn', () => { }); test('should send messages correctly', async () => { - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command' }); @@ -167,7 +167,7 @@ describe('StdioClientTransport using cross-spawn', () => { value: 'win32' }); - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command' }); @@ -187,7 +187,7 @@ describe('StdioClientTransport using cross-spawn', () => { value: 'linux' }); - const transport = new StdioClientTransport({ + const transport = new LegacyStdioClientTransport({ command: 'test-command' }); diff --git a/packages/client/test/client/stdio.test.ts b/packages/client/test/client/stdio.test.ts index 28a7834bcb..0a0c0bb28e 100644 --- a/packages/client/test/client/stdio.test.ts +++ b/packages/client/test/client/stdio.test.ts @@ -1,7 +1,7 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; import type { StdioServerParameters } from '../../src/client/stdio.js'; -import { StdioClientTransport } from '../../src/client/stdio.js'; +import { LegacyStdioClientTransport } from '../../src/client/stdio.js'; // Configure default server parameters based on OS // Uses 'more' command for Windows and 'tee' command for Unix/Linux @@ -15,7 +15,7 @@ const getDefaultServerParameters = (): StdioServerParameters => { const serverParameters = getDefaultServerParameters(); test('should start then close cleanly', async () => { - const client = new StdioClientTransport(serverParameters); + const client = new LegacyStdioClientTransport(serverParameters); client.onerror = error => { throw error; }; @@ -32,7 +32,7 @@ test('should start then close cleanly', async () => { }); test('should read messages', async () => { - const client = new StdioClientTransport(serverParameters); + const client = new LegacyStdioClientTransport(serverParameters); client.onerror = error => { throw error; }; @@ -70,7 +70,7 @@ test('should read messages', async () => { }); test('should return child process pid', async () => { - const client = new StdioClientTransport(serverParameters); + const client = new LegacyStdioClientTransport(serverParameters); await client.start(); expect(client.pid).not.toBeNull(); diff --git a/packages/client/test/client/stdioVersionProbing.test.ts b/packages/client/test/client/stdioVersionProbing.test.ts new file mode 100644 index 0000000000..92906f3dbd --- /dev/null +++ b/packages/client/test/client/stdioVersionProbing.test.ts @@ -0,0 +1,128 @@ +import path from 'node:path'; +import url from 'node:url'; + +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; + +import { StdioClientTransport } from '../../src/client/modernStdio.js'; + +const __dirname = path.dirname(url.fileURLToPath(import.meta.url)); +const fixture = (name: string) => path.resolve(__dirname, '__fixtures__', name); + +describe('StdioClientTransport (version probing)', () => { + vi.setConfig({ testTimeout: 10_000 }); + + let transport: StdioClientTransport; + + afterEach(async () => { + await transport?.close(); + }); + + it('probe succeeds and enters modern mode', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('modernServer.mjs')] + }); + await transport.start(); + + expect(transport.mode).toBe('modern'); + expect(transport.getDiscoverResult()).toBeDefined(); + expect(transport.getDiscoverResult()!.serverInfo.name).toBe('modern-fixture'); + expect(transport.getDiscoverResult()!.supportedVersions).toContain('2026-06-30'); + }); + + it('probe returns error and falls back to legacy mode', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('legacyServer.mjs')] + }); + await transport.start(); + + expect(transport.mode).toBe('legacy'); + expect(transport.getDiscoverResult()).toBeUndefined(); + }); + + it('probe times out and falls back to legacy mode', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('silentServer.mjs')], + probeTimeoutMs: 200 + }); + await transport.start(); + + expect(transport.mode).toBe('legacy'); + expect(transport.getDiscoverResult()).toBeUndefined(); + }); + + it('forceLegacy skips probe even when server supports modern', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('modernServer.mjs')], + forceLegacy: true + }); + await transport.start(); + + expect(transport.mode).toBe('legacy'); + expect(transport.getDiscoverResult()).toBeUndefined(); + }); + + it('pid is available after start', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('modernServer.mjs')] + }); + await transport.start(); + + expect(transport.pid).toBeGreaterThan(0); + }); + + it('close during probe resolves start with legacy mode', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('silentServer.mjs')], + probeTimeoutMs: 10_000 + }); + + const startPromise = transport.start(); + setTimeout(() => transport.close(), 50); + await startPromise; + + expect(transport.mode).toBe('legacy'); + }); + + it('buffered messages during probe are flushed when onmessage is set', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('modernServerWithExtra.mjs')] + }); + await transport.start(); + expect(transport.mode).toBe('modern'); + + const flushed: JSONRPCMessage[] = []; + await new Promise(resolve => { + transport.onmessage = (msg: JSONRPCMessage) => { + flushed.push(msg); + resolve(); + }; + // If nothing arrives within a short window, resolve anyway + setTimeout(resolve, 500); + }); + + expect(flushed.length).toBeGreaterThanOrEqual(1); + const notification = flushed[0] as { method?: string; params?: { data?: string } }; + expect(notification.method).toBe('notifications/message'); + expect(notification.params?.data).toBe('buffered-message'); + }); + + it('second start call is a no-op', async () => { + transport = new StdioClientTransport({ + command: 'node', + args: [fixture('modernServer.mjs')] + }); + await transport.start(); + expect(transport.mode).toBe('modern'); + + // Calling start again should not throw or change mode + await transport.start(); + expect(transport.mode).toBe('modern'); + }); +}); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index b2138b3fa8..4b77a1d726 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -5,10 +5,10 @@ import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; -import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import { LegacyStreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; -describe('StreamableHTTPClientTransport', () => { - let transport: StreamableHTTPClientTransport; +describe('LegacyStreamableHTTPClientTransport', () => { + let transport: LegacyStreamableHTTPClientTransport; let mockAuthProvider: Mocked; beforeEach(() => { @@ -27,7 +27,7 @@ describe('StreamableHTTPClientTransport', () => { codeVerifier: vi.fn(), invalidateCredentials: vi.fn() }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); vi.spyOn(globalThis, 'fetch'); }); @@ -125,7 +125,7 @@ describe('StreamableHTTPClientTransport', () => { it('should accept protocolVersion constructor option and include it in request headers', async () => { // When reconnecting with a preserved sessionId, users need to also preserve the // negotiated protocol version so the required mcp-protocol-version header is sent. - const reconnectTransport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + const reconnectTransport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { sessionId: 'preserved-session-id', protocolVersion: '2025-11-25' }); @@ -405,7 +405,7 @@ describe('StreamableHTTPClientTransport', () => { it('should support custom reconnection options', () => { // Create a transport with custom reconnection options - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 500, maxReconnectionDelay: 10_000, @@ -425,7 +425,7 @@ describe('StreamableHTTPClientTransport', () => { it('should pass lastEventId when reconnecting', async () => { // Create a fresh transport - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); // Mock fetch to verify headers sent const fetchSpy = globalThis.fetch as Mock; @@ -457,7 +457,7 @@ describe('StreamableHTTPClientTransport', () => { // GET SSE request did not, so non-header options like credentials were dropped. vi.clearAllMocks(); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: { credentials: 'include', mode: 'cors' } }); @@ -485,7 +485,7 @@ describe('StreamableHTTPClientTransport', () => { vi.clearAllMocks(); // Create a fresh transport instance - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); const message: JSONRPCMessage = { jsonrpc: '2.0', @@ -524,7 +524,7 @@ describe('StreamableHTTPClientTransport', () => { .mockResolvedValueOnce(new Response(null, { status: 202 })); // Create transport instance - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { fetch: customFetch }); @@ -547,7 +547,7 @@ describe('StreamableHTTPClientTransport', () => { 'X-Custom-Header': 'CustomValue' } }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: requestInit }); @@ -579,7 +579,7 @@ describe('StreamableHTTPClientTransport', () => { 'X-Custom-Header': 'CustomValue' }) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: requestInit }); @@ -605,7 +605,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should always send specified custom headers (array of tuples)', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: { headers: [ ['Authorization', 'Bearer test-token'], @@ -629,7 +629,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should append custom Accept header to required types on POST requests', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: { headers: { Accept: 'application/vnd.example.v1+json' @@ -656,7 +656,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should append custom Accept header to required types on GET SSE requests', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: { headers: { Accept: 'application/json' @@ -678,7 +678,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should set default Accept header when none provided', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); let actualReqInit: RequestInit = {}; @@ -697,7 +697,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should not duplicate Accept media types when user-provided value overlaps required types', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { requestInit: { headers: { Accept: 'application/json' @@ -725,7 +725,7 @@ describe('StreamableHTTPClientTransport', () => { // This test verifies the maxRetries and backoff calculation directly // Create transport with specific options for testing - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 100, maxReconnectionDelay: 5000, @@ -876,7 +876,7 @@ describe('StreamableHTTPClientTransport', () => { }); describe('Reconnection Logic', () => { - let transport: StreamableHTTPClientTransport; + let transport: LegacyStreamableHTTPClientTransport; // Use fake timers to control setTimeout and make the test instant. beforeEach(() => vi.useFakeTimers()); @@ -884,7 +884,7 @@ describe('StreamableHTTPClientTransport', () => { it('should reconnect a GET-initiated notification stream that fails', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, @@ -938,7 +938,7 @@ describe('StreamableHTTPClientTransport', () => { it('should NOT reconnect a POST-initiated stream that fails', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, @@ -987,7 +987,7 @@ describe('StreamableHTTPClientTransport', () => { it('should reconnect a POST-initiated stream after receiving a priming event', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, @@ -1048,7 +1048,7 @@ describe('StreamableHTTPClientTransport', () => { it('should NOT reconnect a POST stream when response was received', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, @@ -1103,7 +1103,7 @@ describe('StreamableHTTPClientTransport', () => { it('should NOT reconnect a POST stream when error response was received', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, @@ -1175,7 +1175,7 @@ describe('StreamableHTTPClientTransport', () => { it('should not attempt reconnection after close() is called', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 100, maxRetries: 3, @@ -1225,7 +1225,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should not throw JSON parse error on priming events with empty data', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); const errorSpy = vi.fn(); transport.onerror = errorSpy; @@ -1501,7 +1501,7 @@ describe('StreamableHTTPClientTransport', () => { }); // Create transport instance - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider, fetch: customFetch }); @@ -1570,7 +1570,7 @@ describe('StreamableHTTPClientTransport', () => { }); // Create transport instance - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider, fetch: customFetch }); @@ -1616,7 +1616,7 @@ describe('StreamableHTTPClientTransport', () => { afterEach(() => vi.useRealTimers()); it('should use server-provided retry value for reconnection delay', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 100, maxReconnectionDelay: 5000, @@ -1671,7 +1671,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should fall back to exponential backoff when no server retry value', () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 100, maxReconnectionDelay: 5000, @@ -1693,7 +1693,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('should reconnect on graceful stream close', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxReconnectionDelay: 1000, @@ -1751,7 +1751,7 @@ describe('StreamableHTTPClientTransport', () => { }); describe('Reconnection Logic with maxRetries 0', () => { - let transport: StreamableHTTPClientTransport; + let transport: LegacyStreamableHTTPClientTransport; // Use fake timers to control setTimeout and make the test instant. beforeEach(() => vi.useFakeTimers()); @@ -1759,7 +1759,7 @@ describe('StreamableHTTPClientTransport', () => { it('should not schedule any reconnection attempts when maxRetries is 0', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 0, // This should disable retries completely @@ -1788,7 +1788,7 @@ describe('StreamableHTTPClientTransport', () => { it('should schedule reconnection when maxRetries is greater than 0', async () => { // ARRANGE - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions: { initialReconnectionDelay: 10, maxRetries: 1, // Allow 1 retry @@ -1890,7 +1890,7 @@ describe('StreamableHTTPClientTransport', () => { maxRetries: 3 }; - function triggerReconnection(t: StreamableHTTPClientTransport): void { + function triggerReconnection(t: LegacyStreamableHTTPClientTransport): void { (t as unknown as { _scheduleReconnection(opts: StartSSEOptions, attempt?: number): void })._scheduleReconnection({}, 0); } @@ -1904,7 +1904,7 @@ describe('StreamableHTTPClientTransport', () => { it('invokes the custom scheduler with reconnect, delay, and attemptCount', () => { const scheduler = vi.fn(); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: scheduler }); @@ -1917,7 +1917,7 @@ describe('StreamableHTTPClientTransport', () => { it('falls back to setTimeout when no scheduler is provided', () => { const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions }); @@ -1928,7 +1928,7 @@ describe('StreamableHTTPClientTransport', () => { it('does not use setTimeout when a custom scheduler is provided', () => { const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: vi.fn() }); @@ -1941,7 +1941,7 @@ describe('StreamableHTTPClientTransport', () => { it('calls the returned cancel function on close()', async () => { const cancel = vi.fn(); const scheduler: ReconnectionScheduler = vi.fn(() => cancel); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: scheduler }); @@ -1954,7 +1954,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('tolerates schedulers that return void (no cancel function)', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: () => { /* no return */ @@ -1967,7 +1967,7 @@ describe('StreamableHTTPClientTransport', () => { it('clears the default setTimeout on close() when no scheduler is provided', async () => { const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions }); @@ -1979,7 +1979,7 @@ describe('StreamableHTTPClientTransport', () => { it('ignores a late-firing reconnect after close()', async () => { let capturedReconnect: (() => void) | undefined; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: reconnect => { capturedReconnect = reconnect; @@ -1999,7 +1999,7 @@ describe('StreamableHTTPClientTransport', () => { }); it('still aborts and fires onclose if the cancel function throws', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { reconnectionOptions, reconnectionScheduler: () => () => { throw new Error('cancel failed'); diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index d6ef35bdee..72e229e553 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -8,10 +8,10 @@ import type { Mock } from 'vitest'; import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; -import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import { LegacyStreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; -describe('StreamableHTTPClientTransport with AuthProvider', () => { - let transport: StreamableHTTPClientTransport; +describe('LegacyStreamableHTTPClientTransport with AuthProvider', () => { + let transport: LegacyStreamableHTTPClientTransport; afterEach(async () => { await transport?.close().catch(() => {}); @@ -22,7 +22,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { it('should set Authorization header from AuthProvider.token()', async () => { const authProvider: AuthProvider = { token: vi.fn(async () => 'my-bearer-token') }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); @@ -36,7 +36,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { it('should not set Authorization header when token() returns undefined', async () => { const authProvider: AuthProvider = { token: vi.fn(async () => undefined) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); @@ -49,7 +49,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { it('should throw UnauthorizedError on 401 when onUnauthorized is not provided', async () => { const authProvider: AuthProvider = { token: vi.fn(async () => 'rejected-token') }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock).mockResolvedValueOnce({ @@ -71,7 +71,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { currentToken = 'new-token'; }) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock) @@ -91,7 +91,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { token: vi.fn(async () => 'still-bad'), onUnauthorized: vi.fn(async () => {}) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock) @@ -109,7 +109,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { token: vi.fn(async () => 'token'), onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock) @@ -127,7 +127,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { }); it('should work with no authProvider at all', async () => { - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); @@ -140,14 +140,14 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { it('should throw when finishAuth is called with a non-OAuth AuthProvider', async () => { const authProvider: AuthProvider = { token: async () => 'api-key' }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); }); it('should throw UnauthorizedError on GET-SSE 401 with no onUnauthorized (via resumeStream)', async () => { const authProvider: AuthProvider = { token: async () => 'api-key' }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); (globalThis.fetch as Mock).mockResolvedValueOnce({ @@ -168,7 +168,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { currentToken = 'new-token'; }) }; - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); // First GET: 401. Second GET (retry): 405 (server doesn't offer SSE — clean exit) @@ -189,7 +189,7 @@ describe('AuthProvider integration — both modes against a real server', () => let server: Server; let serverUrl: URL; let capturedRequests: IncomingMessage[]; - let transport: StreamableHTTPClientTransport; + let transport: LegacyStreamableHTTPClientTransport; const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'ping', params: {}, id: '1' }; @@ -216,7 +216,7 @@ describe('AuthProvider integration — both modes against a real server', () => it('MODE A: minimal AuthProvider { token } sends Authorization header', async () => { const authProvider: AuthProvider = { token: async () => 'mode-a-token' }; - transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(serverUrl, { authProvider }); await transport.send(message); @@ -243,7 +243,7 @@ describe('AuthProvider integration — both modes against a real server', () => }); serverUrl = await listenOnRandomPort(server); - transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + transport = new LegacyStreamableHTTPClientTransport(serverUrl, { authProvider }); await expect(transport.send(message)).rejects.toThrow('user action required'); expect(uiSignal).toHaveBeenCalledWith('show-reauth-prompt'); @@ -274,7 +274,7 @@ describe('AuthProvider integration — both modes against a real server', () => } }; - transport = new StreamableHTTPClientTransport(serverUrl, { authProvider: oauthProvider }); + transport = new LegacyStreamableHTTPClientTransport(serverUrl, { authProvider: oauthProvider }); await transport.send(message); @@ -284,14 +284,14 @@ describe('AuthProvider integration — both modes against a real server', () => it('both modes use the same option slot and same send() call', async () => { // Mode A - const transportA = new StreamableHTTPClientTransport(serverUrl, { + const transportA = new LegacyStreamableHTTPClientTransport(serverUrl, { authProvider: { token: async () => 'a-token' } }); await transportA.send(message); await transportA.close(); // Mode B — same constructor, same option name, different shape - const transportB = new StreamableHTTPClientTransport(serverUrl, { + const transportB = new LegacyStreamableHTTPClientTransport(serverUrl, { authProvider: { get redirectUrl() { return undefined; diff --git a/packages/client/test/client/versionProbing.test.ts b/packages/client/test/client/versionProbing.test.ts new file mode 100644 index 0000000000..a5ecaa63bb --- /dev/null +++ b/packages/client/test/client/versionProbing.test.ts @@ -0,0 +1,768 @@ +/** + * Integration tests for StreamableHTTPClientTransport. + * + * These tests spin up lightweight HTTP mock servers (node:http) that emulate + * the MCP server-side behaviour -- both the modern (2026-06) routing path and + * the legacy (2025-11) streamable-HTTP path -- and verify that the client-side + * probing, fallback, and tool-call flows work end-to-end over real HTTP. + */ +import { randomUUID } from 'node:crypto'; +import type { IncomingMessage, Server, ServerResponse } from 'node:http'; +import { createServer } from 'node:http'; +import type { AddressInfo } from 'node:net'; + +import type { JSONRPCMessage, JSONRPCRequest, JSONRPCResponse } from '@modelcontextprotocol/core'; + +import { Client } from '../../src/client/client.js'; +import { LegacyStreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import { StreamableHTTPClientTransport } from '../../src/client/modernStreamableHttp.js'; + +// --------------------------------------------------------------------------- +// Shared constants +// --------------------------------------------------------------------------- + +const SERVER_NAME = 'test-server'; +const SERVER_VERSION = '1.0.0'; +const TOOL_NAME = 'greet'; + +const SERVER_CAPABILITIES = { + tools: {}, + resources: {}, + prompts: {} +}; + +/** + * The greet tool handler -- shared by both modern and legacy mock servers so + * that the "content equivalence" test can rely on identical output. + */ +function greetToolResult(name: string) { + return { + content: [{ type: 'text' as const, text: `Hello, ${name}!` }] + }; +} + +// --------------------------------------------------------------------------- +// Mock server helpers +// --------------------------------------------------------------------------- + +/** Reads the full body of a Node IncomingMessage and JSON-parses it. */ +async function readJsonBody(req: IncomingMessage): Promise { + return new Promise((resolve, reject) => { + const chunks: Buffer[] = []; + req.on('data', (c: Buffer) => chunks.push(c)); + req.on('end', () => { + try { + resolve(JSON.parse(Buffer.concat(chunks).toString())); + } catch (error) { + reject(error); + } + }); + req.on('error', reject); + }); +} + +/** Sends a JSON-RPC response object as HTTP 200 application/json. */ +function sendJson(res: ServerResponse, body: unknown, status = 200, extraHeaders?: Record): void { + const payload = JSON.stringify(body); + res.writeHead(status, { + 'Content-Type': 'application/json', + 'Content-Length': String(Buffer.byteLength(payload)), + ...extraHeaders + }); + res.end(payload); +} + +/** Sends a JSON-RPC error response. */ +function sendJsonRpcError(res: ServerResponse, id: unknown, code: number, message: string, httpStatus = 200): void { + sendJson(res, { jsonrpc: '2.0', id, error: { code, message } }, httpStatus); +} + +/** Listen on a random port and return the base URL. */ +function listenOnRandomPort(server: Server): Promise { + return new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}/mcp`)); + }); + }); +} + +// --------------------------------------------------------------------------- +// Mock "routing" server (modern + legacy) +// --------------------------------------------------------------------------- + +/** + * Creates a mock HTTP server that supports both modern (2026-06, via + * `Mcp-Method` header) and legacy (2025-11, via initialize handshake) paths. + * + * The modern path responds to `server/discover`, `tools/call`, and `tools/list` + * using the `Mcp-Method` header to route. The legacy path performs a stateful + * initialize handshake and tracks sessions via `Mcp-Session-Id`. + */ +function createRoutingServer(): Server { + const sessions = new Map(); + + return createServer(async (req, res) => { + try { + const mcpMethod = req.headers['mcp-method'] as string | undefined; + + if (mcpMethod) { + // ---- Modern path ---- + if (req.method !== 'POST') { + res.writeHead(405, { Allow: 'POST' }); + res.end(); + return; + } + + const body = (await readJsonBody(req)) as JSONRPCRequest; + + if (mcpMethod === 'server/discover') { + sendJson(res, { + jsonrpc: '2.0', + id: body.id, + result: { + supportedVersions: ['2026-06-30'], + capabilities: SERVER_CAPABILITIES, + serverInfo: { name: SERVER_NAME, version: SERVER_VERSION } + } + } satisfies JSONRPCResponse); + return; + } + + if (mcpMethod === 'tools/list') { + sendJson(res, { + jsonrpc: '2.0', + id: body.id, + result: { + result_type: 'complete', + tools: [ + { + name: TOOL_NAME, + description: 'Greet someone', + inputSchema: { + type: 'object', + properties: { name: { type: 'string' } }, + required: ['name'] + } + } + ] + } + } satisfies JSONRPCResponse); + return; + } + + if (mcpMethod === 'tools/call') { + const args = body.params?.arguments as { name: string }; + sendJson(res, { + jsonrpc: '2.0', + id: body.id, + result: { + result_type: 'complete', + ...greetToolResult(args.name) + } + } satisfies JSONRPCResponse); + return; + } + + sendJsonRpcError(res, body.id, -32_601, `Method not found: ${mcpMethod}`); + return; + } + + // ---- Legacy path ---- + if (req.method === 'GET') { + // SSE stream endpoint -- return 405 (optional per spec) + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (!sessionId || !sessions.has(sessionId)) { + res.writeHead(405); + res.end(); + return; + } + // Keep alive SSE (just open and hold) + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive' + }); + // Don't end -- let the client close + req.on('close', () => res.end()); + return; + } + + if (req.method === 'DELETE') { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) sessions.delete(sessionId); + res.writeHead(200); + res.end(); + return; + } + + if (req.method !== 'POST') { + res.writeHead(405, { Allow: 'POST, GET, DELETE' }); + res.end(); + return; + } + + const body = (await readJsonBody(req)) as JSONRPCMessage; + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + // Check if this is a notification (no id) + if (!('id' in body) || body.id === undefined) { + // Notification -- accept silently + res.writeHead(202); + res.end(); + return; + } + + const rpcReq = body as JSONRPCRequest; + + if (rpcReq.method === 'initialize') { + const newSessionId = randomUUID(); + const params = rpcReq.params as { protocolVersion: string }; + sessions.set(newSessionId, { protocolVersion: params.protocolVersion }); + + sendJson( + res, + { + jsonrpc: '2.0', + id: rpcReq.id, + result: { + protocolVersion: '2025-11-25', + capabilities: SERVER_CAPABILITIES, + serverInfo: { name: SERVER_NAME, version: SERVER_VERSION } + } + } satisfies JSONRPCResponse, + 200, + { 'mcp-session-id': newSessionId } + ); + return; + } + + // All other requests require a session + if (!sessionId || !sessions.has(sessionId)) { + sendJsonRpcError(res, rpcReq.id, -32_000, 'Missing or invalid session', 400); + return; + } + + if (rpcReq.method === 'tools/list') { + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: { + tools: [ + { + name: TOOL_NAME, + description: 'Greet someone', + inputSchema: { + type: 'object', + properties: { name: { type: 'string' } }, + required: ['name'] + } + } + ] + } + } satisfies JSONRPCResponse); + return; + } + + if (rpcReq.method === 'tools/call') { + const args = rpcReq.params?.arguments as { name: string }; + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: greetToolResult(args.name) + } satisfies JSONRPCResponse); + return; + } + + if (rpcReq.method === 'ping') { + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: {} + } satisfies JSONRPCResponse); + return; + } + + sendJsonRpcError(res, rpcReq.id, -32_601, `Method not found: ${rpcReq.method}`); + } catch (error) { + console.error('Mock routing server error:', error); + if (!res.headersSent) { + res.writeHead(500); + res.end(); + } + } + }); +} + +// --------------------------------------------------------------------------- +// Mock "legacy-only" server (no server/discover support) +// --------------------------------------------------------------------------- + +/** + * Creates a mock HTTP server that ONLY supports the legacy (2025-11) path. + * A `server/discover` probe (identified by the `Mcp-Method` header) will + * receive a 404, causing the client to fall back to legacy mode. + */ +function createLegacyOnlyServer(): Server { + const sessions = new Map(); + + return createServer(async (req, res) => { + try { + // If the request has an Mcp-Method header, it's a modern probe -- + // this legacy-only server doesn't support it. + const mcpMethod = req.headers['mcp-method'] as string | undefined; + if (mcpMethod) { + res.writeHead(404); + res.end(); + return; + } + + if (req.method === 'GET') { + // SSE stream endpoint -- return 405 (optional per spec) + res.writeHead(405); + res.end(); + return; + } + + if (req.method === 'DELETE') { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) sessions.delete(sessionId); + res.writeHead(200); + res.end(); + return; + } + + if (req.method !== 'POST') { + res.writeHead(405, { Allow: 'POST, GET, DELETE' }); + res.end(); + return; + } + + const body = (await readJsonBody(req)) as JSONRPCMessage; + const sessionId = req.headers['mcp-session-id'] as string | undefined; + + // Notification + if (!('id' in body) || body.id === undefined) { + res.writeHead(202); + res.end(); + return; + } + + const rpcReq = body as JSONRPCRequest; + + if (rpcReq.method === 'initialize') { + const newSessionId = randomUUID(); + const params = rpcReq.params as { protocolVersion: string }; + sessions.set(newSessionId, { protocolVersion: params.protocolVersion }); + + const payload = JSON.stringify({ + jsonrpc: '2.0', + id: rpcReq.id, + result: { + protocolVersion: '2025-11-25', + capabilities: SERVER_CAPABILITIES, + serverInfo: { name: SERVER_NAME, version: SERVER_VERSION } + } + } satisfies JSONRPCResponse); + + res.writeHead(200, { + 'Content-Type': 'application/json', + 'Content-Length': Buffer.byteLength(payload), + 'mcp-session-id': newSessionId + }); + res.end(payload); + return; + } + + if (!sessionId || !sessions.has(sessionId)) { + sendJsonRpcError(res, rpcReq.id, -32_000, 'Missing or invalid session', 400); + return; + } + + if (rpcReq.method === 'tools/list') { + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: { + tools: [ + { + name: TOOL_NAME, + description: 'Greet someone', + inputSchema: { + type: 'object', + properties: { name: { type: 'string' } }, + required: ['name'] + } + } + ] + } + } satisfies JSONRPCResponse); + return; + } + + if (rpcReq.method === 'tools/call') { + const args = rpcReq.params?.arguments as { name: string }; + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: greetToolResult(args.name) + } satisfies JSONRPCResponse); + return; + } + + if (rpcReq.method === 'ping') { + sendJson(res, { + jsonrpc: '2.0', + id: rpcReq.id, + result: {} + } satisfies JSONRPCResponse); + return; + } + + sendJsonRpcError(res, rpcReq.id, -32_601, `Method not found: ${rpcReq.method}`); + } catch (error) { + console.error('Mock legacy server error:', error); + if (!res.headersSent) { + res.writeHead(500); + res.end(); + } + } + }); +} + +// --------------------------------------------------------------------------- +// Helper to close a server gracefully +// --------------------------------------------------------------------------- + +function closeServer(server: Server): Promise { + return new Promise((resolve, reject) => { + server.close(err => (err ? reject(err) : resolve())); + }); +} + +// =========================================================================== +// Tests +// =========================================================================== + +describe('StreamableHTTPClientTransport', () => { + // ----------------------------------------------------------------------- + // 1. Modern client + routing server + // ----------------------------------------------------------------------- + describe('modern client + routing server', () => { + let server: Server; + let baseUrl: URL; + + beforeAll(async () => { + server = createRoutingServer(); + baseUrl = await listenOnRandomPort(server); + }); + + afterAll(async () => { + await closeServer(server); + }); + + it('probes server/discover and enters modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + try { + await transport.start(); + + expect(transport.mode).toBe('modern'); + expect(transport.getDiscoverResult()).toBeDefined(); + expect(transport.getDiscoverResult()!.supportedVersions).toContain('2026-06-30'); + expect(transport.getDiscoverResult()!.serverInfo.name).toBe(SERVER_NAME); + } finally { + await transport.close(); + } + }); + + it('callTool works via Client in modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + try { + await transport.start(); + expect(transport.mode).toBe('modern'); + + await client.connect(transport); + + const result = await client.callTool({ name: TOOL_NAME, arguments: { name: 'World' } }); + expect(result.content).toEqual([{ type: 'text', text: 'Hello, World!' }]); + } finally { + await client.close(); + } + }); + + it('listTools works via Client in modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + try { + await transport.start(); + await client.connect(transport); + + const result = await client.listTools(); + expect(result.tools).toHaveLength(1); + expect(result.tools[0]!.name).toBe(TOOL_NAME); + } finally { + await client.close(); + } + }); + + it('getServerCapabilities returns capabilities from discover', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + try { + await transport.start(); + await client.connect(transport); + + const caps = client.getServerCapabilities(); + expect(caps).toBeDefined(); + expect(caps!.tools).toBeDefined(); + expect(caps!.resources).toBeDefined(); + expect(caps!.prompts).toBeDefined(); + } finally { + await client.close(); + } + }); + }); + + // ----------------------------------------------------------------------- + // 2. Modern client + legacy-only server + // ----------------------------------------------------------------------- + describe('modern client + legacy-only server', () => { + let server: Server; + let baseUrl: URL; + + beforeAll(async () => { + server = createLegacyOnlyServer(); + baseUrl = await listenOnRandomPort(server); + }); + + afterAll(async () => { + await closeServer(server); + }); + + it('probe fails gracefully and falls back to legacy mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + try { + await transport.start(); + + expect(transport.mode).toBe('legacy'); + expect(transport.getDiscoverResult()).toBeUndefined(); + } finally { + await transport.close(); + } + }); + + it('callTool works via Client in legacy fallback mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + try { + await transport.start(); + expect(transport.mode).toBe('legacy'); + + // connect() performs the initialize handshake in legacy mode + await client.connect(transport); + + const result = await client.callTool({ name: TOOL_NAME, arguments: { name: 'World' } }); + expect(result.content).toEqual([{ type: 'text', text: 'Hello, World!' }]); + } finally { + await client.close(); + } + }); + + it('listTools works via Client in legacy fallback mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + try { + await transport.start(); + await client.connect(transport); + + const result = await client.listTools(); + expect(result.tools).toHaveLength(1); + expect(result.tools[0]!.name).toBe(TOOL_NAME); + } finally { + await client.close(); + } + }); + }); + + // ----------------------------------------------------------------------- + // 3. Legacy client + routing server + // ----------------------------------------------------------------------- + describe('legacy client + routing server', () => { + let server: Server; + let baseUrl: URL; + + beforeAll(async () => { + server = createRoutingServer(); + baseUrl = await listenOnRandomPort(server); + }); + + afterAll(async () => { + await closeServer(server); + }); + + it('callTool works via plain LegacyStreamableHTTPClientTransport (no probe)', async () => { + const transport = new LegacyStreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'legacy-client', version: '1.0.0' }); + try { + // Plain LegacyStreamableHTTPClientTransport does not probe -- it goes + // straight to the initialize handshake which the routing server + // routes to the legacy path (no Mcp-Method header). + await client.connect(transport); + + const result = await client.callTool({ name: TOOL_NAME, arguments: { name: 'World' } }); + expect(result.content).toEqual([{ type: 'text', text: 'Hello, World!' }]); + } finally { + await client.close(); + } + }); + + it('listTools works via plain LegacyStreamableHTTPClientTransport', async () => { + const transport = new LegacyStreamableHTTPClientTransport(baseUrl); + const client = new Client({ name: 'legacy-client', version: '1.0.0' }); + try { + await client.connect(transport); + + const result = await client.listTools(); + expect(result.tools).toHaveLength(1); + expect(result.tools[0]!.name).toBe(TOOL_NAME); + } finally { + await client.close(); + } + }); + }); + + // ----------------------------------------------------------------------- + // 4. Legacy-only methods throw in modern mode + // ----------------------------------------------------------------------- + describe('legacy-only methods in modern mode', () => { + let server: Server; + let baseUrl: URL; + + beforeAll(async () => { + server = createRoutingServer(); + baseUrl = await listenOnRandomPort(server); + }); + + afterAll(async () => { + await closeServer(server); + }); + + it('terminateSession() throws in modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + await transport.start(); + expect(transport.mode).toBe('modern'); + + try { + await expect(transport.terminateSession()).rejects.toThrow('terminateSession() is not available in modern protocol mode'); + } finally { + await transport.close(); + } + }); + + it('resumeStream() throws in modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + await transport.start(); + expect(transport.mode).toBe('modern'); + + try { + await expect(transport.resumeStream('some-event-id')).rejects.toThrow( + 'resumeStream() is not available in modern protocol mode' + ); + } finally { + await transport.close(); + } + }); + + it('sessionId returns undefined in modern mode', async () => { + const transport = new StreamableHTTPClientTransport(baseUrl); + await transport.start(); + expect(transport.mode).toBe('modern'); + + try { + expect(transport.sessionId).toBeUndefined(); + } finally { + await transport.close(); + } + }); + + it('terminateSession() works in legacy mode', async () => { + const legacyServer = createLegacyOnlyServer(); + const legacyUrl = await listenOnRandomPort(legacyServer); + + const transport = new StreamableHTTPClientTransport(legacyUrl); + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + try { + await transport.start(); + expect(transport.mode).toBe('legacy'); + await client.connect(transport); + + // Should not throw in legacy mode + await expect(transport.terminateSession()).resolves.not.toThrow(); + } finally { + await client.close(); + await closeServer(legacyServer); + } + }); + }); + + // ----------------------------------------------------------------------- + // 5. Content equivalence across all 3 combinations + // ----------------------------------------------------------------------- + describe('content equivalence', () => { + let routingServer: Server; + let legacyServer: Server; + let routingUrl: URL; + let legacyUrl: URL; + + beforeAll(async () => { + routingServer = createRoutingServer(); + legacyServer = createLegacyOnlyServer(); + [routingUrl, legacyUrl] = await Promise.all([listenOnRandomPort(routingServer), listenOnRandomPort(legacyServer)]); + }); + + afterAll(async () => { + await Promise.all([closeServer(routingServer), closeServer(legacyServer)]); + }); + + it('same tool call returns identical content across all 3 combinations', async () => { + const toolArgs = { name: TOOL_NAME, arguments: { name: 'Alice' } }; + + // -- Combination 1: Modern client + routing server -- + const modernTransport = new StreamableHTTPClientTransport(routingUrl); + const modernClient = new Client({ name: 'modern-client', version: '1.0.0' }); + await modernTransport.start(); + expect(modernTransport.mode).toBe('modern'); + await modernClient.connect(modernTransport); + const modernResult = await modernClient.callTool(toolArgs); + + // -- Combination 2: Modern client + legacy server (fallback) -- + const fallbackTransport = new StreamableHTTPClientTransport(legacyUrl); + const fallbackClient = new Client({ name: 'fallback-client', version: '1.0.0' }); + await fallbackTransport.start(); + expect(fallbackTransport.mode).toBe('legacy'); + await fallbackClient.connect(fallbackTransport); + const fallbackResult = await fallbackClient.callTool(toolArgs); + + // -- Combination 3: Legacy client + routing server -- + const legacyTransport = new LegacyStreamableHTTPClientTransport(routingUrl); + const legacyClient = new Client({ name: 'legacy-client', version: '1.0.0' }); + await legacyClient.connect(legacyTransport); + const legacyResult = await legacyClient.callTool(toolArgs); + + // All three should return identical content + const expectedContent = [{ type: 'text', text: 'Hello, Alice!' }]; + expect(modernResult.content).toEqual(expectedContent); + expect(fallbackResult.content).toEqual(expectedContent); + expect(legacyResult.content).toEqual(expectedContent); + + // Cross-check: they match each other + expect(modernResult.content).toEqual(fallbackResult.content); + expect(modernResult.content).toEqual(legacyResult.content); + + // Cleanup + await Promise.all([modernClient.close(), fallbackClient.close(), legacyClient.close()]); + }); + }); +}); diff --git a/packages/core/src/errors/sdkErrors.ts b/packages/core/src/errors/sdkErrors.ts index 8d5e34c14e..bf47bff74b 100644 --- a/packages/core/src/errors/sdkErrors.ts +++ b/packages/core/src/errors/sdkErrors.ts @@ -18,6 +18,8 @@ export enum SdkErrorCode { // Capability errors /** Required capability is not supported by the remote side */ CapabilityNotSupported = 'CAPABILITY_NOT_SUPPORTED', + /** Operation is not supported on the current protocol path */ + UnsupportedOperation = 'UNSUPPORTED_OPERATION', // Transport errors /** Request timed out waiting for response */ diff --git a/packages/core/src/experimental/index.ts b/packages/core/src/experimental/index.ts deleted file mode 100644 index ea39eb79f6..0000000000 --- a/packages/core/src/experimental/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from './tasks/helpers.js'; -export * from './tasks/interfaces.js'; -export * from './tasks/stores/inMemory.js'; diff --git a/packages/core/src/experimental/tasks/helpers.ts b/packages/core/src/experimental/tasks/helpers.ts deleted file mode 100644 index 7a13fffbd3..0000000000 --- a/packages/core/src/experimental/tasks/helpers.ts +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Experimental task capability assertion helpers. - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - -import { SdkError, SdkErrorCode } from '../../errors/sdkErrors.js'; - -/** - * Type representing the task requests capability structure. - * This is derived from `ClientTasksCapability.requests` and `ServerTasksCapability.requests`. - */ -interface TaskRequestsCapability { - tools?: { call?: object }; - sampling?: { createMessage?: object }; - elicitation?: { create?: object }; -} - -/** - * Asserts that task creation is supported for `tools/call`. - * Used to implement the `assertTaskCapability` or `assertTaskHandlerCapability` abstract methods on Protocol. - * - * @param requests - The task requests capability object - * @param method - The method being checked - * @param entityName - `'Server'` or `'Client'` for error messages - * @throws {@linkcode SdkError} with {@linkcode SdkErrorCode.CapabilityNotSupported} if the capability is not supported - * - * @experimental - */ -export function assertToolsCallTaskCapability( - requests: TaskRequestsCapability | undefined, - method: string, - entityName: 'Server' | 'Client' -): void { - if (!requests) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `${entityName} does not support task creation (required for ${method})`); - } - - switch (method) { - case 'tools/call': { - if (!requests.tools?.call) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `${entityName} does not support task creation for tools/call (required for ${method})` - ); - } - break; - } - - default: { - // Method doesn't support tasks, which is fine - no error - break; - } - } -} - -/** - * Asserts that task creation is supported for `sampling/createMessage` or `elicitation/create`. - * Used to implement the `assertTaskCapability` or `assertTaskHandlerCapability` abstract methods on Protocol. - * - * @param requests - The task requests capability object - * @param method - The method being checked - * @param entityName - `'Server'` or `'Client'` for error messages - * @throws {@linkcode SdkError} with {@linkcode SdkErrorCode.CapabilityNotSupported} if the capability is not supported - * - * @experimental - */ -export function assertClientRequestTaskCapability( - requests: TaskRequestsCapability | undefined, - method: string, - entityName: 'Server' | 'Client' -): void { - if (!requests) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `${entityName} does not support task creation (required for ${method})`); - } - - switch (method) { - case 'sampling/createMessage': { - if (!requests.sampling?.createMessage) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `${entityName} does not support task creation for sampling/createMessage (required for ${method})` - ); - } - break; - } - - case 'elicitation/create': { - if (!requests.elicitation?.create) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `${entityName} does not support task creation for elicitation/create (required for ${method})` - ); - } - break; - } - - default: { - // Method doesn't support tasks, which is fine - no error - break; - } - } -} diff --git a/packages/core/src/experimental/tasks/interfaces.ts b/packages/core/src/experimental/tasks/interfaces.ts deleted file mode 100644 index d980f304ca..0000000000 --- a/packages/core/src/experimental/tasks/interfaces.ts +++ /dev/null @@ -1,243 +0,0 @@ -/** - * Experimental task interfaces for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - */ - -import type { ServerContext } from '../../shared/protocol.js'; -import type { RequestTaskStore } from '../../shared/taskManager.js'; -import type { - JSONRPCErrorResponse, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResultResponse, - Request, - RequestId, - Result, - Task, - ToolExecution -} from '../../types/index.js'; - -// ============================================================================ -// Task Handler Types (for registerToolTask) -// ============================================================================ - -/** - * Server context with guaranteed task store for task creation. - * @experimental - */ -export type CreateTaskServerContext = ServerContext & { - task: { store: RequestTaskStore; requestedTtl?: number }; -}; - -/** - * Server context with guaranteed task ID and store for task operations. - * @experimental - */ -export type TaskServerContext = ServerContext & { - task: { id: string; store: RequestTaskStore; requestedTtl?: number }; -}; - -/** - * Task-specific execution configuration. - * `taskSupport` cannot be `'forbidden'` for task-based tools. - * @experimental - */ -export type TaskToolExecution = Omit & { - taskSupport: TaskSupport extends 'forbidden' | undefined ? never : TaskSupport; -}; - -/** - * Represents a message queued for side-channel delivery via tasks/result. - * - * This is a serializable data structure that can be stored in external systems. - * All fields are JSON-serializable. - */ -export type QueuedMessage = QueuedRequest | QueuedNotification | QueuedResponse | QueuedError; - -export interface BaseQueuedMessage { - /** Type of message */ - type: string; - /** When the message was queued (milliseconds since epoch) */ - timestamp: number; -} - -export interface QueuedRequest extends BaseQueuedMessage { - type: 'request'; - /** The actual JSONRPC request */ - message: JSONRPCRequest; -} - -export interface QueuedNotification extends BaseQueuedMessage { - type: 'notification'; - /** The actual JSONRPC notification */ - message: JSONRPCNotification; -} - -export interface QueuedResponse extends BaseQueuedMessage { - type: 'response'; - /** The actual JSONRPC response */ - message: JSONRPCResultResponse; -} - -export interface QueuedError extends BaseQueuedMessage { - type: 'error'; - /** The actual JSONRPC error */ - message: JSONRPCErrorResponse; -} - -/** - * Interface for managing per-task FIFO message queues. - * - * Similar to {@linkcode TaskStore}, this allows pluggable queue implementations - * (in-memory, Redis, other distributed queues, etc.). - * - * Each method accepts taskId and optional sessionId parameters to enable - * a single queue instance to manage messages for multiple tasks, with - * isolation based on task ID and session ID. - * - * All methods are async to support external storage implementations. - * All data in {@linkcode QueuedMessage} must be JSON-serializable. - * - * @see {@linkcode InMemoryTaskMessageQueue} for a reference implementation - * @experimental - */ -export interface TaskMessageQueue { - /** - * Adds a message to the end of the queue for a specific task. - * Atomically checks queue size and throws if maxSize would be exceeded. - * @param taskId The task identifier - * @param message The message to enqueue - * @param sessionId Optional session ID for binding the operation to a specific session - * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error - * @throws Error if maxSize is specified and would be exceeded - */ - enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise; - - /** - * Removes and returns the first message from the queue for a specific task. - * @param taskId The task identifier - * @param sessionId Optional session ID for binding the query to a specific session - * @returns The first message, or `undefined` if the queue is empty - */ - dequeue(taskId: string, sessionId?: string): Promise; - - /** - * Removes and returns all messages from the queue for a specific task. - * Used when tasks are cancelled or failed to clean up pending messages. - * @param taskId The task identifier - * @param sessionId Optional session ID for binding the query to a specific session - * @returns Array of all messages that were in the queue - */ - dequeueAll(taskId: string, sessionId?: string): Promise; -} - -/** - * Task creation options. - * @experimental - */ -export interface CreateTaskOptions { - /** - * Duration in milliseconds to retain task from creation. - * If `null`, the task has unlimited lifetime until manually cleaned up. - */ - ttl?: number | null; - - /** - * Time in milliseconds to wait between task status requests. - */ - pollInterval?: number; - - /** - * Additional context to pass to the task store. - */ - context?: Record; -} - -/** - * Interface for storing and retrieving task state and results. - * - * Similar to {@linkcode Transport}, this allows pluggable task storage implementations - * (in-memory, database, distributed cache, etc.). - * - * @see {@linkcode InMemoryTaskStore} for a reference implementation - * @experimental - */ -export interface TaskStore { - /** - * Creates a new task with the given creation parameters and original request. - * The implementation must generate a unique taskId and createdAt timestamp. - * - * TTL Management: - * - The implementation receives the TTL suggested by the requestor via `taskParams.ttl` - * - The implementation MAY override the requested TTL (e.g., to enforce limits) - * - The actual TTL used MUST be returned in the {@linkcode Task} object - * - `null` TTL indicates unlimited task lifetime (no automatic cleanup) - * - Cleanup SHOULD occur automatically after TTL expires, regardless of task status - * - * @param taskParams - The task creation parameters from the request (ttl, pollInterval) - * @param requestId - The JSON-RPC request ID - * @param request - The original request that triggered task creation - * @param sessionId - Optional session ID for binding the task to a specific session - * @returns The created {@linkcode Task} object - */ - createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, sessionId?: string): Promise; - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @param sessionId - Optional session ID for binding the query to a specific session - * @returns The {@linkcode Task} object, or `null` if it does not exist - */ - getTask(taskId: string, sessionId?: string): Promise; - - /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: `'completed'` for success, `'failed'` for errors - * @param result - The result to store - * @param sessionId - Optional session ID for binding the operation to a specific session - */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, sessionId?: string): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @param sessionId - Optional session ID for binding the query to a specific session - * @returns The stored result - */ - getTaskResult(taskId: string, sessionId?: string): Promise; - - /** - * Updates a task's status (e.g., to `'cancelled'`, `'failed'`, `'completed'`). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - * @param sessionId - Optional session ID for binding the operation to a specific session - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, sessionId?: string): Promise; - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @param sessionId - Optional session ID for binding the query to a specific session - * @returns An object containing the tasks array and an optional nextCursor - */ - listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; -} - -/** - * Checks if a task status represents a terminal state. - * Terminal states are those where the task has finished and will not change. - * - * @param status - The task status to check - * @returns `true` if the status is terminal (`completed`, `failed`, or `cancelled`) - * @experimental - */ -export function isTerminal(status: Task['status']): boolean { - return status === 'completed' || status === 'failed' || status === 'cancelled'; -} diff --git a/packages/core/src/experimental/tasks/stores/inMemory.ts b/packages/core/src/experimental/tasks/stores/inMemory.ts deleted file mode 100644 index fbd7e39f53..0000000000 --- a/packages/core/src/experimental/tasks/stores/inMemory.ts +++ /dev/null @@ -1,313 +0,0 @@ -/** - * In-memory implementations of {@linkcode TaskStore} and {@linkcode TaskMessageQueue}. - * @experimental - */ - -import type { Request, RequestId, Result, Task } from '../../../types/index.js'; -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../interfaces.js'; -import { isTerminal } from '../interfaces.js'; - -interface StoredTask { - task: Task; - request: Request; - requestId: RequestId; - sessionId?: string; - result?: Result; -} - -/** - * In-memory {@linkcode TaskStore} implementation for development and testing. - * For production, use a database or distributed cache. - * @experimental - */ -export class InMemoryTaskStore implements TaskStore { - private tasks = new Map(); - private cleanupTimers = new Map>(); - - /** - * Generates a unique task ID using Web Crypto API. - */ - private generateTaskId(): string { - return crypto.randomUUID().replaceAll('-', ''); - } - - /** {@inheritDoc TaskStore.createTask} */ - async createTask(taskParams: CreateTaskOptions, requestId: RequestId, request: Request, sessionId?: string): Promise { - // Generate a unique task ID - const taskId = this.generateTaskId(); - - // Ensure uniqueness - if (this.tasks.has(taskId)) { - throw new Error(`Task with ID ${taskId} already exists`); - } - - const actualTtl = taskParams.ttl ?? null; - - // Create task with generated ID and timestamps - const createdAt = new Date().toISOString(); - const task: Task = { - taskId, - status: 'working', - ttl: actualTtl, - createdAt, - lastUpdatedAt: createdAt, - pollInterval: taskParams.pollInterval ?? 1000 - }; - - this.tasks.set(taskId, { - task, - request, - requestId, - sessionId - }); - - // Schedule cleanup if ttl is specified - // Cleanup occurs regardless of task status - if (actualTtl) { - const timer = setTimeout(() => { - this.tasks.delete(taskId); - this.cleanupTimers.delete(taskId); - }, actualTtl); - - this.cleanupTimers.set(taskId, timer); - } - - return task; - } - - /** - * Retrieves a stored task, enforcing session ownership when a sessionId is provided. - * Returns undefined if the task does not exist or belongs to a different session. - */ - private getStoredTask(taskId: string, sessionId?: string): StoredTask | undefined { - const stored = this.tasks.get(taskId); - if (!stored) { - return undefined; - } - // Enforce session isolation: if a sessionId is provided and the task - // was created with a sessionId, they must match. - if (sessionId !== undefined && stored.sessionId !== undefined && stored.sessionId !== sessionId) { - return undefined; - } - return stored; - } - - async getTask(taskId: string, sessionId?: string): Promise { - const stored = this.getStoredTask(taskId, sessionId); - return stored ? { ...stored.task } : null; - } - - /** {@inheritDoc TaskStore.storeTaskResult} */ - async storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result, sessionId?: string): Promise { - const stored = this.getStoredTask(taskId, sessionId); - if (!stored) { - throw new Error(`Task with ID ${taskId} not found`); - } - - // Don't allow storing results for tasks already in terminal state - if (isTerminal(stored.task.status)) { - throw new Error( - `Cannot store result for task ${taskId} in terminal status '${stored.task.status}'. Task results can only be stored once.` - ); - } - - stored.result = result; - stored.task.status = status; - stored.task.lastUpdatedAt = new Date().toISOString(); - - // Reset cleanup timer to start from now (if ttl is set) - if (stored.task.ttl) { - const existingTimer = this.cleanupTimers.get(taskId); - if (existingTimer) { - clearTimeout(existingTimer); - } - - const timer = setTimeout(() => { - this.tasks.delete(taskId); - this.cleanupTimers.delete(taskId); - }, stored.task.ttl); - - this.cleanupTimers.set(taskId, timer); - } - } - - /** {@inheritDoc TaskStore.getTaskResult} */ - async getTaskResult(taskId: string, sessionId?: string): Promise { - const stored = this.getStoredTask(taskId, sessionId); - if (!stored) { - throw new Error(`Task with ID ${taskId} not found`); - } - - if (!stored.result) { - throw new Error(`Task ${taskId} has no result stored`); - } - - return stored.result; - } - - /** {@inheritDoc TaskStore.updateTaskStatus} */ - async updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string, sessionId?: string): Promise { - const stored = this.getStoredTask(taskId, sessionId); - if (!stored) { - throw new Error(`Task with ID ${taskId} not found`); - } - - // Don't allow transitions from terminal states - if (isTerminal(stored.task.status)) { - throw new Error( - `Cannot update task ${taskId} from terminal status '${stored.task.status}' to '${status}'. Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - - stored.task.status = status; - if (statusMessage) { - stored.task.statusMessage = statusMessage; - } - - stored.task.lastUpdatedAt = new Date().toISOString(); - - // If task is in a terminal state and has ttl, start cleanup timer - if (isTerminal(status) && stored.task.ttl) { - const existingTimer = this.cleanupTimers.get(taskId); - if (existingTimer) { - clearTimeout(existingTimer); - } - - const timer = setTimeout(() => { - this.tasks.delete(taskId); - this.cleanupTimers.delete(taskId); - }, stored.task.ttl); - - this.cleanupTimers.set(taskId, timer); - } - } - - /** {@inheritDoc TaskStore.listTasks} */ - async listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { - const PAGE_SIZE = 10; - - // Filter tasks by session ownership before pagination - const filteredTaskIds = [...this.tasks.entries()] - .filter(([, stored]) => { - if (sessionId === undefined || stored.sessionId === undefined) { - return true; - } - return stored.sessionId === sessionId; - }) - .map(([taskId]) => taskId); - - let startIndex = 0; - if (cursor) { - const cursorIndex = filteredTaskIds.indexOf(cursor); - if (cursorIndex === -1) { - // Invalid cursor - throw error - throw new Error(`Invalid cursor: ${cursor}`); - } else { - startIndex = cursorIndex + 1; - } - } - - const pageTaskIds = filteredTaskIds.slice(startIndex, startIndex + PAGE_SIZE); - const tasks = pageTaskIds.map(taskId => { - const stored = this.tasks.get(taskId)!; - return { ...stored.task }; - }); - - const nextCursor = startIndex + PAGE_SIZE < filteredTaskIds.length ? pageTaskIds.at(-1) : undefined; - - return { tasks, nextCursor }; - } - - /** - * Cleanup all timers (useful for testing or graceful shutdown) - */ - cleanup(): void { - for (const timer of this.cleanupTimers.values()) { - clearTimeout(timer); - } - this.cleanupTimers.clear(); - this.tasks.clear(); - } - - /** - * Get all tasks (useful for debugging) - */ - getAllTasks(): Task[] { - return [...this.tasks.values()].map(stored => ({ ...stored.task })); - } -} - -/** - * In-memory {@linkcode TaskMessageQueue} implementation for development and testing. - * For production, use Redis or another distributed queue. - * @experimental - */ -export class InMemoryTaskMessageQueue implements TaskMessageQueue { - private queues = new Map(); - - /** - * Generates a queue key from taskId. - * SessionId is intentionally ignored because taskIds are globally unique - * and tasks need to be accessible across HTTP requests/sessions. - */ - private getQueueKey(taskId: string, _sessionId?: string): string { - return taskId; - } - - /** - * Gets or creates a queue for the given task and session. - */ - private getQueue(taskId: string, sessionId?: string): QueuedMessage[] { - const key = this.getQueueKey(taskId, sessionId); - let queue = this.queues.get(key); - if (!queue) { - queue = []; - this.queues.set(key, queue); - } - return queue; - } - - /** - * Adds a message to the end of the queue for a specific task. - * Atomically checks queue size and throws if maxSize would be exceeded. - * @param taskId The task identifier - * @param message The message to enqueue - * @param sessionId Optional session ID for binding the operation to a specific session - * @param maxSize Optional maximum queue size - if specified and queue is full, throws an error - * @throws Error if maxSize is specified and would be exceeded - */ - async enqueue(taskId: string, message: QueuedMessage, sessionId?: string, maxSize?: number): Promise { - const queue = this.getQueue(taskId, sessionId); - - // Atomically check size and enqueue - if (maxSize !== undefined && queue.length >= maxSize) { - throw new Error(`Task message queue overflow: queue size (${queue.length}) exceeds maximum (${maxSize})`); - } - - queue.push(message); - } - - /** - * Removes and returns the first message from the queue for a specific task. - * @param taskId The task identifier - * @param sessionId Optional session ID for binding the query to a specific session - * @returns The first message, or `undefined` if the queue is empty - */ - async dequeue(taskId: string, sessionId?: string): Promise { - const queue = this.getQueue(taskId, sessionId); - return queue.shift(); - } - - /** - * Removes and returns all messages from the queue for a specific task. - * @param taskId The task identifier - * @param sessionId Optional session ID for binding the query to a specific session - * @returns Array of all messages that were in the queue - */ - async dequeueAll(taskId: string, sessionId?: string): Promise { - const key = this.getQueueKey(taskId, sessionId); - const queue = this.queues.get(key) ?? []; - this.queues.delete(key); - return queue; - } -} diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index e305f32a44..43faa35ccc 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -45,24 +45,16 @@ export type { NotificationOptions, ProgressCallback, ProtocolOptions, - RequestHandlerSchemas, RequestOptions, ServerContext } from '../../shared/protocol.js'; export { DEFAULT_REQUEST_TIMEOUT_MSEC } from '../../shared/protocol.js'; -// Task manager types (NOT TaskManager class itself — internal) -export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from '../../shared/taskManager.js'; +// Handler registry types +export type { RequestHandlerSchemas } from '../../shared/handlerRegistry.js'; // Response message types -export type { - BaseResponseMessage, - ErrorMessage, - ResponseMessage, - ResultMessage, - TaskCreatedMessage, - TaskStatusMessage -} from '../../shared/responseMessage.js'; +export type { BaseResponseMessage, ErrorMessage, ResponseMessage, ResultMessage } from '../../shared/responseMessage.js'; export { takeResult, toArrayAsync } from '../../shared/responseMessage.js'; // stdio message framing utilities (for custom transport authors) @@ -92,7 +84,6 @@ export { LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, PARSE_ERROR, - RELATED_TASK_META_KEY, SUPPORTED_PROTOCOL_VERSIONS } from '../../types/constants.js'; @@ -114,29 +105,9 @@ export { isJSONRPCRequest, isJSONRPCResponse, isJSONRPCResultResponse, - isTaskAugmentedRequestParams, parseJSONRPCMessage } from '../../types/guards.js'; -// Experimental task types and classes -export { assertClientRequestTaskCapability, assertToolsCallTaskCapability } from '../../experimental/tasks/helpers.js'; -export type { - BaseQueuedMessage, - CreateTaskOptions, - CreateTaskServerContext, - QueuedError, - QueuedMessage, - QueuedNotification, - QueuedRequest, - QueuedResponse, - TaskMessageQueue, - TaskServerContext, - TaskStore, - TaskToolExecution -} from '../../experimental/tasks/interfaces.js'; -export { isTerminal } from '../../experimental/tasks/interfaces.js'; -export { InMemoryTaskMessageQueue, InMemoryTaskStore } from '../../experimental/tasks/stores/inMemory.js'; - // Validator types and classes export type { SpecTypeName, SpecTypes } from '../../types/specTypeSchema.js'; export { isSpecType, specTypeSchemas } from '../../types/specTypeSchema.js'; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 8bcc9c9591..30c8f799d6 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,12 +2,11 @@ export * from './auth/errors.js'; export * from './errors/sdkErrors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/handlerRegistry.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; -export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js'; -export { extractTaskManagerOptions, NullTaskManager, TaskManager } from './shared/taskManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; export * from './shared/uriTemplate.js'; @@ -16,9 +15,6 @@ export * from './util/inMemory.js'; export * from './util/schema.js'; export * from './util/standardSchema.js'; export * from './util/zodCompat.js'; - -// experimental exports -export * from './experimental/index.js'; export * from './validators/ajvProvider.js'; // cfWorkerProvider is intentionally NOT re-exported here: it statically imports // `@cfworker/json-schema` (an optional peer), and bundling it into the main barrel diff --git a/packages/core/src/shared/handlerRegistry.ts b/packages/core/src/shared/handlerRegistry.ts new file mode 100644 index 0000000000..1a169b00b1 --- /dev/null +++ b/packages/core/src/shared/handlerRegistry.ts @@ -0,0 +1,302 @@ +import type { + ClientCapabilities, + JSONRPCNotification, + JSONRPCRequest, + Notification, + NotificationMethod, + NotificationTypeMap, + RequestMethod, + RequestTypeMap, + Result, + ResultTypeMap, + ServerCapabilities +} from '../types/index.js'; +import { getNotificationSchema, getRequestSchema, ProtocolError, ProtocolErrorCode } from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { validateStandardSchema } from '../util/standardSchema.js'; +import type { BaseContext } from './protocol.js'; + +/** + * A function that handles an incoming JSON-RPC request and returns a result. + */ +export type RequestHandler = (request: JSONRPCRequest, ctx: ContextT) => Promise; + +/** + * A function that handles an incoming JSON-RPC notification. + */ +export type NotificationHandler = (notification: JSONRPCNotification) => Promise; + +/** + * Schema bundle accepted by `setRequestHandler`'s 3-arg form. + * + * `params` is required and validates the inbound `request.params`. `result` is optional; + * when supplied it types the handler's return value (no runtime validation is performed + * on the result). + */ +export interface RequestHandlerSchemas< + P extends StandardSchemaV1 = StandardSchemaV1, + R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined +> { + params: P; + result?: R; +} + +/** + * Infers the handler return type from an optional result schema. + * When `R` is a `StandardSchemaV1`, the return type is the schema's output type. + * When `R` is `undefined`, the return type falls back to the generic `Result`. + */ +export type InferHandlerResult = R extends StandardSchemaV1 + ? StandardSchemaV1.InferOutput + : Result; + +/** + * Options for constructing a {@linkcode HandlerRegistry}. + */ +export interface HandlerRegistryOptions { + /** + * Initial capabilities. These are shallow-merged with any capabilities + * registered later via {@linkcode HandlerRegistry.registerCapabilities | registerCapabilities()}. + */ + capabilities?: Caps; + + /** + * Optional callback invoked during `setRequestHandler()` + * to assert that registering a handler for this method is valid given the + * declared capabilities. For example, a server may reject handler registration + * for `tools/call` unless `capabilities.tools` is declared. + */ + assertRequestHandlerCapability?: (method: string) => void; + + /** + * Optional callback that wraps every registered request handler with + * role-specific validation or behavior (e.g., `Server` validates `tools/call` + * results). The default behavior is identity (no wrapping). + */ + wrapHandler?: (method: string, handler: RequestHandler) => RequestHandler; +} + +/** + * Owns handler maps, schema parsing, and capability management. + * + * `HandlerRegistry` is a standalone class extracted from `Protocol` so that + * multiple protocol instances (or routers) can share or compose handler sets + * without being coupled to transport or connection lifecycle. + */ +export class HandlerRegistry { + private _requestHandlers: Map> = new Map(); + private _notificationHandlers: Map = new Map(); + private _capabilities: Caps; + assertRequestHandlerCapability?: (method: string) => void; + wrapHandler?: (method: string, handler: RequestHandler) => RequestHandler; + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: RequestHandler; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + constructor(options?: HandlerRegistryOptions) { + this._capabilities = (options?.capabilities ?? {}) as Caps; + this.assertRequestHandlerCapability = options?.assertRequestHandlerCapability; + this.wrapHandler = options?.wrapHandler; + } + + /** + * Read-only view of the registered request handlers. + */ + get requestHandlers(): ReadonlyMap> { + return this._requestHandlers; + } + + /** + * Read-only view of the registered notification handlers. + */ + get notificationHandlers(): ReadonlyMap { + return this._notificationHandlers; + } + + // ----------------------------------------------------------------------- + // Capabilities + // ----------------------------------------------------------------------- + + /** + * Merges additional capabilities into the existing capability set. + */ + registerCapabilities(caps: Partial): void { + this._capabilities = mergeCapabilities(this._capabilities, caps) as Caps; + } + + /** + * Returns the current capability set. + */ + getCapabilities(): Caps { + return this._capabilities; + } + + // ----------------------------------------------------------------------- + // Request handler registration + // ----------------------------------------------------------------------- + + /** + * Registers a handler to invoke when a request with the given method is received. + * + * Note that this will replace any previous request handler for the same method. + * + * For spec methods, pass `(method, handler)`; the request is parsed with the spec + * schema and the handler receives the typed `Request`. For custom (non-spec) + * methods, pass `(method, schemas, handler)`; `params` are validated against + * `schemas.params` and the handler receives the parsed params object directly. + * Supplying `schemas.result` types the handler's return value. + */ + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ContextT) => InferHandlerResult | Promise> + ): void; + setRequestHandler( + method: string, + schemasOrHandler: RequestHandlerSchemas | ((request: unknown, ctx: ContextT) => Result | Promise), + maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise + ): void { + this.assertRequestHandlerCapability?.(method); + + let stored: RequestHandler; + + if (typeof schemasOrHandler === 'function') { + const schema = getRequestSchema(method); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` + ); + } + stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); + } else if (maybeHandler) { + stored = async (request, ctx) => { + const userParams = { ...request.params }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); + } + return maybeHandler(parsed.data, ctx); + }; + } else { + throw new TypeError('setRequestHandler: handler is required'); + } + + const wrapped = this.wrapHandler ? this.wrapHandler(method, stored) : stored; + this._requestHandlers.set(method, wrapped); + } + + /** + * Removes the request handler for the given method. + */ + removeRequestHandler(method: RequestMethod | string): void { + this._requestHandlers.delete(method); + } + + /** + * Asserts that a request handler has not already been set for the given method, + * in preparation for a new one being automatically installed. + */ + assertCanSetRequestHandler(method: RequestMethod | string): void { + if (this._requestHandlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + + // ----------------------------------------------------------------------- + // Notification handler registration + // ----------------------------------------------------------------------- + + /** + * Registers a handler to invoke when a notification with the given method is received. + * + * Note that this will replace any previous notification handler for the same method. + * + * For spec methods, pass `(method, handler)`; the notification is parsed with the + * spec schema. For custom (non-spec) methods, pass `(method, schemas, handler)`; + * `params` are validated against `schemas.params` and the handler receives the + * parsed params object directly. The raw notification is passed as the second + * argument; `_meta` is recoverable via `notification.params?._meta`. + */ + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void; + setNotificationHandler

( + method: string, + schemas: { params: P }, + handler: (params: StandardSchemaV1.InferOutput

, notification: Notification) => void | Promise + ): void; + setNotificationHandler( + method: string, + schemasOrHandler: { params: StandardSchemaV1 } | ((notification: unknown) => void | Promise), + maybeHandler?: (params: unknown, notification: Notification) => void | Promise + ): void { + if (typeof schemasOrHandler === 'function') { + const schema = getNotificationSchema(method); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().` + ); + } + this._notificationHandlers.set(method, notification => Promise.resolve(schemasOrHandler(schema.parse(notification)))); + return; + } + + if (!maybeHandler) { + throw new TypeError('setNotificationHandler: handler is required'); + } + this._notificationHandlers.set(method, async notification => { + const userParams = { ...notification.params }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`); + } + await maybeHandler(parsed.data, notification); + }); + } + + /** + * Removes the notification handler for the given method. + */ + removeNotificationHandler(method: NotificationMethod | string): void { + this._notificationHandlers.delete(method); + } +} + +// --------------------------------------------------------------------------- +// Capability merging helpers +// --------------------------------------------------------------------------- + +function isPlainObject(value: unknown): value is Record { + return value !== null && typeof value === 'object' && !Array.isArray(value); +} + +export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; +export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; +export function mergeCapabilities(base: T, additional: Partial): T { + const result: T = { ...base }; + for (const key in additional) { + const k = key as keyof T; + const addValue = additional[k]; + if (addValue === undefined) continue; + const baseValue = result[k]; + result[k] = + isPlainObject(baseValue) && isPlainObject(addValue) + ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) + : (addValue as T[typeof k]); + } + return result; +} diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 361bd6fc7c..baaa948080 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -2,7 +2,6 @@ import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; import type { AuthInfo, CancelledNotification, - ClientCapabilities, CreateMessageRequest, CreateMessageResult, CreateMessageResultWithTools, @@ -21,20 +20,15 @@ import type { NotificationTypeMap, Progress, ProgressNotification, - RelatedTaskMetadata, Request, RequestId, RequestMeta, RequestMethod, RequestTypeMap, Result, - ResultTypeMap, - ServerCapabilities, - TaskCreationParams + ResultTypeMap } from '../types/index.js'; import { - getNotificationSchema, - getRequestSchema, getResultSchema, isJSONRPCErrorResponse, isJSONRPCNotification, @@ -46,8 +40,7 @@ import { } from '../types/index.js'; import type { StandardSchemaV1 } from '../util/standardSchema.js'; import { isStandardSchema, validateStandardSchema } from '../util/standardSchema.js'; -import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; -import { NullTaskManager, TaskManager } from './taskManager.js'; +import type { HandlerRegistry, InferHandlerResult, RequestHandlerSchemas } from './handlerRegistry.js'; import type { Transport, TransportSendOptions } from './transport.js'; /** @@ -82,16 +75,6 @@ export type ProtocolOptions = { * e.g., `['notifications/tools/list_changed']` */ debouncedNotificationMethods?: string[]; - - /** - * Runtime configuration for task management. - * If provided, creates a TaskManager with the given options; otherwise a NullTaskManager is used. - * - * Capability assertions are wired automatically from the protocol's - * `assertTaskCapability()` and `assertTaskHandlerCapability()` methods, - * so they should NOT be included here. - */ - tasks?: TaskManagerOptions; }; /** @@ -105,8 +88,6 @@ export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; export type RequestOptions = { /** * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - * - * For task-augmented requests: progress notifications continue after {@linkcode CreateTaskResult} is returned and stop automatically when the task reaches a terminal status. */ onprogress?: ProgressCallback; @@ -135,16 +116,6 @@ export type RequestOptions = { * If not specified, there is no maximum total timeout. */ maxTotalTimeout?: number; - - /** - * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. - */ - task?: TaskCreationParams; - - /** - * If provided, associates this request with a related task. - */ - relatedTask?: RelatedTaskMetadata; } & TransportSendOptions; /** @@ -155,11 +126,6 @@ export type NotificationOptions = { * May be used to indicate to the transport which incoming request to associate this outgoing notification with. */ relatedRequestId?: RequestId; - - /** - * If provided, associates this notification with a related task. - */ - relatedTask?: RelatedTaskMetadata; }; /** @@ -206,12 +172,12 @@ export type BaseContext = { send: { ( request: { method: M; params?: Record }, - options?: TaskRequestOptions + options?: RequestOptions ): Promise; ( request: Request, resultSchema: T, - options?: TaskRequestOptions + options?: RequestOptions ): Promise>; }; @@ -232,11 +198,6 @@ export type BaseContext = { */ authInfo?: AuthInfo; }; - - /** - * Task context, available when task storage is configured. - */ - task?: TaskContext; }; /** @@ -311,18 +272,18 @@ type TimeoutInfo = { export abstract class Protocol { private _transport?: Transport; private _requestMessageId = 0; - private _requestHandlers: Map Promise> = new Map(); private _requestHandlerAbortControllers: Map = new Map(); - private _notificationHandlers: Map Promise> = new Map(); private _responseHandlers: Map void> = new Map(); private _progressHandlers: Map = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); - private _taskManager: TaskManager; - protected _supportedProtocolVersions: string[]; + protected setTransport(transport: Transport | undefined): void { + this._transport = transport; + } + /** * Callback for when the connection is closed for any reason. * @@ -340,68 +301,45 @@ export abstract class Protocol { /** * A handler to invoke for any request types that do not have their own handler installed. */ - fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; + get fallbackRequestHandler() { + return this._registry.fallbackRequestHandler; + } + set fallbackRequestHandler(h) { + this._registry.fallbackRequestHandler = h; + } /** * A handler to invoke for any notification types that do not have their own handler installed. */ - fallbackNotificationHandler?: (notification: Notification) => Promise; + get fallbackNotificationHandler() { + return this._registry.fallbackNotificationHandler; + } + set fallbackNotificationHandler(h) { + this._registry.fallbackNotificationHandler = h; + } - constructor(private _options?: ProtocolOptions) { + constructor( + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Caps type varies by subclass + protected _registry: HandlerRegistry, + private _options?: ProtocolOptions + ) { this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; - // Create TaskManager from protocol options - this._taskManager = _options?.tasks ? new TaskManager(_options.tasks) : new NullTaskManager(); - this._bindTaskManager(); - - this.setNotificationHandler('notifications/cancelled', notification => { + this._registry.setNotificationHandler('notifications/cancelled', notification => { this._oncancel(notification); }); - this.setNotificationHandler('notifications/progress', notification => { + this._registry.setNotificationHandler('notifications/progress', notification => { this._onprogress(notification); }); - this.setRequestHandler( + this._registry.setRequestHandler( 'ping', // Automatic pong by default. _request => ({}) as Result ); } - /** - * Access the TaskManager for task orchestration. - * Always available; returns a NullTaskManager when no task store is configured. - */ - get taskManager(): TaskManager { - return this._taskManager; - } - - private _bindTaskManager(): void { - const taskManager = this._taskManager; - const host: TaskManagerHost = { - request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options), - notification: (notification, options) => this.notification(notification, options), - reportError: error => this._onerror(error), - removeProgressHandler: token => this._progressHandlers.delete(token), - registerHandler: (method, handler) => { - const schema = getRequestSchema(method as RequestMethod); - this._requestHandlers.set(method, (request, ctx) => { - // Validate request params via Zod (strips jsonrpc/id, so we pass original to handler) - schema.parse(request); - return handler(request, ctx); - }); - }, - sendOnResponseStream: async (message, relatedRequestId) => { - await this._transport?.send(message, { relatedRequestId }); - }, - enforceStrictCapabilities: this._options?.enforceStrictCapabilities === true, - assertTaskCapability: method => this.assertTaskCapability(method), - assertTaskHandlerCapability: method => this.assertTaskHandlerCapability(method) - }; - taskManager.bind(host); - } - /** * Builds the context object for request handlers. Subclasses must override * to return the appropriate context type (e.g., ServerContext adds HTTP request info). @@ -506,7 +444,6 @@ export abstract class Protocol { const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); - this._taskManager.onClose(); this._pendingDebouncedNotifications.clear(); for (const info of this._timeoutInfo.values()) { @@ -539,7 +476,7 @@ export abstract class Protocol { } private _onnotification(notification: JSONRPCNotification): void { - const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; + const handler = this._registry.notificationHandlers.get(notification.method) ?? this._registry.fallbackNotificationHandler; // Ignore notifications not being subscribed to. if (handler === undefined) { @@ -553,29 +490,11 @@ export abstract class Protocol { } private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + const handler = this._registry.requestHandlers.get(request.method) ?? this._registry.fallbackRequestHandler; // Capture the current transport at request time to ensure responses go to the correct client const capturedTransport = this._transport; - // Delegate context extraction to module (if registered) - const inboundCtx = { - sessionId: capturedTransport?.sessionId, - sendNotification: (notification: Notification, options?: NotificationOptions) => - this.notification(notification, { ...options, relatedRequestId: request.id }), - sendRequest: (r: Request, resultSchema: U, options?: RequestOptions) => - this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }) - }; - - // Delegate to TaskManager for task context, wrapped send/notify, and response routing - const taskResult = this._taskManager.processInboundRequest(request, inboundCtx); - const sendNotification = taskResult.sendNotification; - const sendRequest = taskResult.sendRequest; - const taskContext = taskResult.taskContext; - const routeResponse = taskResult.routeResponse; - const validators: Array<() => void> = []; - if (taskResult.validateInbound) validators.push(taskResult.validateInbound); - if (handler === undefined) { const errorResponse: JSONRPCErrorResponse = { jsonrpc: '2.0', @@ -586,22 +505,18 @@ export abstract class Protocol { } }; - // Queue or send the error response based on whether this is a task-related request - routeResponse(errorResponse) - .then(routed => { - if (!routed) { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } - }) - .catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); + capturedTransport?.send(errorResponse).catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); return; } const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); + const sendNotification = (notification: Notification, options?: NotificationOptions) => + this.notification(notification, { ...options, relatedRequestId: request.id }); + const sendRequest = (r: Request, resultSchema: U, options?: RequestOptions) => + this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }); + const baseCtx: BaseContext = { sessionId: capturedTransport?.sessionId, mcpReq: { @@ -609,11 +524,7 @@ export abstract class Protocol { method: request.method, _meta: request.params?._meta, signal: abortController.signal, - // BaseContext.mcpReq.send is declared with two overloads (spec-method-keyed and explicit-schema). Arrow - // literals can't carry overload signatures, so the inferred single-signature type isn't assignable to - // that overloaded property type. The cast is sound: this impl dispatches both overload paths via the - // isStandardSchema guard, and sendRequest validates the result against the resolved schema either way. - send: ((r: Request, schemaOrOptions?: StandardSchemaV1 | TaskRequestOptions, maybeOptions?: TaskRequestOptions) => { + send: ((r: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions) => { if (isStandardSchema(schemaOrOptions)) { return sendRequest(r, schemaOrOptions, maybeOptions); } @@ -627,23 +538,16 @@ export abstract class Protocol { }) as BaseContext['mcpReq']['send'], notify: sendNotification }, - http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, - task: taskContext + http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined }; const ctx = this.buildContext(baseCtx, extra); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() - .then(() => { - for (const validate of validators) { - validate(); - } - }) .then(() => handler(request, ctx)) .then( async result => { if (abortController.signal.aborted) { - // Request was cancelled return; } @@ -653,15 +557,10 @@ export abstract class Protocol { id: request.id }; - // Queue or send the response based on whether this is a task-related request - const routed = await routeResponse(response); - if (!routed) { - await capturedTransport?.send(response); - } + await capturedTransport?.send(response); }, async error => { if (abortController.signal.aborted) { - // Request was cancelled return; } @@ -675,11 +574,7 @@ export abstract class Protocol { } }; - // Queue or send the error response based on whether this is a task-related request - const routed = await routeResponse(errorResponse); - if (!routed) { - await capturedTransport?.send(errorResponse); - } + await capturedTransport?.send(errorResponse); } ) .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) @@ -722,11 +617,6 @@ export abstract class Protocol { private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { const messageId = Number(response.id); - // Delegate to TaskManager for task-related response handling - const taskResult = this._taskManager.processInboundResponse(response, messageId); - if (taskResult.consumed) return; - const preserveProgress = taskResult.preserveProgress; - const handler = this._responseHandlers.get(messageId); if (handler === undefined) { this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); @@ -735,11 +625,7 @@ export abstract class Protocol { this._responseHandlers.delete(messageId); this._cleanupTimeout(messageId); - - // Keep progress handler alive for CreateTaskResult responses - if (!preserveProgress) { - this._progressHandlers.delete(messageId); - } + this._progressHandlers.delete(messageId); if (isJSONRPCResultResponse(response)) { handler(response); @@ -774,29 +660,6 @@ export abstract class Protocol { */ protected abstract assertNotificationCapability(method: NotificationMethod | string): void; - /** - * A method to check if a request handler is supported by the local side, for the given method to be handled. - * - * This should be implemented by subclasses. - */ - protected abstract assertRequestHandlerCapability(method: string): void; - - /** - * A method to check if the remote side supports task creation for the given method. - * - * Called when sending a task-augmented outbound request (only when enforceStrictCapabilities is true). - * This should be implemented by subclasses. - */ - protected abstract assertTaskCapability(method: string): void; - - /** - * A method to check if this side supports handling task creation for the given method. - * - * Called when receiving a task-augmented inbound request. - * This should be implemented by subclasses. - */ - protected abstract assertTaskHandlerCapability(method: string): void; - /** * Sends a request and waits for a response. * @@ -938,44 +801,12 @@ export abstract class Protocol { this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - // Delegate task augmentation and routing to module (if registered) - const responseHandler = (response: JSONRPCResultResponse | Error) => { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - - let outboundQueued = false; - try { - const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { - this._progressHandlers.delete(messageId); - reject(error); - }); - if (taskResult.queued) { - outboundQueued = true; - } - } catch (error) { + this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { this._progressHandlers.delete(messageId); reject(error); - return; - } - - if (!outboundQueued) { - // No related task or no module - send through transport normally - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { - this._progressHandlers.delete(messageId); - reject(error); - }); - } + }); }).finally(() => { - // Per-request cleanup that must run on every exit path. Consolidated - // here so new exit paths added to the promise body can't forget it. - // _progressHandlers is NOT cleaned up here: _onresponse deletes it - // conditionally (preserveProgress for task flows), and error paths - // above delete it inline since no task exists in those cases. + // Per-request cleanup that must run on every exit path. if (onAbort) { options?.signal?.removeEventListener('abort', onAbort); } @@ -996,21 +827,13 @@ export abstract class Protocol { this.assertNotificationCapability(notification.method); - // Delegate task-related notification routing and JSONRPC building to TaskManager - const taskResult = await this._taskManager.processOutboundNotification(notification, options); - const queued = taskResult.queued; - const jsonrpcNotification = taskResult.queued ? undefined : taskResult.jsonrpcNotification; - - if (queued) { - // Don't send through transport - queued messages are delivered via tasks/result only - return; - } + const jsonrpcNotification: JSONRPCNotification = { + jsonrpc: '2.0', + ...notification + }; const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; - // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID or related task that could be lost). - const canDebounce = - debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId && !options?.relatedTask; + const canDebounce = debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId; if (canDebounce) { // If a notification of this type is already scheduled, do nothing. @@ -1044,6 +867,10 @@ export abstract class Protocol { await this._transport.send(jsonrpcNotification!, options); } + // ----------------------------------------------------------------------- + // Handler registration — delegates to HandlerRegistry + // ----------------------------------------------------------------------- + /** * Registers a handler to invoke when this protocol object receives a request with the given method. * @@ -1079,64 +906,25 @@ export abstract class Protocol { schemasOrHandler: RequestHandlerSchemas | ((request: unknown, ctx: ContextT) => Result | Promise), maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise ): void { - this.assertRequestHandlerCapability(method); - - let stored: (request: JSONRPCRequest, ctx: ContextT) => Promise; - - if (typeof schemasOrHandler === 'function') { - const schema = getRequestSchema(method); - if (!schema) { - throw new TypeError( - `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` - ); - } - stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); - } else if (maybeHandler) { - stored = async (request, ctx) => { - const userParams = { ...request.params }; - delete userParams._meta; - const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); - if (!parsed.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); - } - return maybeHandler(parsed.data, ctx); - }; + if (maybeHandler) { + this._registry.setRequestHandler(method, schemasOrHandler as RequestHandlerSchemas, maybeHandler); } else { - throw new TypeError('setRequestHandler: handler is required'); + (this._registry.setRequestHandler as (...a: unknown[]) => void).call(this._registry, method, schemasOrHandler); } - - this._requestHandlers.set(method, this._wrapHandler(method, stored)); - } - - /** - * Hook for subclasses to wrap a registered request handler with role-specific - * validation or behavior (e.g. `Server` validates `tools/call` results, `Client` - * validates `elicitation/create` mode and result). Runs for both the 2-arg and - * 3-arg registration paths. The default implementation is identity. - * - * Subclasses overriding this hook avoid redeclaring `setRequestHandler`'s overload set. - */ - protected _wrapHandler( - _method: string, - handler: (request: JSONRPCRequest, ctx: ContextT) => Promise - ): (request: JSONRPCRequest, ctx: ContextT) => Promise { - return handler; } /** * Removes the request handler for the given method. */ removeRequestHandler(method: RequestMethod | string): void { - this._requestHandlers.delete(method); + this._registry.removeRequestHandler(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ assertCanSetRequestHandler(method: RequestMethod | string): void { - if (this._requestHandlers.has(method)) { - throw new Error(`A request handler for ${method} already exists, which would be overridden`); - } + this._registry.assertCanSetRequestHandler(method); } /** @@ -1164,73 +952,17 @@ export abstract class Protocol { schemasOrHandler: { params: StandardSchemaV1 } | ((notification: unknown) => void | Promise), maybeHandler?: (params: unknown, notification: Notification) => void | Promise ): void { - if (typeof schemasOrHandler === 'function') { - const schema = getNotificationSchema(method); - if (!schema) { - throw new TypeError( - `'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().` - ); - } - this._notificationHandlers.set(method, notification => Promise.resolve(schemasOrHandler(schema.parse(notification)))); - return; - } - - if (!maybeHandler) { - throw new TypeError('setNotificationHandler: handler is required'); + if (maybeHandler) { + this._registry.setNotificationHandler(method, schemasOrHandler as { params: StandardSchemaV1 }, maybeHandler); + } else { + (this._registry.setNotificationHandler as (...a: unknown[]) => void).call(this._registry, method, schemasOrHandler); } - this._notificationHandlers.set(method, async notification => { - const userParams = { ...notification.params }; - delete userParams._meta; - const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); - if (!parsed.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`); - } - await maybeHandler(parsed.data, notification); - }); } /** * Removes the notification handler for the given method. */ removeNotificationHandler(method: NotificationMethod | string): void { - this._notificationHandlers.delete(method); - } -} - -/** - * Schema bundle accepted by {@linkcode Protocol.setRequestHandler | setRequestHandler}'s 3-arg form. - * - * `params` is required and validates the inbound `request.params`. `result` is optional; - * when supplied it types the handler's return value (no runtime validation is performed - * on the result). - */ -export interface RequestHandlerSchemas< - P extends StandardSchemaV1 = StandardSchemaV1, - R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined -> { - params: P; - result?: R; -} - -type InferHandlerResult = R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result; - -function isPlainObject(value: unknown): value is Record { - return value !== null && typeof value === 'object' && !Array.isArray(value); -} - -export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; -export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; -export function mergeCapabilities(base: T, additional: Partial): T { - const result: T = { ...base }; - for (const key in additional) { - const k = key as keyof T; - const addValue = additional[k]; - if (addValue === undefined) continue; - const baseValue = result[k]; - result[k] = - isPlainObject(baseValue) && isPlainObject(addValue) - ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) - : (addValue as T[typeof k]); + this._registry.removeNotificationHandler(method); } - return result; } diff --git a/packages/core/src/shared/responseMessage.ts b/packages/core/src/shared/responseMessage.ts index 25922a355f..6a7269282a 100644 --- a/packages/core/src/shared/responseMessage.ts +++ b/packages/core/src/shared/responseMessage.ts @@ -1,4 +1,4 @@ -import type { Result, Task } from '../types/index.js'; +import type { Result } from '../types/index.js'; /** * Base message type for the response stream. @@ -7,28 +7,6 @@ export interface BaseResponseMessage { type: string; } -/** - * Task status update message. - * - * Yielded on each poll iteration while the task is active (e.g. while - * `working`). May be emitted multiple times with the same status. - */ -export interface TaskStatusMessage extends BaseResponseMessage { - type: 'taskStatus'; - task: Task; -} - -/** - * Task created message. - * - * Yielded once when the server creates a new task for a long-running operation. - * This is always the first message for task-augmented requests. - */ -export interface TaskCreatedMessage extends BaseResponseMessage { - type: 'taskCreated'; - task: Task; -} - /** * Final result message. * @@ -51,20 +29,15 @@ export interface ErrorMessage extends BaseResponseMessage { } /** - * Union of all message types yielded by task-aware streaming APIs such as - * {@linkcode @modelcontextprotocol/client!experimental/tasks/client.ExperimentalClientTasks#callToolStream | callToolStream()}, - * {@linkcode @modelcontextprotocol/client!experimental/tasks/client.ExperimentalClientTasks#requestStream | ExperimentalClientTasks.requestStream()}, and - * {@linkcode @modelcontextprotocol/server!experimental/tasks/server.ExperimentalServerTasks#requestStream | ExperimentalServerTasks.requestStream()}. + * Union of all message types yielded by streaming response APIs. * * A typical sequence is: - * 1. `taskCreated` — task is registered (once) - * 2. `taskStatus` — zero or more progress updates - * 3. `result` **or** `error` — terminal message (once) + * 1. `result` **or** `error` — terminal message (once) * * Progress notifications are handled through the existing {@linkcode index.RequestOptions | onprogress} callback. * Side-channeled messages (server requests/notifications) are handled through registered handlers. */ -export type ResponseMessage = TaskStatusMessage | TaskCreatedMessage | ResultMessage | ErrorMessage; +export type ResponseMessage = ResultMessage | ErrorMessage; export type AsyncGeneratorValue = T extends AsyncGenerator ? U : never; @@ -81,9 +54,8 @@ export async function toArrayAsync>(it: T): Pr } /** - * Consumes a {@linkcode ResponseMessage} stream and returns the final result, - * discarding intermediate `taskCreated` and `taskStatus` messages. Throws - * if an `error` message is received or the stream ends without a result. + * Consumes a {@linkcode ResponseMessage} stream and returns the final result. + * Throws if an `error` message is received or the stream ends without a result. */ export async function takeResult>>(it: U): Promise { for await (const o of it) { diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts deleted file mode 100644 index 257dbec827..0000000000 --- a/packages/core/src/shared/taskManager.ts +++ /dev/null @@ -1,915 +0,0 @@ -import type { CreateTaskOptions, QueuedMessage, TaskMessageQueue, TaskStore } from '../experimental/tasks/interfaces.js'; -import { isTerminal } from '../experimental/tasks/interfaces.js'; -import type { - GetTaskPayloadRequest, - GetTaskRequest, - GetTaskResult, - JSONRPCErrorResponse, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - JSONRPCResultResponse, - Notification, - Request, - RequestId, - Result, - Task, - TaskCreationParams, - TaskStatusNotification -} from '../types/index.js'; -import { - CancelTaskResultSchema, - CreateTaskResultSchema, - GetTaskResultSchema, - isJSONRPCErrorResponse, - isJSONRPCRequest, - isJSONRPCResultResponse, - isTaskAugmentedRequestParams, - ListTasksResultSchema, - ProtocolError, - ProtocolErrorCode, - RELATED_TASK_META_KEY, - TaskStatusNotificationSchema -} from '../types/index.js'; -import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; -import type { StandardSchemaV1 } from '../util/standardSchema.js'; -import type { BaseContext, NotificationOptions, RequestOptions } from './protocol.js'; -import type { ResponseMessage } from './responseMessage.js'; - -/** - * Host interface for TaskManager to call back into Protocol. @internal - */ -export interface TaskManagerHost { - request( - request: Request, - resultSchema: T, - options?: RequestOptions - ): Promise>; - notification(notification: Notification, options?: NotificationOptions): Promise; - reportError(error: Error): void; - removeProgressHandler(token: number): void; - registerHandler(method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise): void; - sendOnResponseStream(message: JSONRPCNotification | JSONRPCRequest, relatedRequestId: RequestId): Promise; - enforceStrictCapabilities: boolean; - assertTaskCapability(method: string): void; - assertTaskHandlerCapability(method: string): void; -} - -/** - * Context provided to TaskManager when processing an inbound request. - * @internal - */ -export interface InboundContext { - sessionId?: string; - sendNotification: (notification: Notification, options?: NotificationOptions) => Promise; - sendRequest: ( - request: Request, - resultSchema: U, - options?: RequestOptions - ) => Promise>; -} - -/** - * Result returned by TaskManager after processing an inbound request. - * @internal - */ -export interface InboundResult { - taskContext?: BaseContext['task']; - sendNotification: (notification: Notification) => Promise; - sendRequest: ( - request: Request, - resultSchema: U, - options?: Omit - ) => Promise>; - routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise; - hasTaskCreationParams: boolean; - /** - * Optional validation to run inside the async handler chain (before the request handler). - * Throwing here produces a proper JSON-RPC error response, matching the behavior of - * capability checks on main. - */ - validateInbound?: () => void; -} - -/** - * Options that can be given per request. - */ -// relatedTask is excluded as the SDK controls if this is sent according to if the source is a task. -export type TaskRequestOptions = Omit; - -/** - * Request-scoped TaskStore interface. - */ -export interface RequestTaskStore { - /** - * Creates a new task with the given creation parameters. - * The implementation generates a unique taskId and createdAt timestamp. - * - * @param taskParams - The task creation parameters from the request - * @returns The created task object - */ - createTask(taskParams: CreateTaskOptions): Promise; - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @returns The task object - * @throws If the task does not exist - */ - getTask(taskId: string): Promise; - - /** - * Stores the result of a task and sets its final status. - * - * @param taskId - The task identifier - * @param status - The final status: 'completed' for success, 'failed' for errors - * @param result - The result to store - */ - storeTaskResult(taskId: string, status: 'completed' | 'failed', result: Result): Promise; - - /** - * Retrieves the stored result of a task. - * - * @param taskId - The task identifier - * @returns The stored result - */ - getTaskResult(taskId: string): Promise; - - /** - * Updates a task's status (e.g., to 'cancelled', 'failed', 'completed'). - * - * @param taskId - The task identifier - * @param status - The new status - * @param statusMessage - Optional diagnostic message for failed tasks or other status information - */ - updateTaskStatus(taskId: string, status: Task['status'], statusMessage?: string): Promise; - - /** - * Lists tasks, optionally starting from a pagination cursor. - * - * @param cursor - Optional cursor for pagination - * @returns An object containing the tasks array and an optional nextCursor - */ - listTasks(cursor?: string): Promise<{ tasks: Task[]; nextCursor?: string }>; -} - -/** - * Task context provided to request handlers when task storage is configured. - */ -export type TaskContext = { - id?: string; - store: RequestTaskStore; - requestedTtl?: number; -}; - -export type TaskManagerOptions = { - /** - * Task storage implementation. Required for handling incoming task requests (server-side). - * Not required for sending task requests (client-side outbound API). - */ - taskStore?: TaskStore; - /** - * Optional task message queue implementation for managing server-initiated messages - * that will be delivered through the tasks/result response stream. - */ - taskMessageQueue?: TaskMessageQueue; - /** - * Default polling interval (in milliseconds) for task status checks when no pollInterval - * is provided by the server. Defaults to 1000ms if not specified. - */ - defaultTaskPollInterval?: number; - /** - * Maximum number of messages that can be queued per task for side-channel delivery. - * If undefined, the queue size is unbounded. - */ - maxTaskQueueSize?: number; -}; - -/** - * Extracts {@linkcode TaskManagerOptions} from a capability object that mixes in runtime fields. - * Returns `undefined` when no task capability is configured. - */ -export function extractTaskManagerOptions(tasksCapability: TaskManagerOptions | undefined): TaskManagerOptions | undefined { - if (!tasksCapability) return undefined; - const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize } = tasksCapability; - return { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize }; -} - -/** - * Manages task orchestration: state, message queuing, and polling. - * Capability checking is delegated to the Protocol host. - * @internal - */ -export class TaskManager { - private _taskStore?: TaskStore; - private _taskMessageQueue?: TaskMessageQueue; - private _taskProgressTokens: Map = new Map(); - private _requestResolvers: Map void> = new Map(); - private _options: TaskManagerOptions; - private _host?: TaskManagerHost; - - constructor(options: TaskManagerOptions) { - this._options = options; - this._taskStore = options.taskStore; - this._taskMessageQueue = options.taskMessageQueue; - } - - bind(host: TaskManagerHost): void { - this._host = host; - - if (this._taskStore) { - host.registerHandler('tasks/get', async (request, ctx) => { - const params = request.params as { taskId: string }; - const task = await this.handleGetTask(params.taskId, ctx.sessionId); - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - return { - ...task - } as Result; - }); - - host.registerHandler('tasks/result', async (request, ctx) => { - const params = request.params as { taskId: string }; - return await this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await host.sendOnResponseStream(message, ctx.mcpReq.id); - }); - }); - - host.registerHandler('tasks/list', async (request, ctx) => { - const params = request.params as { cursor?: string } | undefined; - return (await this.handleListTasks(params?.cursor, ctx.sessionId)) as Result; - }); - - host.registerHandler('tasks/cancel', async (request, ctx) => { - const params = request.params as { taskId: string }; - return await this.handleCancelTask(params.taskId, ctx.sessionId); - }); - } - } - - protected get _requireHost(): TaskManagerHost { - if (!this._host) { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskManager is not bound to a Protocol host — call bind() first'); - } - return this._host; - } - - get taskStore(): TaskStore | undefined { - return this._taskStore; - } - - private get _requireTaskStore(): TaskStore { - if (!this._taskStore) { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskStore is not configured'); - } - return this._taskStore; - } - - get taskMessageQueue(): TaskMessageQueue | undefined { - return this._taskMessageQueue; - } - - // -- Public API (client-facing) -- - async *requestStream( - request: Request, - resultSchema: T, - options?: RequestOptions - ): AsyncGenerator>, void, void> { - const host = this._requireHost; - const { task } = options ?? {}; - - if (!task) { - try { - // TODO: SchemaOutput (Zod) and StandardSchemaV1.InferOutput (host.request's return) - // resolve to the same type for Zod schemas, but TS can't unify them generically. - // Removing this cast requires aligning ResponseMessage with StandardSchema. - const result = (await host.request(request, resultSchema, options)) as SchemaOutput; - yield { type: 'result', result }; - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - return; - } - - let taskId: string | undefined; - try { - const createResult = await host.request(request, CreateTaskResultSchema, options); - - if (createResult.task) { - taskId = createResult.task.taskId; - yield { type: 'taskCreated', task: createResult.task }; - } else { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'Task creation did not return a task'); - } - - while (true) { - const task = await this.getTask({ taskId }, options); - yield { type: 'taskStatus', task }; - - if (isTerminal(task.status)) { - switch (task.status) { - case 'completed': - case 'failed': { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - break; - } - case 'cancelled': { - yield { - type: 'error', - error: new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} was cancelled`) - }; - break; - } - } - return; - } - - if (task.status === 'input_required') { - const result = await this.getTaskResult({ taskId }, resultSchema, options); - yield { type: 'result', result }; - return; - } - - const pollInterval = task.pollInterval ?? this._options.defaultTaskPollInterval ?? 1000; - await new Promise(resolve => setTimeout(resolve, pollInterval)); - options?.signal?.throwIfAborted(); - } - } catch (error) { - yield { - type: 'error', - error: error instanceof Error ? error : new Error(String(error)) - }; - } - } - - async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - return this._requireHost.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); - } - - async getTaskResult( - params: GetTaskPayloadRequest['params'], - resultSchema: T, - options?: RequestOptions - ): Promise> { - // TODO: same SchemaOutput vs StandardSchemaV1.InferOutput mismatch as requestStream above. - return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options) as Promise>; - } - - async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - return this._requireHost.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); - } - - async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - return this._requireHost.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); - } - - // -- Handler bodies (delegated from Protocol's registered handlers) -- - - private async handleGetTask(taskId: string, sessionId?: string): Promise { - const task = await this._requireTaskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - } - return task; - } - - private async handleGetTaskPayload( - taskId: string, - sessionId: string | undefined, - signal: AbortSignal, - sendOnResponseStream: (message: JSONRPCNotification | JSONRPCRequest) => Promise - ): Promise { - const handleTaskResult = async (): Promise => { - if (this._taskMessageQueue) { - let queuedMessage: QueuedMessage | undefined; - while ((queuedMessage = await this._taskMessageQueue.dequeue(taskId, sessionId))) { - if (queuedMessage.type === 'response' || queuedMessage.type === 'error') { - const message = queuedMessage.message; - const requestId = message.id; - const resolver = this._requestResolvers.get(requestId as RequestId); - - if (resolver) { - this._requestResolvers.delete(requestId as RequestId); - if (queuedMessage.type === 'response') { - resolver(message as JSONRPCResultResponse); - } else { - const errorMessage = message as JSONRPCErrorResponse; - resolver(new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data)); - } - } else { - const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._host?.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); - } - continue; - } - - await sendOnResponseStream(queuedMessage.message as JSONRPCNotification | JSONRPCRequest); - } - } - - const task = await this._requireTaskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); - } - - if (!isTerminal(task.status)) { - await this._waitForTaskUpdate(task.pollInterval, signal); - return await handleTaskResult(); - } - - const result = await this._requireTaskStore.getTaskResult(taskId, sessionId); - await this._clearTaskQueue(taskId); - - return { - ...result, - _meta: { - ...result._meta, - [RELATED_TASK_META_KEY]: { taskId } - } - }; - }; - - return await handleTaskResult(); - } - - private async handleListTasks( - cursor: string | undefined, - sessionId?: string - ): Promise<{ tasks: Task[]; nextCursor?: string; _meta: Record }> { - try { - const { tasks, nextCursor } = await this._requireTaskStore.listTasks(cursor, sessionId); - return { tasks, nextCursor, _meta: {} }; - } catch (error) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Failed to list tasks: ${error instanceof Error ? error.message : String(error)}` - ); - } - } - - private async handleCancelTask(taskId: string, sessionId?: string): Promise { - try { - const task = await this._requireTaskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found: ${taskId}`); - } - - if (isTerminal(task.status)) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Cannot cancel task in terminal status: ${task.status}`); - } - - await this._requireTaskStore.updateTaskStatus(taskId, 'cancelled', 'Client cancelled task execution.', sessionId); - await this._clearTaskQueue(taskId); - - const cancelledTask = await this._requireTaskStore.getTask(taskId, sessionId); - if (!cancelledTask) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task not found after cancellation: ${taskId}`); - } - - return { _meta: {}, ...cancelledTask }; - } catch (error) { - if (error instanceof ProtocolError) throw error; - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Failed to cancel task: ${error instanceof Error ? error.message : String(error)}` - ); - } - } - - // -- Internal delegation methods -- - - private prepareOutboundRequest( - jsonrpcRequest: JSONRPCRequest, - options: RequestOptions | undefined, - messageId: number, - responseHandler: (response: JSONRPCResultResponse | Error) => void, - onError: (error: unknown) => void - ): boolean { - const { task, relatedTask } = options ?? {}; - - if (task) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - task: task - }; - } - - if (relatedTask) { - jsonrpcRequest.params = { - ...jsonrpcRequest.params, - _meta: { - ...jsonrpcRequest.params?._meta, - [RELATED_TASK_META_KEY]: relatedTask - } - }; - } - - const relatedTaskId = relatedTask?.taskId; - if (relatedTaskId) { - this._requestResolvers.set(messageId, responseHandler); - - this._enqueueTaskMessage(relatedTaskId, { - type: 'request', - message: jsonrpcRequest, - timestamp: Date.now() - }).catch(error => { - onError(error); - }); - - return true; - } - - return false; - } - - private extractInboundTaskContext( - request: JSONRPCRequest, - sessionId?: string - ): { - relatedTaskId?: string; - taskCreationParams?: TaskCreationParams; - taskContext?: TaskContext; - } { - const relatedTaskId = (request.params?._meta as Record | undefined)?.[RELATED_TASK_META_KEY]?.taskId; - const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined; - - // Provide task context whenever a task store is configured, - // not just for task-related requests — tools need ctx.task.store - let taskContext: TaskContext | undefined; - if (this._taskStore) { - const store = this.createRequestTaskStore(request, sessionId); - taskContext = { - id: relatedTaskId, - store, - requestedTtl: taskCreationParams?.ttl - }; - } - - if (!relatedTaskId && !taskCreationParams && !taskContext) { - return {}; - } - - return { - relatedTaskId, - taskCreationParams, - taskContext - }; - } - - private wrapSendNotification( - relatedTaskId: string, - originalSendNotification: (notification: Notification, options?: NotificationOptions) => Promise - ): (notification: Notification) => Promise { - return async (notification: Notification) => { - const notificationOptions: NotificationOptions = { relatedTask: { taskId: relatedTaskId } }; - await originalSendNotification(notification, notificationOptions); - }; - } - - private wrapSendRequest( - relatedTaskId: string, - taskStore: RequestTaskStore | undefined, - originalSendRequest: ( - request: Request, - resultSchema: V, - options?: RequestOptions - ) => Promise> - ): ( - request: Request, - resultSchema: V, - options?: TaskRequestOptions - ) => Promise> { - return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { - const requestOptions: RequestOptions = { ...options }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await originalSendRequest(request, resultSchema, requestOptions); - }; - } - - private handleResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean { - const messageId = Number(response.id); - const resolver = this._requestResolvers.get(messageId); - if (resolver) { - this._requestResolvers.delete(messageId); - if (isJSONRPCResultResponse(response)) { - resolver(response); - } else { - resolver(new ProtocolError(response.error.code, response.error.message, response.error.data)); - } - return true; - } - return false; - } - - private shouldPreserveProgressHandler(response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number): boolean { - if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') { - const result = response.result as Record; - if (result.task && typeof result.task === 'object') { - const task = result.task as Record; - if (typeof task.taskId === 'string') { - this._taskProgressTokens.set(task.taskId, messageId); - return true; - } - } - } - return false; - } - - private async routeNotification(notification: Notification, options?: NotificationOptions): Promise { - const relatedTaskId = options?.relatedTask?.taskId; - if (!relatedTaskId) return false; - - const jsonrpcNotification: JSONRPCNotification = { - ...notification, - jsonrpc: '2.0', - params: { - ...notification.params, - _meta: { - ...notification.params?._meta, - [RELATED_TASK_META_KEY]: options!.relatedTask - } - } - }; - - await this._enqueueTaskMessage(relatedTaskId, { - type: 'notification', - message: jsonrpcNotification, - timestamp: Date.now() - }); - - return true; - } - - private async routeResponse( - relatedTaskId: string | undefined, - message: JSONRPCResponse | JSONRPCErrorResponse, - sessionId?: string - ): Promise { - if (!relatedTaskId || !this._taskMessageQueue) return false; - - await (isJSONRPCErrorResponse(message) - ? this._enqueueTaskMessage(relatedTaskId, { type: 'error', message, timestamp: Date.now() }, sessionId) - : this._enqueueTaskMessage( - relatedTaskId, - { type: 'response', message: message as JSONRPCResultResponse, timestamp: Date.now() }, - sessionId - )); - return true; - } - - private createRequestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { - const taskStore = this._requireTaskStore; - const host = this._host; - - return { - createTask: async taskParams => { - if (!request) throw new Error('No request provided'); - return await taskStore.createTask(taskParams, request.id, { method: request.method, params: request.params }, sessionId); - }, - getTask: async taskId => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Failed to retrieve task: Task not found'); - return task; - }, - storeTaskResult: async (taskId, status, result) => { - await taskStore.storeTaskResult(taskId, status, result, sessionId); - const task = await taskStore.getTask(taskId, sessionId); - if (task) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: task - }); - await host?.notification(notification as Notification); - if (isTerminal(task.status)) { - this._cleanupTaskProgressHandler(taskId); - } - } - }, - getTaskResult: taskId => taskStore.getTaskResult(taskId, sessionId), - updateTaskStatus: async (taskId, status, statusMessage) => { - const task = await taskStore.getTask(taskId, sessionId); - if (!task) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Task "${taskId}" not found - it may have been cleaned up`); - } - if (isTerminal(task.status)) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Cannot update task "${taskId}" from terminal status "${task.status}" to "${status}". Terminal states (completed, failed, cancelled) cannot transition to other states.` - ); - } - await taskStore.updateTaskStatus(taskId, status, statusMessage, sessionId); - const updatedTask = await taskStore.getTask(taskId, sessionId); - if (updatedTask) { - const notification: TaskStatusNotification = TaskStatusNotificationSchema.parse({ - method: 'notifications/tasks/status', - params: updatedTask - }); - await host?.notification(notification as Notification); - if (isTerminal(updatedTask.status)) { - this._cleanupTaskProgressHandler(taskId); - } - } - }, - listTasks: cursor => taskStore.listTasks(cursor, sessionId) - }; - } - - // -- Lifecycle methods (called by Protocol directly) -- - - processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { - const taskInfo = this.extractInboundTaskContext(request, ctx.sessionId); - const relatedTaskId = taskInfo?.relatedTaskId; - - const sendNotification = relatedTaskId - ? this.wrapSendNotification(relatedTaskId, ctx.sendNotification) - : (notification: Notification) => ctx.sendNotification(notification); - - const sendRequest = relatedTaskId - ? this.wrapSendRequest(relatedTaskId, taskInfo?.taskContext?.store, ctx.sendRequest) - : taskInfo?.taskContext - ? this.wrapSendRequest('', taskInfo.taskContext.store, ctx.sendRequest) - : ctx.sendRequest; - - const hasTaskCreationParams = !!taskInfo?.taskCreationParams; - - return { - taskContext: taskInfo?.taskContext, - sendNotification, - sendRequest, - routeResponse: async (message: JSONRPCResponse | JSONRPCErrorResponse) => { - if (relatedTaskId) { - return this.routeResponse(relatedTaskId, message, ctx.sessionId); - } - return false; - }, - hasTaskCreationParams, - // Deferred validation: runs inside the async handler chain so errors - // produce proper JSON-RPC error responses (matching main's behavior). - validateInbound: hasTaskCreationParams ? () => this._requireHost.assertTaskHandlerCapability(request.method) : undefined - }; - } - - processOutboundRequest( - jsonrpcRequest: JSONRPCRequest, - options: RequestOptions | undefined, - messageId: number, - responseHandler: (response: JSONRPCResultResponse | Error) => void, - onError: (error: unknown) => void - ): { queued: boolean } { - // Check task capability when sending a task-augmented request (matches main's enforceStrictCapabilities gate) - if (this._requireHost.enforceStrictCapabilities && options?.task) { - this._requireHost.assertTaskCapability(jsonrpcRequest.method); - } - - const queued = this.prepareOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, onError); - return { queued }; - } - - processInboundResponse( - response: JSONRPCResponse | JSONRPCErrorResponse, - messageId: number - ): { consumed: boolean; preserveProgress: boolean } { - const consumed = this.handleResponse(response); - if (consumed) { - return { consumed: true, preserveProgress: false }; - } - const preserveProgress = this.shouldPreserveProgressHandler(response, messageId); - return { consumed: false, preserveProgress }; - } - - async processOutboundNotification( - notification: Notification, - options?: NotificationOptions - ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { - // Try queuing first - const queued = await this.routeNotification(notification, options); - if (queued) return { queued: true }; - - // Build JSONRPC notification with optional relatedTask metadata - let jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: '2.0' }; - if (options?.relatedTask) { - jsonrpcNotification = { - ...jsonrpcNotification, - params: { - ...jsonrpcNotification.params, - _meta: { - ...jsonrpcNotification.params?._meta, - [RELATED_TASK_META_KEY]: options.relatedTask - } - } - }; - } - return { queued: false, jsonrpcNotification }; - } - - onClose(): void { - this._taskProgressTokens.clear(); - this._requestResolvers.clear(); - } - - // -- Private helpers -- - - private async _enqueueTaskMessage(taskId: string, message: QueuedMessage, sessionId?: string): Promise { - if (!this._taskStore || !this._taskMessageQueue) { - throw new Error('Cannot enqueue task message: taskStore and taskMessageQueue are not configured'); - } - await this._taskMessageQueue.enqueue(taskId, message, sessionId, this._options.maxTaskQueueSize); - } - - private async _clearTaskQueue(taskId: string, sessionId?: string): Promise { - if (this._taskMessageQueue) { - const messages = await this._taskMessageQueue.dequeueAll(taskId, sessionId); - for (const message of messages) { - if (message.type === 'request' && isJSONRPCRequest(message.message)) { - const requestId = message.message.id as RequestId; - const resolver = this._requestResolvers.get(requestId); - if (resolver) { - resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); - this._requestResolvers.delete(requestId); - } else { - this._host?.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); - } - } - } - } - } - - private async _waitForTaskUpdate(pollInterval: number | undefined, signal: AbortSignal): Promise { - const interval = pollInterval ?? this._options.defaultTaskPollInterval ?? 1000; - - return new Promise((resolve, reject) => { - if (signal.aborted) { - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - return; - } - const timeoutId = setTimeout(resolve, interval); - signal.addEventListener( - 'abort', - () => { - clearTimeout(timeoutId); - reject(new ProtocolError(ProtocolErrorCode.InvalidRequest, 'Request cancelled')); - }, - { once: true } - ); - }); - } - - private _cleanupTaskProgressHandler(taskId: string): void { - const progressToken = this._taskProgressTokens.get(taskId); - if (progressToken !== undefined) { - this._host?.removeProgressHandler(progressToken); - this._taskProgressTokens.delete(taskId); - } - } -} - -/** - * No-op TaskManager used when tasks capability is not configured. - * Provides passthrough implementations for the hot paths, avoiding - * unnecessary task extraction logic on every request. - */ -export class NullTaskManager extends TaskManager { - constructor() { - super({}); - } - - override processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { - const hasTaskCreationParams = isTaskAugmentedRequestParams(request.params) && !!request.params.task; - return { - taskContext: undefined, - sendNotification: (notification: Notification) => ctx.sendNotification(notification), - sendRequest: ctx.sendRequest, - routeResponse: async () => false, - hasTaskCreationParams, - validateInbound: hasTaskCreationParams ? () => this._requireHost.assertTaskHandlerCapability(request.method) : undefined - }; - } - - // processOutboundRequest is inherited - it handles task/relatedTask augmentation - // and only queues if relatedTask is set (which won't happen without a task store) - - // processInboundResponse is inherited - it checks _requestResolvers (empty for NullTaskManager) - // and _taskProgressTokens (empty for NullTaskManager) - - override async processOutboundNotification( - notification: Notification, - _options?: NotificationOptions - ): Promise<{ queued: boolean; jsonrpcNotification?: JSONRPCNotification }> { - return { queued: false, jsonrpcNotification: { ...notification, jsonrpc: '2.0' } }; - } -} diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index c606e2e3b5..2e24ce2c12 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -1,4 +1,12 @@ -import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/index.js'; +import type { + Implementation, + JSONRPCMessage, + JSONRPCRequest, + MessageExtraInfo, + RequestId, + Result, + ServerCapabilities +} from '../types/index.js'; export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; @@ -68,6 +76,23 @@ export type TransportSendOptions = { */ onresumptiontoken?: ((token: string) => void) | undefined; }; +/** + * Configuration passed from Protocol to routing transports during connect(). + * Provides access to the handler registry and server metadata so that + * routing transports can dispatch requests without going through Protocol's message loop. + * + * The `requestHandlers` map is a live reference — handlers registered after connect() + * (e.g., via McpServer.registerTool) are visible immediately. + */ +export interface ProtocolConfig { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + requestHandlers: ReadonlyMap Promise>; + serverInfo?: Implementation; + capabilities?: ServerCapabilities; + instructions?: string; + createServer?: () => unknown; +} + /** * Describes the minimal contract for an MCP transport that a client or server can communicate over. */ @@ -131,4 +156,12 @@ export interface Transport { * This allows the server to pass its supported versions to the transport. */ setSupportedProtocolVersions?: ((versions: string[]) => void) | undefined; + + /** + * Configures a routing transport with protocol-level metadata and handler registry. + * + * When present, `Server.connect()` treats the transport as a routing transport + * and calls this method instead of creating a `LegacyServer` internally. + */ + setProtocolConfig?: ((config: ProtocolConfig) => void) | undefined; } diff --git a/packages/core/src/types/constants.ts b/packages/core/src/types/constants.ts index 878d5111cf..1766f0c8e5 100644 --- a/packages/core/src/types/constants.ts +++ b/packages/core/src/types/constants.ts @@ -2,8 +2,6 @@ export const LATEST_PROTOCOL_VERSION = '2025-11-25'; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = '2025-03-26'; export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, '2025-06-18', '2025-03-26', '2024-11-05', '2024-10-07']; -export const RELATED_TASK_META_KEY = 'io.modelcontextprotocol/related-task'; - /* JSON-RPC types */ export const JSONRPC_VERSION = '2.0'; diff --git a/packages/core/src/types/guards.ts b/packages/core/src/types/guards.ts index f385b91b42..c8185320a9 100644 --- a/packages/core/src/types/guards.ts +++ b/packages/core/src/types/guards.ts @@ -7,8 +7,7 @@ import { JSONRPCNotificationSchema, JSONRPCRequestSchema, JSONRPCResponseSchema, - JSONRPCResultResponseSchema, - TaskAugmentedRequestParamsSchema + JSONRPCResultResponseSchema } from './schemas.js'; import type { CallToolResult, @@ -22,8 +21,7 @@ import type { JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - JSONRPCResultResponse, - TaskAugmentedRequestParams + JSONRPCResultResponse } from './types.js'; /** @@ -81,15 +79,6 @@ export const isCallToolResult = (value: unknown): value is CallToolResult => { return CallToolResultSchema.safeParse(value).success; }; -/** - * Checks if a value is a valid {@linkcode TaskAugmentedRequestParams}. - * @param value - The value to check. - * - * @returns True if the value is a valid {@linkcode TaskAugmentedRequestParams}, false otherwise. - */ -export const isTaskAugmentedRequestParams = (value: unknown): value is TaskAugmentedRequestParams => - TaskAugmentedRequestParamsSchema.safeParse(value).success; - export const isInitializeRequest = (value: unknown): value is InitializeRequest => InitializeRequestSchema.safeParse(value).success; export const isInitializedNotification = (value: unknown): value is InitializedNotification => diff --git a/packages/core/src/types/schemas.ts b/packages/core/src/types/schemas.ts index a243c1b829..19bf4f81e7 100644 --- a/packages/core/src/types/schemas.ts +++ b/packages/core/src/types/schemas.ts @@ -1,6 +1,6 @@ import * as z from 'zod/v4'; -import { JSONRPC_VERSION, RELATED_TASK_META_KEY } from './constants.js'; +import { JSONRPC_VERSION } from './constants.js'; import type { JSONArray, JSONObject, @@ -27,42 +27,11 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); -/** - * Task creation parameters, used to ask that the server create a task to represent a request. - */ -export const TaskCreationParamsSchema = z.looseObject({ - /** - * Requested duration in milliseconds to retain task from creation. - */ - ttl: z.number().optional(), - - /** - * Time in milliseconds to wait between task status requests. - */ - pollInterval: z.number().optional() -}); - -export const TaskMetadataSchema = z.object({ - ttl: z.number().optional() -}); - -/** - * Metadata for associating messages with a task. - * Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. - */ -export const RelatedTaskMetadataSchema = z.object({ - taskId: z.string() -}); - export const RequestMetaSchema = z.looseObject({ /** * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ - progressToken: ProgressTokenSchema.optional(), - /** - * If specified, this request is related to the provided task. - */ - [RELATED_TASK_META_KEY]: RelatedTaskMetadataSchema.optional() + progressToken: ProgressTokenSchema.optional() }); /** @@ -75,21 +44,6 @@ export const BaseRequestParamsSchema = z.object({ _meta: RequestMetaSchema.optional() }); -/** - * Common params for any task-augmented request. - */ -export const TaskAugmentedRequestParamsSchema = BaseRequestParamsSchema.extend({ - /** - * If specified, the caller is requesting task-augmented execution for this request. - * The request will return a `CreateTaskResult` immediately, and the actual result can be - * retrieved later via `tasks/result`. - * - * Task augmentation is subject to capability negotiation - receivers MUST declare support - * for task augmentation of specific request types in their capabilities. - */ - task: TaskMetadataSchema.optional() -}); - export const RequestSchema = z.object({ method: z.string(), params: BaseRequestParamsSchema.loose().optional() @@ -331,72 +285,6 @@ const ElicitationCapabilitySchema = z.preprocess( ) ); -/** - * Task capabilities for clients, indicating which request types support task creation. - */ -export const ClientTasksCapabilitySchema = z.looseObject({ - /** - * Present if the client supports listing tasks. - */ - list: JSONObjectSchema.optional(), - /** - * Present if the client supports cancelling tasks. - */ - cancel: JSONObjectSchema.optional(), - /** - * Capabilities for task creation on specific request types. - */ - requests: z - .looseObject({ - /** - * Task support for sampling requests. - */ - sampling: z - .looseObject({ - createMessage: JSONObjectSchema.optional() - }) - .optional(), - /** - * Task support for elicitation requests. - */ - elicitation: z - .looseObject({ - create: JSONObjectSchema.optional() - }) - .optional() - }) - .optional() -}); - -/** - * Task capabilities for servers, indicating which request types support task creation. - */ -export const ServerTasksCapabilitySchema = z.looseObject({ - /** - * Present if the server supports listing tasks. - */ - list: JSONObjectSchema.optional(), - /** - * Present if the server supports cancelling tasks. - */ - cancel: JSONObjectSchema.optional(), - /** - * Capabilities for task creation on specific request types. - */ - requests: z - .looseObject({ - /** - * Task support for tool requests. - */ - tools: z - .looseObject({ - call: JSONObjectSchema.optional() - }) - .optional() - }) - .optional() -}); - /** * Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities. */ @@ -436,10 +324,6 @@ export const ClientCapabilitiesSchema = z.object({ listChanged: z.boolean().optional() }) .optional(), - /** - * Present if the client supports task creation. - */ - tasks: ClientTasksCapabilitySchema.optional(), /** * Extensions that the client supports. Keys are extension identifiers (vendor-prefix/extension-name). */ @@ -516,10 +400,6 @@ export const ServerCapabilitiesSchema = z.object({ listChanged: z.boolean().optional() }) .optional(), - /** - * Present if the server supports task creation. - */ - tasks: ServerTasksCapabilitySchema.optional(), /** * Extensions that the server supports. Keys are extension identifiers (vendor-prefix/extension-name). */ @@ -616,120 +496,6 @@ export const PaginatedResultSchema = ResultSchema.extend({ nextCursor: CursorSchema.optional() }); -/** - * The status of a task. - * */ -export const TaskStatusSchema = z.enum(['working', 'input_required', 'completed', 'failed', 'cancelled']); - -/* Tasks */ -/** - * A pollable state object associated with a request. - */ -export const TaskSchema = z.object({ - taskId: z.string(), - status: TaskStatusSchema, - /** - * Time in milliseconds to keep task results available after completion. - * If `null`, the task has unlimited lifetime until manually cleaned up. - */ - ttl: z.union([z.number(), z.null()]), - /** - * ISO 8601 timestamp when the task was created. - */ - createdAt: z.string(), - /** - * ISO 8601 timestamp when the task was last updated. - */ - lastUpdatedAt: z.string(), - pollInterval: z.optional(z.number()), - /** - * Optional diagnostic message for failed tasks or other status information. - */ - statusMessage: z.optional(z.string()) -}); - -/** - * Result returned when a task is created, containing the task data wrapped in a `task` field. - */ -export const CreateTaskResultSchema = ResultSchema.extend({ - task: TaskSchema -}); - -/** - * Parameters for task status notification. - */ -export const TaskStatusNotificationParamsSchema = NotificationsParamsSchema.merge(TaskSchema); - -/** - * A notification sent when a task's status changes. - */ -export const TaskStatusNotificationSchema = NotificationSchema.extend({ - method: z.literal('notifications/tasks/status'), - params: TaskStatusNotificationParamsSchema -}); - -/** - * A request to get the state of a specific task. - */ -export const GetTaskRequestSchema = RequestSchema.extend({ - method: z.literal('tasks/get'), - params: BaseRequestParamsSchema.extend({ - taskId: z.string() - }) -}); - -/** - * The response to a {@linkcode GetTaskRequest | tasks/get} request. - */ -export const GetTaskResultSchema = ResultSchema.merge(TaskSchema); - -/** - * A request to get the result of a specific task. - */ -export const GetTaskPayloadRequestSchema = RequestSchema.extend({ - method: z.literal('tasks/result'), - params: BaseRequestParamsSchema.extend({ - taskId: z.string() - }) -}); - -/** - * The response to a `tasks/result` request. - * The structure matches the result type of the original request. - * For example, a {@linkcode CallToolRequest | tools/call} task would return the `CallToolResult` structure. - * - */ -export const GetTaskPayloadResultSchema = ResultSchema.loose(); - -/** - * A request to list tasks. - */ -export const ListTasksRequestSchema = PaginatedRequestSchema.extend({ - method: z.literal('tasks/list') -}); - -/** - * The response to a {@linkcode ListTasksRequest | tasks/list} request. - */ -export const ListTasksResultSchema = PaginatedResultSchema.extend({ - tasks: z.array(TaskSchema) -}); - -/** - * A request to cancel a specific task. - */ -export const CancelTaskRequestSchema = RequestSchema.extend({ - method: z.literal('tasks/cancel'), - params: BaseRequestParamsSchema.extend({ - taskId: z.string() - }) -}); - -/** - * The response to a {@linkcode CancelTaskRequest | tasks/cancel} request. - */ -export const CancelTaskResultSchema = ResultSchema.merge(TaskSchema); - /* Resources */ /** * The contents of a specific resource or sub-resource. @@ -1409,7 +1175,7 @@ export const CompatibilityCallToolResultSchema = CallToolResultSchema.or( /** * Parameters for a `tools/call` request. */ -export const CallToolRequestParamsSchema = TaskAugmentedRequestParamsSchema.extend({ +export const CallToolRequestParamsSchema = BaseRequestParamsSchema.extend({ /** * The name of the tool to call. */ @@ -1607,7 +1373,7 @@ export const SamplingMessageSchema = z.object({ /** * Parameters for a `sampling/createMessage` request. */ -export const CreateMessageRequestParamsSchema = TaskAugmentedRequestParamsSchema.extend({ +export const CreateMessageRequestParamsSchema = BaseRequestParamsSchema.extend({ messages: z.array(SamplingMessageSchema), /** * The server's preferences for which model to select. The client MAY modify or omit this request. @@ -1846,7 +1612,7 @@ export const PrimitiveSchemaDefinitionSchema = z.union([EnumSchemaSchema, Boolea /** * Parameters for an `elicitation/create` request for form-based elicitation. */ -export const ElicitRequestFormParamsSchema = TaskAugmentedRequestParamsSchema.extend({ +export const ElicitRequestFormParamsSchema = BaseRequestParamsSchema.extend({ /** * The elicitation mode. * @@ -1873,7 +1639,7 @@ export const ElicitRequestFormParamsSchema = TaskAugmentedRequestParamsSchema.ex /** * Parameters for an {@linkcode ElicitRequest | elicitation/create} request for URL-based elicitation. */ -export const ElicitRequestURLParamsSchema = TaskAugmentedRequestParamsSchema.extend({ +export const ElicitRequestURLParamsSchema = BaseRequestParamsSchema.extend({ /** * The elicitation mode. */ @@ -2089,19 +1855,14 @@ export const ClientRequestSchema = z.union([ SubscribeRequestSchema, UnsubscribeRequestSchema, CallToolRequestSchema, - ListToolsRequestSchema, - GetTaskRequestSchema, - GetTaskPayloadRequestSchema, - ListTasksRequestSchema, - CancelTaskRequestSchema + ListToolsRequestSchema ]); export const ClientNotificationSchema = z.union([ CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, - RootsListChangedNotificationSchema, - TaskStatusNotificationSchema + RootsListChangedNotificationSchema ]); export const ClientResultSchema = z.union([ @@ -2109,23 +1870,11 @@ export const ClientResultSchema = z.union([ CreateMessageResultSchema, CreateMessageResultWithToolsSchema, ElicitResultSchema, - ListRootsResultSchema, - GetTaskResultSchema, - ListTasksResultSchema, - CreateTaskResultSchema + ListRootsResultSchema ]); /* Server messages */ -export const ServerRequestSchema = z.union([ - PingRequestSchema, - CreateMessageRequestSchema, - ElicitRequestSchema, - ListRootsRequestSchema, - GetTaskRequestSchema, - GetTaskPayloadRequestSchema, - ListTasksRequestSchema, - CancelTaskRequestSchema -]); +export const ServerRequestSchema = z.union([PingRequestSchema, CreateMessageRequestSchema, ElicitRequestSchema, ListRootsRequestSchema]); export const ServerNotificationSchema = z.union([ CancelledNotificationSchema, @@ -2135,7 +1884,6 @@ export const ServerNotificationSchema = z.union([ ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, PromptListChangedNotificationSchema, - TaskStatusNotificationSchema, ElicitationCompleteNotificationSchema ]); @@ -2149,10 +1897,7 @@ export const ServerResultSchema = z.union([ ListResourceTemplatesResultSchema, ReadResourceResultSchema, CallToolResultSchema, - ListToolsResultSchema, - GetTaskResultSchema, - ListTasksResultSchema, - CreateTaskResultSchema + ListToolsResultSchema ]); /* Runtime schema lookup — result schemas by method */ @@ -2168,15 +1913,11 @@ const resultSchemas: Record = { 'resources/read': ReadResourceResultSchema, 'resources/subscribe': EmptyResultSchema, 'resources/unsubscribe': EmptyResultSchema, - 'tools/call': z.union([CallToolResultSchema, CreateTaskResultSchema]), + 'tools/call': CallToolResultSchema, 'tools/list': ListToolsResultSchema, - 'sampling/createMessage': z.union([CreateMessageResultWithToolsSchema, CreateTaskResultSchema]), - 'elicitation/create': z.union([ElicitResultSchema, CreateTaskResultSchema]), - 'roots/list': ListRootsResultSchema, - 'tasks/get': GetTaskResultSchema, - 'tasks/result': ResultSchema, - 'tasks/list': ListTasksResultSchema, - 'tasks/cancel': CancelTaskResultSchema + 'sampling/createMessage': CreateMessageResultWithToolsSchema, + 'elicitation/create': ElicitResultSchema, + 'roots/list': ListRootsResultSchema }; /** diff --git a/packages/core/src/types/specTypeSchema.ts b/packages/core/src/types/specTypeSchema.ts index 477d61a55a..8906b288de 100644 --- a/packages/core/src/types/specTypeSchema.ts +++ b/packages/core/src/types/specTypeSchema.ts @@ -41,8 +41,6 @@ const SPEC_SCHEMA_KEYS = [ 'CallToolResultSchema', 'CancelledNotificationSchema', 'CancelledNotificationParamsSchema', - 'CancelTaskRequestSchema', - 'CancelTaskResultSchema', 'ClientCapabilitiesSchema', 'ClientNotificationSchema', 'ClientRequestSchema', @@ -56,7 +54,6 @@ const SPEC_SCHEMA_KEYS = [ 'CreateMessageRequestParamsSchema', 'CreateMessageResultSchema', 'CreateMessageResultWithToolsSchema', - 'CreateTaskResultSchema', 'CursorSchema', 'ElicitationCompleteNotificationSchema', 'ElicitationCompleteNotificationParamsSchema', @@ -71,10 +68,6 @@ const SPEC_SCHEMA_KEYS = [ 'GetPromptRequestSchema', 'GetPromptRequestParamsSchema', 'GetPromptResultSchema', - 'GetTaskPayloadRequestSchema', - 'GetTaskPayloadResultSchema', - 'GetTaskRequestSchema', - 'GetTaskResultSchema', 'IconSchema', 'IconsSchema', 'ImageContentSchema', @@ -101,8 +94,6 @@ const SPEC_SCHEMA_KEYS = [ 'ListResourceTemplatesResultSchema', 'ListRootsRequestSchema', 'ListRootsResultSchema', - 'ListTasksRequestSchema', - 'ListTasksResultSchema', 'ListToolsRequestSchema', 'ListToolsResultSchema', 'LoggingLevelSchema', @@ -130,7 +121,6 @@ const SPEC_SCHEMA_KEYS = [ 'ReadResourceRequestSchema', 'ReadResourceRequestParamsSchema', 'ReadResourceResultSchema', - 'RelatedTaskMetadataSchema', 'RequestSchema', 'RequestIdSchema', 'RequestMetaSchema', @@ -160,13 +150,6 @@ const SPEC_SCHEMA_KEYS = [ 'StringSchemaSchema', 'SubscribeRequestSchema', 'SubscribeRequestParamsSchema', - 'TaskSchema', - 'TaskAugmentedRequestParamsSchema', - 'TaskCreationParamsSchema', - 'TaskMetadataSchema', - 'TaskStatusSchema', - 'TaskStatusNotificationSchema', - 'TaskStatusNotificationParamsSchema', 'TextContentSchema', 'TextResourceContentsSchema', 'TitledMultiSelectEnumSchemaSchema', diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index a92deec8e1..9a12c95c7f 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -17,8 +17,6 @@ import type { CallToolResultSchema, CancelledNotificationParamsSchema, CancelledNotificationSchema, - CancelTaskRequestSchema, - CancelTaskResultSchema, ClientCapabilitiesSchema, ClientNotificationSchema, ClientRequestSchema, @@ -32,7 +30,6 @@ import type { CreateMessageRequestSchema, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, - CreateTaskResultSchema, CursorSchema, ElicitationCompleteNotificationParamsSchema, ElicitationCompleteNotificationSchema, @@ -47,10 +44,6 @@ import type { GetPromptRequestParamsSchema, GetPromptRequestSchema, GetPromptResultSchema, - GetTaskPayloadRequestSchema, - GetTaskPayloadResultSchema, - GetTaskRequestSchema, - GetTaskResultSchema, IconSchema, IconsSchema, ImageContentSchema, @@ -74,8 +67,6 @@ import type { ListResourceTemplatesResultSchema, ListRootsRequestSchema, ListRootsResultSchema, - ListTasksRequestSchema, - ListTasksResultSchema, ListToolsRequestSchema, ListToolsResultSchema, LoggingLevelSchema, @@ -104,7 +95,6 @@ import type { ReadResourceRequestParamsSchema, ReadResourceRequestSchema, ReadResourceResultSchema, - RelatedTaskMetadataSchema, RequestIdSchema, RequestMetaSchema, RequestSchema, @@ -134,13 +124,6 @@ import type { StringSchemaSchema, SubscribeRequestParamsSchema, SubscribeRequestSchema, - TaskAugmentedRequestParamsSchema, - TaskCreationParamsSchema, - TaskMetadataSchema, - TaskSchema, - TaskStatusNotificationParamsSchema, - TaskStatusNotificationSchema, - TaskStatusSchema, TextContentSchema, TextResourceContentsSchema, TitledMultiSelectEnumSchemaSchema, @@ -187,7 +170,6 @@ type Infer = Flatten>; export type ProgressToken = Infer; export type Cursor = Infer; export type Request = Infer; -export type TaskAugmentedRequestParams = Infer; export type RequestMeta = Infer; export type Notification = Infer; export type Result = Infer; @@ -232,24 +214,6 @@ export type Progress = Infer; export type ProgressNotificationParams = Infer; export type ProgressNotification = Infer; -/* Tasks */ -export type Task = Infer; -export type TaskStatus = Infer; -export type TaskCreationParams = Infer; -export type TaskMetadata = Infer; -export type RelatedTaskMetadata = Infer; -export type CreateTaskResult = Infer; -export type TaskStatusNotificationParams = Infer; -export type TaskStatusNotification = Infer; -export type GetTaskRequest = Infer; -export type GetTaskResult = Infer; -export type GetTaskPayloadRequest = Infer; -export type ListTasksRequest = Infer; -export type ListTasksResult = Infer; -export type CancelTaskRequest = Infer; -export type CancelTaskResult = Infer; -export type GetTaskPayloadResult = Infer; - /* Pagination */ export type PaginatedRequestParams = Infer; export type PaginatedRequest = Infer; @@ -392,15 +356,11 @@ export type ResultTypeMap = { 'resources/read': ReadResourceResult; 'resources/subscribe': EmptyResult; 'resources/unsubscribe': EmptyResult; - 'tools/call': CallToolResult | CreateTaskResult; + 'tools/call': CallToolResult; 'tools/list': ListToolsResult; - 'sampling/createMessage': CreateMessageResult | CreateMessageResultWithTools | CreateTaskResult; - 'elicitation/create': ElicitResult | CreateTaskResult; + 'sampling/createMessage': CreateMessageResult | CreateMessageResultWithTools; + 'elicitation/create': ElicitResult; 'roots/list': ListRootsResult; - 'tasks/get': GetTaskResult; - 'tasks/result': Result; - 'tasks/list': ListTasksResult; - 'tasks/cancel': CancelTaskResult; }; /** diff --git a/packages/core/test/experimental/inMemory.test.ts b/packages/core/test/experimental/inMemory.test.ts deleted file mode 100644 index 7639cad9f4..0000000000 --- a/packages/core/test/experimental/inMemory.test.ts +++ /dev/null @@ -1,1035 +0,0 @@ -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; - -import type { QueuedMessage } from '../../src/experimental/tasks/interfaces.js'; -import { InMemoryTaskMessageQueue, InMemoryTaskStore } from '../../src/experimental/tasks/stores/inMemory.js'; -import type { Request, TaskCreationParams } from '../../src/types/index.js'; - -describe('InMemoryTaskStore', () => { - let store: InMemoryTaskStore; - - beforeEach(() => { - store = new InMemoryTaskStore(); - }); - - afterEach(() => { - store.cleanup(); - }); - - describe('createTask', () => { - it('should create a new task with working status', async () => { - const taskParams: TaskCreationParams = { - ttl: 60_000 - }; - const request: Request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const task = await store.createTask(taskParams, 123, request); - - expect(task).toBeDefined(); - expect(task.taskId).toBeDefined(); - expect(typeof task.taskId).toBe('string'); - expect(task.taskId.length).toBeGreaterThan(0); - expect(task.status).toBe('working'); - expect(task.ttl).toBe(60_000); - expect(task.pollInterval).toBeDefined(); - expect(task.createdAt).toBeDefined(); - expect(new Date(task.createdAt).getTime()).toBeGreaterThan(0); - }); - - it('should create task without ttl', async () => { - const taskParams: TaskCreationParams = {}; - const request: Request = { - method: 'tools/call', - params: {} - }; - - const task = await store.createTask(taskParams, 456, request); - - expect(task).toBeDefined(); - expect(task.ttl).toBeNull(); - }); - - it('should generate unique taskIds', async () => { - const taskParams: TaskCreationParams = {}; - const request: Request = { - method: 'tools/call', - params: {} - }; - - const task1 = await store.createTask(taskParams, 789, request); - const task2 = await store.createTask(taskParams, 790, request); - - expect(task1.taskId).not.toBe(task2.taskId); - }); - }); - - describe('getTask', () => { - it('should return null for non-existent task', async () => { - const task = await store.getTask('non-existent'); - expect(task).toBeNull(); - }); - - it('should return task state', async () => { - const taskParams: TaskCreationParams = {}; - const request: Request = { - method: 'tools/call', - params: {} - }; - - const createdTask = await store.createTask(taskParams, 111, request); - await store.updateTaskStatus(createdTask.taskId, 'working'); - - const task = await store.getTask(createdTask.taskId); - expect(task).toBeDefined(); - expect(task?.status).toBe('working'); - }); - }); - - describe('updateTaskStatus', () => { - let taskId: string; - - beforeEach(async () => { - const taskParams: TaskCreationParams = {}; - const createdTask = await store.createTask(taskParams, 222, { - method: 'tools/call', - params: {} - }); - taskId = createdTask.taskId; - }); - - it('should keep task status as working', async () => { - const task = await store.getTask(taskId); - expect(task?.status).toBe('working'); - }); - - it('should update task status to input_required', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('input_required'); - }); - - it('should update task status to completed', async () => { - await store.updateTaskStatus(taskId, 'completed'); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('completed'); - }); - - it('should update task status to failed with error', async () => { - await store.updateTaskStatus(taskId, 'failed', 'Something went wrong'); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('failed'); - expect(task?.statusMessage).toBe('Something went wrong'); - }); - - it('should update task status to cancelled', async () => { - await store.updateTaskStatus(taskId, 'cancelled'); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('cancelled'); - }); - - it('should throw if task not found', async () => { - await expect(store.updateTaskStatus('non-existent', 'working')).rejects.toThrow('Task with ID non-existent not found'); - }); - - describe('status lifecycle validation', () => { - it('should allow transition from working to input_required', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('input_required'); - }); - - it('should allow transition from working to completed', async () => { - await store.updateTaskStatus(taskId, 'completed'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('completed'); - }); - - it('should allow transition from working to failed', async () => { - await store.updateTaskStatus(taskId, 'failed'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('failed'); - }); - - it('should allow transition from working to cancelled', async () => { - await store.updateTaskStatus(taskId, 'cancelled'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('cancelled'); - }); - - it('should allow transition from input_required to working', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - await store.updateTaskStatus(taskId, 'working'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('working'); - }); - - it('should allow transition from input_required to completed', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - await store.updateTaskStatus(taskId, 'completed'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('completed'); - }); - - it('should allow transition from input_required to failed', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - await store.updateTaskStatus(taskId, 'failed'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('failed'); - }); - - it('should allow transition from input_required to cancelled', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - await store.updateTaskStatus(taskId, 'cancelled'); - const task = await store.getTask(taskId); - expect(task?.status).toBe('cancelled'); - }); - - it('should reject transition from completed to any other status', async () => { - await store.updateTaskStatus(taskId, 'completed'); - await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); - }); - - it('should reject transition from failed to any other status', async () => { - await store.updateTaskStatus(taskId, 'failed'); - await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'cancelled')).rejects.toThrow('Cannot update task'); - }); - - it('should reject transition from cancelled to any other status', async () => { - await store.updateTaskStatus(taskId, 'cancelled'); - await expect(store.updateTaskStatus(taskId, 'working')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'input_required')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'completed')).rejects.toThrow('Cannot update task'); - await expect(store.updateTaskStatus(taskId, 'failed')).rejects.toThrow('Cannot update task'); - }); - }); - }); - - describe('storeTaskResult', () => { - let taskId: string; - - beforeEach(async () => { - const taskParams: TaskCreationParams = { - ttl: 60_000 - }; - const createdTask = await store.createTask(taskParams, 333, { - method: 'tools/call', - params: {} - }); - taskId = createdTask.taskId; - }); - - it('should store task result and set status to completed', async () => { - const result = { - content: [{ type: 'text' as const, text: 'Success!' }] - }; - - await store.storeTaskResult(taskId, 'completed', result); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('completed'); - - const storedResult = await store.getTaskResult(taskId); - expect(storedResult).toStrictEqual(result); - }); - - it('should throw if task not found', async () => { - await expect(store.storeTaskResult('non-existent', 'completed', {})).rejects.toThrow('Task with ID non-existent not found'); - }); - - it('should reject storing result for task already in completed status', async () => { - // First complete the task - const firstResult = { - content: [{ type: 'text' as const, text: 'First result' }] - }; - await store.storeTaskResult(taskId, 'completed', firstResult); - - // Try to store result again (should fail) - const secondResult = { - content: [{ type: 'text' as const, text: 'Second result' }] - }; - - await expect(store.storeTaskResult(taskId, 'completed', secondResult)).rejects.toThrow('Cannot store result for task'); - }); - - it('should store result with failed status', async () => { - const result = { - content: [{ type: 'text' as const, text: 'Error details' }], - isError: true - }; - - await store.storeTaskResult(taskId, 'failed', result); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('failed'); - - const storedResult = await store.getTaskResult(taskId); - expect(storedResult).toStrictEqual(result); - }); - - it('should reject storing result for task already in failed status', async () => { - // First fail the task - const firstResult = { - content: [{ type: 'text' as const, text: 'First error' }], - isError: true - }; - await store.storeTaskResult(taskId, 'failed', firstResult); - - // Try to store result again (should fail) - const secondResult = { - content: [{ type: 'text' as const, text: 'Second error' }], - isError: true - }; - - await expect(store.storeTaskResult(taskId, 'failed', secondResult)).rejects.toThrow('Cannot store result for task'); - }); - - it('should reject storing result for cancelled task', async () => { - // Mark task as cancelled - await store.updateTaskStatus(taskId, 'cancelled'); - - // Try to store result (should fail) - const result = { - content: [{ type: 'text' as const, text: 'Cancellation result' }] - }; - - await expect(store.storeTaskResult(taskId, 'completed', result)).rejects.toThrow('Cannot store result for task'); - }); - - it('should allow storing result from input_required status', async () => { - await store.updateTaskStatus(taskId, 'input_required'); - - const result = { - content: [{ type: 'text' as const, text: 'Success!' }] - }; - - await store.storeTaskResult(taskId, 'completed', result); - - const task = await store.getTask(taskId); - expect(task?.status).toBe('completed'); - }); - }); - - describe('getTaskResult', () => { - it('should throw if task not found', async () => { - await expect(store.getTaskResult('non-existent')).rejects.toThrow('Task with ID non-existent not found'); - }); - - it('should throw if task has no result stored', async () => { - const taskParams: TaskCreationParams = {}; - const createdTask = await store.createTask(taskParams, 444, { - method: 'tools/call', - params: {} - }); - - await expect(store.getTaskResult(createdTask.taskId)).rejects.toThrow(`Task ${createdTask.taskId} has no result stored`); - }); - - it('should return stored result', async () => { - const taskParams: TaskCreationParams = {}; - const createdTask = await store.createTask(taskParams, 555, { - method: 'tools/call', - params: {} - }); - - const result = { - content: [{ type: 'text' as const, text: 'Result data' }] - }; - await store.storeTaskResult(createdTask.taskId, 'completed', result); - - const retrieved = await store.getTaskResult(createdTask.taskId); - expect(retrieved).toStrictEqual(result); - }); - }); - - describe('ttl cleanup', () => { - beforeEach(() => { - vi.useFakeTimers(); - }); - - afterEach(() => { - vi.useRealTimers(); - }); - - it('should cleanup task after ttl duration', async () => { - const taskParams: TaskCreationParams = { - ttl: 1000 - }; - const createdTask = await store.createTask(taskParams, 666, { - method: 'tools/call', - params: {} - }); - - // Task should exist initially - let task = await store.getTask(createdTask.taskId); - expect(task).toBeDefined(); - - // Fast-forward past ttl - vi.advanceTimersByTime(1001); - - // Task should be cleaned up - task = await store.getTask(createdTask.taskId); - expect(task).toBeNull(); - }); - - it('should reset cleanup timer when result is stored', async () => { - const taskParams: TaskCreationParams = { - ttl: 1000 - }; - const createdTask = await store.createTask(taskParams, 777, { - method: 'tools/call', - params: {} - }); - - // Fast-forward 500ms - vi.advanceTimersByTime(500); - - // Store result (should reset timer) - await store.storeTaskResult(createdTask.taskId, 'completed', { - content: [{ type: 'text' as const, text: 'Done' }] - }); - - // Fast-forward another 500ms (total 1000ms since creation, but timer was reset) - vi.advanceTimersByTime(500); - - // Task should still exist - const task = await store.getTask(createdTask.taskId); - expect(task).toBeDefined(); - - // Fast-forward remaining time - vi.advanceTimersByTime(501); - - // Now task should be cleaned up - const cleanedTask = await store.getTask(createdTask.taskId); - expect(cleanedTask).toBeNull(); - }); - - it('should not cleanup tasks without ttl', async () => { - const taskParams: TaskCreationParams = {}; - const createdTask = await store.createTask(taskParams, 888, { - method: 'tools/call', - params: {} - }); - - // Fast-forward a long time - vi.advanceTimersByTime(100_000); - - // Task should still exist - const task = await store.getTask(createdTask.taskId); - expect(task).toBeDefined(); - }); - - it('should start cleanup timer when task reaches terminal state', async () => { - const taskParams: TaskCreationParams = { - ttl: 1000 - }; - const createdTask = await store.createTask(taskParams, 999, { - method: 'tools/call', - params: {} - }); - - // Task in non-terminal state, fast-forward - vi.advanceTimersByTime(1001); - - // Task should be cleaned up - let task = await store.getTask(createdTask.taskId); - expect(task).toBeNull(); - - // Create another task - const taskParams2: TaskCreationParams = { - ttl: 2000 - }; - const createdTask2 = await store.createTask(taskParams2, 1000, { - method: 'tools/call', - params: {} - }); - - // Update to terminal state - await store.updateTaskStatus(createdTask2.taskId, 'completed'); - - // Fast-forward past original ttl - vi.advanceTimersByTime(2001); - - // Task should be cleaned up - task = await store.getTask(createdTask2.taskId); - expect(task).toBeNull(); - }); - - it('should return actual TTL in task response', async () => { - // Test that the TaskStore returns the actual TTL it will use - // This implementation uses the requested TTL as-is, but implementations - // MAY override it (e.g., enforce maximum TTL limits) - const requestedTtl = 5000; - const taskParams: TaskCreationParams = { - ttl: requestedTtl - }; - const createdTask = await store.createTask(taskParams, 1111, { - method: 'tools/call', - params: {} - }); - - // The returned task should include the actual TTL that will be used - expect(createdTask.ttl).toBe(requestedTtl); - - // Verify the task is cleaned up after the actual TTL - vi.advanceTimersByTime(requestedTtl + 1); - const task = await store.getTask(createdTask.taskId); - expect(task).toBeNull(); - }); - - it('should support omitted TTL for unlimited lifetime', async () => { - // Test that omitting TTL means unlimited lifetime (server returns null) - // Per spec: clients omit ttl to let server decide, server returns null for unlimited - const taskParams: TaskCreationParams = {}; - const createdTask = await store.createTask(taskParams, 2222, { - method: 'tools/call', - params: {} - }); - - // The returned task should have null TTL (unlimited) - expect(createdTask.ttl).toBeNull(); - - // Task should not be cleaned up even after a long time - vi.advanceTimersByTime(100_000); - const task = await store.getTask(createdTask.taskId); - expect(task).toBeDefined(); - expect(task?.taskId).toBe(createdTask.taskId); - }); - - it('should cleanup tasks regardless of status', async () => { - // Test that TTL cleanup happens regardless of task status - const taskParams: TaskCreationParams = { - ttl: 1000 - }; - - // Create tasks in different statuses - const workingTask = await store.createTask(taskParams, 3333, { - method: 'tools/call', - params: {} - }); - - const completedTask = await store.createTask(taskParams, 4444, { - method: 'tools/call', - params: {} - }); - await store.storeTaskResult(completedTask.taskId, 'completed', { - content: [{ type: 'text' as const, text: 'Done' }] - }); - - const failedTask = await store.createTask(taskParams, 5555, { - method: 'tools/call', - params: {} - }); - await store.storeTaskResult(failedTask.taskId, 'failed', { - content: [{ type: 'text' as const, text: 'Error' }] - }); - - // Fast-forward past TTL - vi.advanceTimersByTime(1001); - - // All tasks should be cleaned up regardless of status - expect(await store.getTask(workingTask.taskId)).toBeNull(); - expect(await store.getTask(completedTask.taskId)).toBeNull(); - expect(await store.getTask(failedTask.taskId)).toBeNull(); - }); - }); - - describe('getAllTasks', () => { - it('should return all tasks', async () => { - await store.createTask({}, 1, { - method: 'tools/call', - params: {} - }); - await store.createTask({}, 2, { - method: 'tools/call', - params: {} - }); - await store.createTask({}, 3, { - method: 'tools/call', - params: {} - }); - - const tasks = store.getAllTasks(); - expect(tasks).toHaveLength(3); - // Verify all tasks have unique IDs - const taskIds = tasks.map(t => t.taskId); - expect(new Set(taskIds).size).toBe(3); - }); - - it('should return empty array when no tasks', () => { - const tasks = store.getAllTasks(); - expect(tasks).toStrictEqual([]); - }); - }); - - describe('listTasks', () => { - it('should return empty list when no tasks', async () => { - const result = await store.listTasks(); - expect(result.tasks).toStrictEqual([]); - expect(result.nextCursor).toBeUndefined(); - }); - - it('should return all tasks when less than page size', async () => { - await store.createTask({}, 1, { - method: 'tools/call', - params: {} - }); - await store.createTask({}, 2, { - method: 'tools/call', - params: {} - }); - await store.createTask({}, 3, { - method: 'tools/call', - params: {} - }); - - const result = await store.listTasks(); - expect(result.tasks).toHaveLength(3); - expect(result.nextCursor).toBeUndefined(); - }); - - it('should paginate when more than page size', async () => { - // Create 15 tasks (page size is 10) - for (let i = 1; i <= 15; i++) { - await store.createTask({}, i, { - method: 'tools/call', - params: {} - }); - } - - // Get first page - const page1 = await store.listTasks(); - expect(page1.tasks).toHaveLength(10); - expect(page1.nextCursor).toBeDefined(); - - // Get second page using cursor - const page2 = await store.listTasks(page1.nextCursor); - expect(page2.tasks).toHaveLength(5); - expect(page2.nextCursor).toBeUndefined(); - }); - - it('should throw error for invalid cursor', async () => { - await store.createTask({}, 1, { - method: 'tools/call', - params: {} - }); - - await expect(store.listTasks('non-existent-cursor')).rejects.toThrow('Invalid cursor: non-existent-cursor'); - }); - - it('should continue from cursor correctly', async () => { - // Create 5 tasks - for (let i = 1; i <= 5; i++) { - await store.createTask({}, i, { - method: 'tools/call', - params: {} - }); - } - - // Get first 3 tasks - const allTaskIds = store.getAllTasks().map(t => t.taskId); - const result = await store.listTasks(allTaskIds[2]); - - // Should get tasks after the third task - expect(result.tasks).toHaveLength(2); - }); - }); - - describe('session isolation', () => { - const baseRequest: Request = { method: 'tools/call', params: { name: 'demo' } }; - - it('should not allow session-b to list tasks created by session-a', async () => { - await store.createTask({}, 1, baseRequest, 'session-a'); - await store.createTask({}, 2, baseRequest, 'session-a'); - - const result = await store.listTasks(undefined, 'session-b'); - expect(result.tasks).toHaveLength(0); - }); - - it('should not allow session-b to read a task created by session-a', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - - const result = await store.getTask(task.taskId, 'session-b'); - expect(result).toBeNull(); - }); - - it('should not allow session-b to update a task created by session-a', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - - await expect(store.updateTaskStatus(task.taskId, 'cancelled', undefined, 'session-b')).rejects.toThrow('not found'); - }); - - it('should not allow session-b to store a result on session-a task', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - - await expect(store.storeTaskResult(task.taskId, 'completed', { content: [] }, 'session-b')).rejects.toThrow('not found'); - }); - - it('should not allow session-b to get the result of session-a task', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - await store.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'secret' }] }, 'session-a'); - - await expect(store.getTaskResult(task.taskId, 'session-b')).rejects.toThrow('not found'); - }); - - it('should allow the owning session to access its own tasks', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - - const retrieved = await store.getTask(task.taskId, 'session-a'); - expect(retrieved).toBeDefined(); - expect(retrieved?.taskId).toBe(task.taskId); - }); - - it('should list only tasks belonging to the requesting session', async () => { - await store.createTask({}, 1, baseRequest, 'session-a'); - await store.createTask({}, 2, baseRequest, 'session-b'); - await store.createTask({}, 3, baseRequest, 'session-a'); - - const resultA = await store.listTasks(undefined, 'session-a'); - expect(resultA.tasks).toHaveLength(2); - - const resultB = await store.listTasks(undefined, 'session-b'); - expect(resultB.tasks).toHaveLength(1); - }); - - it('should allow access when no sessionId is provided (backward compatibility)', async () => { - const task = await store.createTask({}, 1, baseRequest, 'session-a'); - - // No sessionId on read = no filtering - const retrieved = await store.getTask(task.taskId); - expect(retrieved).toBeDefined(); - }); - - it('should allow access when task was created without sessionId', async () => { - const task = await store.createTask({}, 1, baseRequest); - - // Any sessionId on read should still see the task - const retrieved = await store.getTask(task.taskId, 'session-b'); - expect(retrieved).toBeDefined(); - }); - - it('should paginate correctly within a session', async () => { - // Create 15 tasks for session-a, 5 for session-b - for (let i = 1; i <= 15; i++) { - await store.createTask({}, i, baseRequest, 'session-a'); - } - for (let i = 16; i <= 20; i++) { - await store.createTask({}, i, baseRequest, 'session-b'); - } - - // First page for session-a should have 10 - const page1 = await store.listTasks(undefined, 'session-a'); - expect(page1.tasks).toHaveLength(10); - expect(page1.nextCursor).toBeDefined(); - - // Second page for session-a should have 5 - const page2 = await store.listTasks(page1.nextCursor, 'session-a'); - expect(page2.tasks).toHaveLength(5); - expect(page2.nextCursor).toBeUndefined(); - - // session-b should only see its 5 - const resultB = await store.listTasks(undefined, 'session-b'); - expect(resultB.tasks).toHaveLength(5); - expect(resultB.nextCursor).toBeUndefined(); - }); - }); - - describe('cleanup', () => { - it('should clear all timers and tasks', async () => { - await store.createTask({ ttl: 1000 }, 1, { - method: 'tools/call', - params: {} - }); - await store.createTask({ ttl: 2000 }, 2, { - method: 'tools/call', - params: {} - }); - - expect(store.getAllTasks()).toHaveLength(2); - - store.cleanup(); - - expect(store.getAllTasks()).toHaveLength(0); - }); - }); -}); - -describe('InMemoryTaskMessageQueue', () => { - let queue: InMemoryTaskMessageQueue; - - beforeEach(() => { - queue = new InMemoryTaskMessageQueue(); - }); - - describe('enqueue and dequeue', () => { - it('should enqueue and dequeue request messages', async () => { - const requestMessage: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-1', requestMessage); - const dequeued = await queue.dequeue('task-1'); - - expect(dequeued).toStrictEqual(requestMessage); - }); - - it('should enqueue and dequeue notification messages', async () => { - const notificationMessage: QueuedMessage = { - type: 'notification', - message: { - jsonrpc: '2.0', - method: 'notifications/progress', - params: { progress: 50, total: 100 } - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-2', notificationMessage); - const dequeued = await queue.dequeue('task-2'); - - expect(dequeued).toStrictEqual(notificationMessage); - }); - - it('should enqueue and dequeue response messages', async () => { - const responseMessage: QueuedMessage = { - type: 'response', - message: { - jsonrpc: '2.0', - id: 42, - result: { content: [{ type: 'text', text: 'Success' }] } - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-3', responseMessage); - const dequeued = await queue.dequeue('task-3'); - - expect(dequeued).toStrictEqual(responseMessage); - }); - - it('should return undefined when dequeuing from empty queue', async () => { - const dequeued = await queue.dequeue('task-empty'); - expect(dequeued).toBeUndefined(); - }); - - it('should maintain FIFO order for mixed message types', async () => { - const request: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: {} - }, - timestamp: 1000 - }; - - const notification: QueuedMessage = { - type: 'notification', - message: { - jsonrpc: '2.0', - method: 'notifications/progress', - params: {} - }, - timestamp: 2000 - }; - - const response: QueuedMessage = { - type: 'response', - message: { - jsonrpc: '2.0', - id: 1, - result: {} - }, - timestamp: 3000 - }; - - await queue.enqueue('task-fifo', request); - await queue.enqueue('task-fifo', notification); - await queue.enqueue('task-fifo', response); - - expect(await queue.dequeue('task-fifo')).toStrictEqual(request); - expect(await queue.dequeue('task-fifo')).toStrictEqual(notification); - expect(await queue.dequeue('task-fifo')).toStrictEqual(response); - expect(await queue.dequeue('task-fifo')).toBeUndefined(); - }); - }); - - describe('dequeueAll', () => { - it('should dequeue all messages including responses', async () => { - const request: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: {} - }, - timestamp: 1000 - }; - - const response: QueuedMessage = { - type: 'response', - message: { - jsonrpc: '2.0', - id: 1, - result: {} - }, - timestamp: 2000 - }; - - const notification: QueuedMessage = { - type: 'notification', - message: { - jsonrpc: '2.0', - method: 'notifications/progress', - params: {} - }, - timestamp: 3000 - }; - - await queue.enqueue('task-all', request); - await queue.enqueue('task-all', response); - await queue.enqueue('task-all', notification); - - const all = await queue.dequeueAll('task-all'); - - expect(all).toHaveLength(3); - expect(all[0]).toStrictEqual(request); - expect(all[1]).toStrictEqual(response); - expect(all[2]).toStrictEqual(notification); - }); - - it('should return empty array for non-existent task', async () => { - const all = await queue.dequeueAll('non-existent'); - expect(all).toStrictEqual([]); - }); - - it('should clear the queue after dequeueAll', async () => { - const message: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'test', - params: {} - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-clear', message); - await queue.dequeueAll('task-clear'); - - const dequeued = await queue.dequeue('task-clear'); - expect(dequeued).toBeUndefined(); - }); - }); - - describe('queue size limits', () => { - it('should throw when maxSize is exceeded', async () => { - const message: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'test', - params: {} - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-limit', message, undefined, 2); - await queue.enqueue('task-limit', message, undefined, 2); - - await expect(queue.enqueue('task-limit', message, undefined, 2)).rejects.toThrow('Task message queue overflow'); - }); - - it('should allow enqueue when under maxSize', async () => { - const message: QueuedMessage = { - type: 'response', - message: { - jsonrpc: '2.0', - id: 1, - result: {} - }, - timestamp: Date.now() - }; - - await expect(queue.enqueue('task-ok', message, undefined, 5)).resolves.toBeUndefined(); - }); - }); - - describe('task isolation', () => { - it('should isolate messages between different tasks', async () => { - const message1: QueuedMessage = { - type: 'request', - message: { - jsonrpc: '2.0', - id: 1, - method: 'test1', - params: {} - }, - timestamp: 1000 - }; - - const message2: QueuedMessage = { - type: 'response', - message: { - jsonrpc: '2.0', - id: 2, - result: {} - }, - timestamp: 2000 - }; - - await queue.enqueue('task-a', message1); - await queue.enqueue('task-b', message2); - - expect(await queue.dequeue('task-a')).toStrictEqual(message1); - expect(await queue.dequeue('task-b')).toStrictEqual(message2); - expect(await queue.dequeue('task-a')).toBeUndefined(); - expect(await queue.dequeue('task-b')).toBeUndefined(); - }); - }); - - describe('response message error handling', () => { - it('should handle response messages with errors', async () => { - const errorResponse: QueuedMessage = { - type: 'error', - message: { - jsonrpc: '2.0', - id: 1, - error: { - code: -32_600, - message: 'Invalid Request' - } - }, - timestamp: Date.now() - }; - - await queue.enqueue('task-error', errorResponse); - const dequeued = await queue.dequeue('task-error'); - - expect(dequeued).toStrictEqual(errorResponse); - expect(dequeued?.type).toBe('error'); - }); - }); -}); diff --git a/packages/core/test/shared/customMethods.test.ts b/packages/core/test/shared/customMethods.test.ts index 47e02c9bca..df49dd1dbc 100644 --- a/packages/core/test/shared/customMethods.test.ts +++ b/packages/core/test/shared/customMethods.test.ts @@ -2,20 +2,33 @@ import { describe, expect, it } from 'vitest'; import { z } from 'zod/v4'; import { Protocol } from '../../src/shared/protocol.js'; +import { HandlerRegistry } from '../../src/shared/handlerRegistry.js'; import type { BaseContext, JSONRPCRequest, Result, StandardSchemaV1 } from '../../src/exports/public/index.js'; import { ProtocolError } from '../../src/types/index.js'; import { SdkErrorCode } from '../../src/errors/sdkErrors.js'; import { InMemoryTransport } from '../../src/util/inMemory.js'; +function createTestRegistry( + wrapHandler?: ( + method: string, + handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise + ) => (request: JSONRPCRequest, ctx: BaseContext) => Promise +): HandlerRegistry { + return new HandlerRegistry({ + wrapHandler + }); +} + class TestProtocol extends Protocol { protected buildContext(ctx: BaseContext): BaseContext { return ctx; } protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} + + constructor(registry?: HandlerRegistry) { + super(registry ?? createTestRegistry()); + } } async function pair(): Promise<[TestProtocol, TestProtocol]> { @@ -81,18 +94,15 @@ describe('Protocol custom-method support', () => { expect(() => p.setRequestHandler('acme/unknown' as never, () => ({}) as never)).toThrow(TypeError); }); - it('routes both 2-arg and 3-arg registration through _wrapHandler', () => { + it('routes both 2-arg and 3-arg registration through wrapHandler callback', () => { const seen: string[] = []; - class SpyProtocol extends TestProtocol { - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise - ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { - seen.push(method); - return handler; - } - } - const p = new SpyProtocol(); + const registry = createTestRegistry((method, handler) => { + seen.push(method); + return handler; + }); + const p = new TestProtocol(registry); + // Clear entries added by Protocol constructor (ping) + seen.length = 0; p.setRequestHandler('tools/list', () => ({ tools: [] })); p.setRequestHandler('acme/custom', { params: z.object({}) }, () => ({})); expect(seen).toContain('tools/list'); diff --git a/packages/core/test/shared/handlerRegistry.test.ts b/packages/core/test/shared/handlerRegistry.test.ts new file mode 100644 index 0000000000..773bb791f8 --- /dev/null +++ b/packages/core/test/shared/handlerRegistry.test.ts @@ -0,0 +1,93 @@ +import { describe, expect, it, vi } from 'vitest'; +import type { BaseContext } from '../../src/shared/protocol.js'; +import type { RequestHandler } from '../../src/shared/handlerRegistry.js'; +import { HandlerRegistry } from '../../src/shared/handlerRegistry.js'; +import type { JSONRPCRequest, ServerCapabilities } from '../../src/types/index.js'; + +function createRegistry(options?: ConstructorParameters>[0]) { + return new HandlerRegistry(options); +} + +const noopHandler = async () => ({}); + +describe('HandlerRegistry', () => { + it('should register and retrieve a spec request handler', () => { + const registry = createRegistry(); + registry.setRequestHandler('ping', noopHandler); + expect(registry.requestHandlers.has('ping')).toBe(true); + }); + + it('should call assertRequestHandlerCapability callback during registration', () => { + const assertCb = vi.fn(); + const registry = createRegistry({ assertRequestHandlerCapability: assertCb }); + registry.setRequestHandler('ping', noopHandler); + expect(assertCb).toHaveBeenCalledWith('ping'); + }); + + it('should apply wrapHandler callback during registration', () => { + const wrappedHandler: RequestHandler = async () => ({ wrapped: true }); + const wrapCb = vi.fn((_method: string, _handler: RequestHandler) => wrappedHandler); + + const registry = createRegistry({ wrapHandler: wrapCb }); + registry.setRequestHandler('ping', noopHandler); + + expect(wrapCb).toHaveBeenCalledWith('ping', expect.any(Function)); + expect(registry.requestHandlers.get('ping')).toBe(wrappedHandler); + }); + + it('should throw from assertCanSetRequestHandler on duplicate handler', () => { + const registry = createRegistry(); + registry.setRequestHandler('ping', noopHandler); + + expect(() => registry.assertCanSetRequestHandler('ping')).toThrow('A request handler for ping already exists'); + }); + + it('should remove a request handler', () => { + const registry = createRegistry(); + registry.setRequestHandler('ping', noopHandler); + expect(registry.requestHandlers.has('ping')).toBe(true); + + registry.removeRequestHandler('ping'); + expect(registry.requestHandlers.has('ping')).toBe(false); + }); + + it('should merge capabilities via registerCapabilities', () => { + const registry = createRegistry({ capabilities: { tools: {} } }); + registry.registerCapabilities({ logging: {} }); + + const caps = registry.getCapabilities(); + expect(caps.tools).toEqual({}); + expect(caps.logging).toEqual({}); + }); + + it('should register and retrieve a notification handler', () => { + const registry = createRegistry(); + const handler = async () => {}; + registry.setNotificationHandler('notifications/cancelled', handler); + + expect(registry.notificationHandlers.has('notifications/cancelled')).toBe(true); + }); + + it('should remove a notification handler', () => { + const registry = createRegistry(); + registry.setNotificationHandler('notifications/cancelled', async () => {}); + expect(registry.notificationHandlers.has('notifications/cancelled')).toBe(true); + + registry.removeNotificationHandler('notifications/cancelled'); + expect(registry.notificationHandlers.has('notifications/cancelled')).toBe(false); + }); + + it('should store and retrieve fallbackRequestHandler', () => { + const registry = createRegistry(); + const fallback: RequestHandler = async (_req: JSONRPCRequest) => ({ fallback: true }); + + registry.fallbackRequestHandler = fallback; + expect(registry.fallbackRequestHandler).toBe(fallback); + }); + + it('should return initial capabilities via getCapabilities', () => { + const registry = createRegistry({ capabilities: { prompts: {} } }); + const caps = registry.getCapabilities(); + expect(caps.prompts).toEqual({}); + }); +}); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 619e09376a..552980ea26 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -3,68 +3,41 @@ import { vi } from 'vitest'; import * as z from 'zod/v4'; import type { ZodType } from 'zod/v4'; -import type { - QueuedMessage, - QueuedNotification, - QueuedRequest, - TaskMessageQueue, - TaskStore -} from '../../src/experimental/tasks/interfaces.js'; -import { InMemoryTaskMessageQueue } from '../../src/experimental/tasks/stores/inMemory.js'; import type { BaseContext } from '../../src/shared/protocol.js'; -import { mergeCapabilities, Protocol } from '../../src/shared/protocol.js'; -import type { ErrorMessage, ResponseMessage } from '../../src/shared/responseMessage.js'; -import { toArrayAsync } from '../../src/shared/responseMessage.js'; -import type { TaskManagerOptions } from '../../src/shared/taskManager.js'; -import { NullTaskManager, TaskManager } from '../../src/shared/taskManager.js'; +import type { ProtocolOptions } from '../../src/shared/protocol.js'; +import { Protocol } from '../../src/shared/protocol.js'; +import { HandlerRegistry, mergeCapabilities } from '../../src/shared/handlerRegistry.js'; import type { Transport, TransportSendOptions } from '../../src/shared/transport.js'; import type { ClientCapabilities, - JSONRPCErrorResponse, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, - JSONRPCResponse, JSONRPCResultResponse, Notification, Request, RequestId, Result, - ServerCapabilities, - Task, - TaskCreationParams + ServerCapabilities } from '../../src/types/index.js'; -import { ProtocolError, ProtocolErrorCode, RELATED_TASK_META_KEY } from '../../src/types/index.js'; +import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js'; import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; // Test Protocol subclass for testing class TestProtocolImpl extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } -} -function createTestProtocol(taskOptions?: TaskManagerOptions): TestProtocolImpl { - return new TestProtocolImpl(taskOptions ? { tasks: taskOptions } : undefined); + constructor(options?: ProtocolOptions) { + super(new HandlerRegistry(), options); + } } -// Type helper for accessing private/protected Protocol properties in tests -interface TestProtocolInternals { - _responseHandlers: Map void>; - _taskManager: { - _taskMessageQueue?: TaskMessageQueue; - _requestResolvers: Map void>; - _taskProgressTokens: Map; - _clearTaskQueue: (taskId: string, sessionId?: string) => Promise; - listTasks: (params?: { cursor?: string }) => Promise<{ tasks: Task[]; nextCursor?: string }>; - cancelTask: (params: { taskId: string }) => Promise; - requestStream: (request: Request, schema: ZodType, options?: unknown) => AsyncGenerator>; - }; +function createTestProtocol(): TestProtocolImpl { + return new TestProtocolImpl(); } // Mock Transport class @@ -80,95 +53,6 @@ class MockTransport implements Transport { async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise {} } -function createMockTaskStore(options?: { - onStatus?: (status: Task['status']) => void; - onList?: () => void; -}): TaskStore & { [K in keyof TaskStore]: MockInstance } { - const tasks: Record = {}; - return { - createTask: vi.fn((taskParams: TaskCreationParams, _1: RequestId, _2: Request) => { - // Generate a unique task ID - const taskId = `test-task-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; - const createdAt = new Date().toISOString(); - const task = (tasks[taskId] = { - taskId, - status: 'working', - ttl: taskParams.ttl ?? null, - createdAt, - lastUpdatedAt: createdAt, - pollInterval: taskParams.pollInterval ?? 1000 - }); - options?.onStatus?.('working'); - return Promise.resolve(task); - }), - getTask: vi.fn((taskId: string) => { - return Promise.resolve(tasks[taskId] ?? null); - }), - updateTaskStatus: vi.fn((taskId, status, statusMessage) => { - const task = tasks[taskId]; - if (task) { - task.status = status; - task.statusMessage = statusMessage; - options?.onStatus?.(task.status); - } - return Promise.resolve(); - }), - storeTaskResult: vi.fn((taskId: string, status: 'completed' | 'failed', result: Result) => { - const task = tasks[taskId]; - if (task) { - task.status = status; - task.result = result; - options?.onStatus?.(status); - } - return Promise.resolve(); - }), - getTaskResult: vi.fn((taskId: string) => { - const task = tasks[taskId]; - if (task?.result) { - return Promise.resolve(task.result); - } - throw new Error('Task result not found'); - }), - listTasks: vi.fn(() => { - const result = { - tasks: Object.values(tasks) - }; - options?.onList?.(); - return Promise.resolve(result); - }) - }; -} - -function createLatch() { - let latch = false; - const waitForLatch = async () => { - while (!latch) { - await new Promise(resolve => setTimeout(resolve, 0)); - } - }; - - return { - releaseLatch: () => { - latch = true; - }, - waitForLatch - }; -} - -function assertErrorResponse(o: ResponseMessage): asserts o is ErrorMessage { - expect(o.type).toBe('error'); -} - -function assertQueuedNotification(o?: QueuedMessage): asserts o is QueuedNotification { - expect(o).toBeDefined(); - expect(o?.type).toBe('notification'); -} - -function assertQueuedRequest(o?: QueuedMessage): asserts o is QueuedRequest { - expect(o).toBeDefined(); - expect(o?.type).toBe('request'); -} - /** * Helper to call the protected _requestWithSchema method from tests that * use custom method names not present in RequestMethod. @@ -887,97 +771,6 @@ describe('protocol tests', () => { }); }); -describe('InMemoryTaskMessageQueue', () => { - let queue: TaskMessageQueue; - const taskId = 'test-task-id'; - - beforeEach(() => { - queue = new InMemoryTaskMessageQueue(); - }); - - describe('enqueue/dequeue maintains FIFO order', () => { - it('should maintain FIFO order for multiple messages', async () => { - const msg1 = { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test1' }, - timestamp: 1 - }; - const msg2 = { - type: 'request' as const, - message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, - timestamp: 2 - }; - const msg3 = { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test3' }, - timestamp: 3 - }; - - await queue.enqueue(taskId, msg1); - await queue.enqueue(taskId, msg2); - await queue.enqueue(taskId, msg3); - - expect(await queue.dequeue(taskId)).toEqual(msg1); - expect(await queue.dequeue(taskId)).toEqual(msg2); - expect(await queue.dequeue(taskId)).toEqual(msg3); - }); - - it('should return undefined when dequeuing from empty queue', async () => { - expect(await queue.dequeue(taskId)).toBeUndefined(); - }); - }); - - describe('dequeueAll operation', () => { - it('should return all messages in FIFO order', async () => { - const msg1 = { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test1' }, - timestamp: 1 - }; - const msg2 = { - type: 'request' as const, - message: { jsonrpc: '2.0' as const, id: 1, method: 'test2' }, - timestamp: 2 - }; - const msg3 = { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test3' }, - timestamp: 3 - }; - - await queue.enqueue(taskId, msg1); - await queue.enqueue(taskId, msg2); - await queue.enqueue(taskId, msg3); - - const allMessages = await queue.dequeueAll(taskId); - - expect(allMessages).toEqual([msg1, msg2, msg3]); - }); - - it('should return empty array for empty queue', async () => { - const allMessages = await queue.dequeueAll(taskId); - expect(allMessages).toEqual([]); - }); - - it('should clear queue after dequeueAll', async () => { - await queue.enqueue(taskId, { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test1' }, - timestamp: 1 - }); - await queue.enqueue(taskId, { - type: 'notification' as const, - message: { jsonrpc: '2.0' as const, method: 'test2' }, - timestamp: 2 - }); - - await queue.dequeueAll(taskId); - - expect(await queue.dequeue(taskId)).toBeUndefined(); - }); - }); -}); - describe('mergeCapabilities', () => { it('should merge client capabilities', () => { const base: ClientCapabilities = { @@ -1067,4614 +860,3 @@ describe('mergeCapabilities', () => { expect(merged).toEqual({}); }); }); - -describe('Task-based execution', () => { - let protocol: Protocol; - let transport: MockTransport; - let sendSpy: MockInstance; - - beforeEach(() => { - transport = new MockTransport(); - sendSpy = vi.spyOn(transport, 'send'); - protocol = createTestProtocol({ taskStore: createMockTaskStore(), taskMessageQueue: new InMemoryTaskMessageQueue() }); - }); - - describe('request with task metadata', () => { - it('should include task parameters at top level', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 30000, - pollInterval: 1000 - } - }).catch(() => { - // May not complete, ignore error - }); - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'tools/call', - params: { - name: 'test-tool', - task: { - ttl: 30000, - pollInterval: 1000 - } - } - }), - expect.any(Object) - ); - }); - - it('should preserve existing _meta and add task parameters at top level', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { - name: 'test-tool', - _meta: { - customField: 'customValue' - } - } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 60000 - } - }).catch(() => { - // May not complete, ignore error - }); - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - params: { - name: 'test-tool', - _meta: { - customField: 'customValue' - }, - task: { - ttl: 60000 - } - } - }), - expect.any(Object) - ); - }); - - it('should return Promise for task-augmented request', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - const resultPromise = testRequest(protocol, request, resultSchema, { - task: { - ttl: 30000 - } - }); - - expect(resultPromise).toBeDefined(); - expect(resultPromise).toBeInstanceOf(Promise); - }); - }); - - describe('relatedTask metadata', () => { - it('should inject relatedTask metadata into _meta field', async () => { - await protocol.connect(transport); - - const request = { - method: 'notifications/message', - params: { data: 'test' } - }; - - const resultSchema = z.object({}); - - // Start the request (don't await completion, just let it send) - void testRequest(protocol, request, resultSchema, { - relatedTask: { - taskId: 'parent-task-123' - } - }).catch(() => { - // May not complete, ignore error - }); - - // Wait a bit for the request to be queued - await new Promise(resolve => setTimeout(resolve, 10)); - - // Requests with relatedTask should be queued, not sent via transport - // This prevents duplicate delivery for bidirectional transports - expect(sendSpy).not.toHaveBeenCalled(); - - // Verify the message was queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - }); - - it('should work with notification method', async () => { - await protocol.connect(transport); - - await protocol.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }, - { - relatedTask: { - taskId: 'parent-task-456' - } - } - ); - - // Notifications with relatedTask should be queued, not sent via transport - // This prevents duplicate delivery for bidirectional transports - expect(sendSpy).not.toHaveBeenCalled(); - - // Verify the message was queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue('parent-task-456'); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.method).toBe('notifications/message'); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: 'parent-task-456' }); - }); - }); - - describe('task metadata combination', () => { - it('should combine task, relatedTask, and progress metadata', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - // Start the request (don't await completion, just let it send) - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 60000, - pollInterval: 1000 - }, - relatedTask: { - taskId: 'parent-task' - }, - onprogress: vi.fn() - }).catch(() => { - // May not complete, ignore error - }); - - // Wait a bit for the request to be queued - await new Promise(resolve => setTimeout(resolve, 10)); - - // Requests with relatedTask should be queued, not sent via transport - // This prevents duplicate delivery for bidirectional transports - expect(sendSpy).not.toHaveBeenCalled(); - - // Verify the message was queued with all metadata combined - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue('parent-task'); - assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.params).toMatchObject({ - name: 'test-tool', - task: { - ttl: 60000, - pollInterval: 1000 - }, - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task' - }, - progressToken: expect.any(Number) - } - }); - }); - }); - - describe('task status transitions', () => { - it('should not auto-update task status when a task-augmented request completes', async () => { - const mockTaskStore = createMockTaskStore(); - const localProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const localTransport = new MockTransport(); - await localProtocol.connect(localTransport); - - localProtocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'done' }] }; - }); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 42, - method: 'tools/call', - params: { - name: 'test-tool', - arguments: {}, - task: { ttl: 60000, pollInterval: 1000 } - } - }); - - // Allow the request to be processed - await new Promise(resolve => setTimeout(resolve, 20)); - - // The protocol layer must not call updateTaskStatus — that is solely the tool implementor's responsibility - expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); - }); - - it('should handle requests with task creation parameters in top-level task field', async () => { - // This test documents that task creation parameters are now in the top-level task field - // rather than in _meta, and that task management is handled by tool implementors - const mockTaskStore = createMockTaskStore(); - - protocol = createTestProtocol({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - protocol.setRequestHandler('tools/call', async request => { - // Tool implementor can access task creation parameters from request.params.task - expect(request.params.task).toEqual({ - ttl: 60000, - pollInterval: 1000 - }); - return { content: [{ type: 'text', text: 'success' }] }; - }); - - transport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { - name: 'test', - arguments: {}, - task: { - ttl: 60000, - pollInterval: 1000 - } - } - }); - - // Wait for the request to be processed - await new Promise(resolve => setTimeout(resolve, 10)); - }); - }); - - describe('assertTaskHandlerCapability', () => { - it('should invoke assertTaskHandlerCapability when an inbound task-augmented request arrives', async () => { - const localProtocol = createTestProtocol({ taskStore: createMockTaskStore() }); - const spy = vi.spyOn(localProtocol, 'assertTaskHandlerCapability' as never); - const localTransport = new MockTransport(); - await localProtocol.connect(localTransport); - - localProtocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'ok' }] }; - }); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { - name: 'my-tool', - arguments: {}, - task: { ttl: 30000, pollInterval: 500 } - } - }); - - await new Promise(resolve => setTimeout(resolve, 20)); - - expect(spy).toHaveBeenCalledOnce(); - expect(spy).toHaveBeenCalledWith('tools/call'); - }); - - it('should not invoke assertTaskHandlerCapability for non-task-augmented requests', async () => { - const localProtocol = createTestProtocol({ taskStore: createMockTaskStore() }); - const spy = vi.spyOn(localProtocol, 'assertTaskHandlerCapability' as never); - const localTransport = new MockTransport(); - await localProtocol.connect(localTransport); - - localProtocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'ok' }] }; - }); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tools/call', - params: { name: 'my-tool', arguments: {} } - }); - - await new Promise(resolve => setTimeout(resolve, 20)); - - expect(spy).not.toHaveBeenCalled(); - }); - - it('should succeed with default no-op assertTaskHandlerCapability', async () => { - const localProtocol = createTestProtocol({ taskStore: createMockTaskStore() }); - const localTransport = new MockTransport(); - const localSendSpy = vi.spyOn(localTransport, 'send'); - await localProtocol.connect(localTransport); - - localProtocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'ok' }] }; - }); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 3, - method: 'tools/call', - params: { - name: 'my-tool', - arguments: {}, - task: { ttl: 30000, pollInterval: 500 } - } - }); - - await new Promise(resolve => setTimeout(resolve, 20)); - - // The response should be a success, not an error - expect(localSendSpy).toHaveBeenCalledOnce(); - const response = localSendSpy.mock.calls[0]![0] as { error?: unknown }; - expect(response.error).toBeUndefined(); - }); - - it('should send a JSON-RPC error response when assertTaskHandlerCapability throws', async () => { - const localProtocol = createTestProtocol({ taskStore: createMockTaskStore() }); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - vi.spyOn(localProtocol as any, 'assertTaskHandlerCapability').mockImplementation(() => { - throw new Error('Task handler capability not declared'); - }); - const localTransport = new MockTransport(); - const sendSpy = vi.spyOn(localTransport, 'send'); - await localProtocol.connect(localTransport); - - localProtocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'ok' }] }; - }); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 4, - method: 'tools/call', - params: { - name: 'my-tool', - arguments: {}, - task: { ttl: 30000, pollInterval: 500 } - } - }); - - await new Promise(resolve => setTimeout(resolve, 20)); - - // Verify the error was sent back as a JSON-RPC error response (matching main's behavior) - expect(sendSpy).toHaveBeenCalledOnce(); - const response = sendSpy.mock.calls[0]![0] as { error?: { message?: string } }; - expect(response.error).toBeDefined(); - expect(response.error!.message).toBe('Task handler capability not declared'); - }); - }); - - describe('pollInterval fallback in _waitForTaskUpdate', () => { - it('should fall back to defaultTaskPollInterval when task has no pollInterval', async () => { - const mockTaskStore = createMockTaskStore(); - - const task = await mockTaskStore.createTask({ pollInterval: undefined as unknown as number }, 1, { - method: 'test/method', - params: {} - }); - // Override pollInterval to be undefined on the stored task - const storedTask = await mockTaskStore.getTask(task.taskId); - if (storedTask) { - storedTask.pollInterval = undefined as unknown as number; - } - - const localProtocol = createTestProtocol({ - taskStore: mockTaskStore, - defaultTaskPollInterval: 100 - }); - const localTransport = new MockTransport(); - const sendSpy = vi.spyOn(localTransport, 'send'); - await localProtocol.connect(localTransport); - - // Send tasks/result request — task is non-terminal so it will poll - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tasks/result', - params: { taskId: task.taskId } - }); - - // Use a macrotask to complete the task AFTER the handler has entered polling - setTimeout(() => { - mockTaskStore.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'done' }] }); - }, 10); - - // At 50ms the 100ms poll hasn't fired yet - await new Promise(resolve => setTimeout(resolve, 50)); - expect(sendSpy).not.toHaveBeenCalled(); - - // At 200ms the poll should have fired and found the completed task - await new Promise(resolve => setTimeout(resolve, 150)); - expect(sendSpy).toHaveBeenCalled(); - }); - - it('should fall back to 1000ms when both pollInterval and defaultTaskPollInterval are absent', async () => { - const mockTaskStore = createMockTaskStore(); - - const task = await mockTaskStore.createTask({ pollInterval: undefined as unknown as number }, 1, { - method: 'test/method', - params: {} - }); - const storedTask = await mockTaskStore.getTask(task.taskId); - if (storedTask) { - storedTask.pollInterval = undefined as unknown as number; - } - - // No defaultTaskPollInterval — should fall back to 1000ms - const localProtocol = createTestProtocol({ - taskStore: mockTaskStore - }); - const localTransport = new MockTransport(); - const sendSpy = vi.spyOn(localTransport, 'send'); - await localProtocol.connect(localTransport); - - localTransport.onmessage?.({ - jsonrpc: '2.0', - id: 1, - method: 'tasks/result', - params: { taskId: task.taskId } - }); - - // Complete the task via macrotask so the handler enters polling first - setTimeout(() => { - mockTaskStore.storeTaskResult(task.taskId, 'completed', { content: [{ type: 'text', text: 'done' }] }); - }, 10); - - // At 500ms the 1000ms poll hasn't fired yet - await new Promise(resolve => setTimeout(resolve, 500)); - expect(sendSpy).not.toHaveBeenCalled(); - - // At 1100ms the poll should have fired - await new Promise(resolve => setTimeout(resolve, 600)); - expect(sendSpy).toHaveBeenCalled(); - }); - }); - - describe('listTasks', () => { - it('should handle tasks/list requests and return tasks from TaskStore', async () => { - const listedTasks = createLatch(); - const mockTaskStore = createMockTaskStore({ - onList: () => listedTasks.releaseLatch() - }); - const task1 = await mockTaskStore.createTask( - { - pollInterval: 500 - }, - 1, - { - method: 'test/method', - params: {} - } - ); - // Manually set status to completed for this test - await mockTaskStore.updateTaskStatus(task1.taskId, 'completed'); - - const task2 = await mockTaskStore.createTask( - { - ttl: 60000, - pollInterval: 1000 - }, - 2, - { - method: 'test/method', - params: {} - } - ); - - protocol = createTestProtocol({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - // Simulate receiving a tasks/list request - transport.onmessage?.({ - jsonrpc: '2.0', - id: 3, - method: 'tasks/list', - params: {} - }); - - await listedTasks.waitForLatch(); - - expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); - const sentMessage = sendSpy.mock.calls[0]![0]; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(3); - expect(sentMessage.result.tasks).toEqual([ - { - taskId: task1.taskId, - status: 'completed', - ttl: null, - createdAt: expect.any(String), - lastUpdatedAt: expect.any(String), - pollInterval: 500 - }, - { - taskId: task2.taskId, - status: 'working', - ttl: 60000, - createdAt: expect.any(String), - lastUpdatedAt: expect.any(String), - pollInterval: 1000 - } - ]); - expect(sentMessage.result._meta).toEqual({}); - }); - - it('should handle tasks/list requests with cursor for pagination', async () => { - const listedTasks = createLatch(); - const mockTaskStore = createMockTaskStore({ - onList: () => listedTasks.releaseLatch() - }); - const task3 = await mockTaskStore.createTask( - { - pollInterval: 500 - }, - 1, - { - method: 'test/method', - params: {} - } - ); - - protocol = createTestProtocol({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - // Simulate receiving a tasks/list request with cursor - transport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/list', - params: { - cursor: 'task-2' - } - }); - - await listedTasks.waitForLatch(); - - expect(mockTaskStore.listTasks).toHaveBeenCalledWith('task-2', undefined); - const sentMessage = sendSpy.mock.calls[0]![0]; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(2); - expect(sentMessage.result.tasks).toEqual([ - { - taskId: task3.taskId, - status: 'working', - ttl: null, - createdAt: expect.any(String), - lastUpdatedAt: expect.any(String), - pollInterval: 500 - } - ]); - expect(sentMessage.result.nextCursor).toBeUndefined(); - expect(sentMessage.result._meta).toEqual({}); - }); - - it('should handle tasks/list requests with empty results', async () => { - const listedTasks = createLatch(); - const mockTaskStore = createMockTaskStore({ - onList: () => listedTasks.releaseLatch() - }); - - protocol = createTestProtocol({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - // Simulate receiving a tasks/list request - transport.onmessage?.({ - jsonrpc: '2.0', - id: 3, - method: 'tasks/list', - params: {} - }); - - await listedTasks.waitForLatch(); - - expect(mockTaskStore.listTasks).toHaveBeenCalledWith(undefined, undefined); - const sentMessage = sendSpy.mock.calls[0]![0]; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(3); - expect(sentMessage.result.tasks).toEqual([]); - expect(sentMessage.result.nextCursor).toBeUndefined(); - expect(sentMessage.result._meta).toEqual({}); - }); - - it('should return error for invalid cursor', async () => { - const mockTaskStore = createMockTaskStore(); - mockTaskStore.listTasks.mockRejectedValue(new Error('Invalid cursor: bad-cursor')); - - protocol = createTestProtocol({ taskStore: mockTaskStore }); - - await protocol.connect(transport); - - // Simulate receiving a tasks/list request with invalid cursor - transport.onmessage?.({ - jsonrpc: '2.0', - id: 4, - method: 'tasks/list', - params: { - cursor: 'bad-cursor' - } - }); - - await new Promise(resolve => setTimeout(resolve, 10)); - - expect(mockTaskStore.listTasks).toHaveBeenCalledWith('bad-cursor', undefined); - const sentMessage = sendSpy.mock.calls[0]![0]; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(4); - expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code - expect(sentMessage.error.message).toContain('Failed to list tasks'); - expect(sentMessage.error.message).toContain('Invalid cursor'); - }); - - it('should call listTasks method from client side', async () => { - await protocol.connect(transport); - - const listTasksPromise = (protocol as unknown as TestProtocolInternals)._taskManager.listTasks(); - - // Simulate server response - setTimeout(() => { - transport.onmessage?.({ - jsonrpc: '2.0', - id: sendSpy.mock.calls[0]![0].id, - result: { - tasks: [ - { - taskId: 'task-1', - status: 'completed', - ttl: null, - createdAt: '2024-01-01T00:00:00Z', - lastUpdatedAt: '2024-01-01T00:00:00Z', - pollInterval: 500 - } - ], - nextCursor: undefined, - _meta: {} - } - }); - }, 10); - - const result = await listTasksPromise; - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'tasks/list', - params: undefined - }), - expect.any(Object) - ); - expect(result.tasks).toHaveLength(1); - expect(result.tasks[0]?.taskId).toBe('task-1'); - }); - - it('should call listTasks with cursor from client side', async () => { - await protocol.connect(transport); - - const listTasksPromise = (protocol as unknown as TestProtocolInternals)._taskManager.listTasks({ cursor: 'task-10' }); - - // Simulate server response - setTimeout(() => { - transport.onmessage?.({ - jsonrpc: '2.0', - id: sendSpy.mock.calls[0]![0].id, - result: { - tasks: [ - { - taskId: 'task-11', - status: 'working', - ttl: 30000, - createdAt: '2024-01-01T00:00:00Z', - lastUpdatedAt: '2024-01-01T00:00:00Z', - pollInterval: 1000 - } - ], - nextCursor: 'task-11', - _meta: {} - } - }); - }, 10); - - const result = await listTasksPromise; - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'tasks/list', - params: { - cursor: 'task-10' - } - }), - expect.any(Object) - ); - expect(result.tasks).toHaveLength(1); - expect(result.tasks[0]?.taskId).toBe('task-11'); - expect(result.nextCursor).toBe('task-11'); - }); - }); - - describe('cancelTask', () => { - it('should handle tasks/cancel requests and update task status to cancelled', async () => { - const taskDeleted = createLatch(); - const mockTaskStore = createMockTaskStore(); - const task = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - mockTaskStore.getTask.mockResolvedValue(task); - mockTaskStore.updateTaskStatus.mockImplementation(async (taskId: string, status: string) => { - if (taskId === task.taskId && status === 'cancelled') { - taskDeleted.releaseLatch(); - return; - } - throw new Error('Task not found'); - }); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 5, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - - await taskDeleted.waitForLatch(); - - expect(mockTaskStore.getTask).toHaveBeenCalledWith(task.taskId, undefined); - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( - task.taskId, - 'cancelled', - 'Client cancelled task execution.', - undefined - ); - const sentMessage = sendSpy.mock.calls[0]![0] as unknown as JSONRPCResultResponse; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(5); - expect(sentMessage.result._meta).toBeDefined(); - }); - - it('should return error with code -32602 when task does not exist', async () => { - const taskDeleted = createLatch(); - const mockTaskStore = createMockTaskStore(); - - mockTaskStore.getTask.mockResolvedValue(null); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 6, - method: 'tasks/cancel', - params: { - taskId: 'non-existent' - } - }); - - // Wait a bit for the async handler to complete - await new Promise(resolve => setTimeout(resolve, 10)); - taskDeleted.releaseLatch(); - - expect(mockTaskStore.getTask).toHaveBeenCalledWith('non-existent', undefined); - const sentMessage = sendSpy.mock.calls[0]![0] as unknown as JSONRPCErrorResponse; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(6); - expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code - expect(sentMessage.error.message).toContain('Task not found'); - }); - - it('should return error with code -32602 when trying to cancel a task in terminal status', async () => { - const mockTaskStore = createMockTaskStore(); - const completedTask = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - // Set task to completed status - await mockTaskStore.updateTaskStatus(completedTask.taskId, 'completed'); - completedTask.status = 'completed'; - - // Reset the mock so we can check it's not called during cancellation - mockTaskStore.updateTaskStatus.mockClear(); - mockTaskStore.getTask.mockResolvedValue(completedTask); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 7, - method: 'tasks/cancel', - params: { - taskId: completedTask.taskId - } - }); - - // Wait a bit for the async handler to complete - await new Promise(resolve => setTimeout(resolve, 10)); - - expect(mockTaskStore.getTask).toHaveBeenCalledWith(completedTask.taskId, undefined); - expect(mockTaskStore.updateTaskStatus).not.toHaveBeenCalled(); - const sentMessage = sendSpy.mock.calls[0]![0] as unknown as JSONRPCErrorResponse; - expect(sentMessage.jsonrpc).toBe('2.0'); - expect(sentMessage.id).toBe(7); - expect(sentMessage.error).toBeDefined(); - expect(sentMessage.error.code).toBe(-32602); // InvalidParams error code - expect(sentMessage.error.message).toContain('Cannot cancel task in terminal status'); - }); - - it('should call cancelTask method from client side', async () => { - await protocol.connect(transport); - - const deleteTaskPromise = (protocol as unknown as TestProtocolInternals)._taskManager.cancelTask({ taskId: 'task-to-delete' }); - - // Simulate server response - per MCP spec, CancelTaskResult is Result & Task - setTimeout(() => { - transport.onmessage?.({ - jsonrpc: '2.0', - id: sendSpy.mock.calls[0]![0].id, - result: { - _meta: {}, - taskId: 'task-to-delete', - status: 'cancelled', - ttl: 60000, - createdAt: new Date().toISOString(), - lastUpdatedAt: new Date().toISOString() - } - }); - }, 0); - - const result = await deleteTaskPromise; - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'tasks/cancel', - params: { - taskId: 'task-to-delete' - } - }), - expect.any(Object) - ); - expect(result._meta).toBeDefined(); - expect(result.taskId).toBe('task-to-delete'); - expect(result.status).toBe('cancelled'); - }); - }); - - describe('task status notifications', () => { - it('should call getTask after updateTaskStatus to enable notification sending', async () => { - const mockTaskStore = createMockTaskStore(); - - // Create a task first - const task = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - - await serverProtocol.connect(serverTransport); - - // Simulate cancelling the task - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify that updateTaskStatus was called - expect(mockTaskStore.updateTaskStatus).toHaveBeenCalledWith( - task.taskId, - 'cancelled', - 'Client cancelled task execution.', - undefined - ); - - // Verify that getTask was called after updateTaskStatus - // This is done by the RequestTaskStore wrapper to get the updated task for the notification - const getTaskCalls = mockTaskStore.getTask.mock.calls; - const lastGetTaskCall = getTaskCalls[getTaskCalls.length - 1]; - expect(lastGetTaskCall?.[0]).toBe(task.taskId); - }); - }); - - describe('task metadata handling', () => { - it('should NOT include related-task metadata in tasks/get response', async () => { - const mockTaskStore = createMockTaskStore(); - - // Create a task first - const task = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - // Request task status - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/get', - params: { - taskId: task.taskId - } - }); - - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify response does NOT include related-task metadata - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - result: expect.objectContaining({ - taskId: task.taskId, - status: 'working' - }) - }) - ); - - // Verify _meta is not present or doesn't contain RELATED_TASK_META_KEY - const response = sendSpy.mock.calls[0]![0] as { result?: { _meta?: Record } }; - expect(response.result?._meta?.[RELATED_TASK_META_KEY]).toBeUndefined(); - }); - - it('should NOT include related-task metadata in tasks/list response', async () => { - const mockTaskStore = createMockTaskStore(); - - // Create a task first - await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - // Request task list - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/list', - params: {} - }); - - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify response does NOT include related-task metadata - const response = sendSpy.mock.calls[0]![0] as { result?: { _meta?: Record } }; - expect(response.result?._meta).toEqual({}); - }); - - it('should NOT include related-task metadata in tasks/cancel response', async () => { - const mockTaskStore = createMockTaskStore(); - - // Create a task first - const task = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - // Cancel the task - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify response does NOT include related-task metadata - const response = sendSpy.mock.calls[0]![0] as { result?: { _meta?: Record } }; - expect(response.result?._meta).toEqual({}); - }); - - it('should include related-task metadata in tasks/result response', async () => { - const mockTaskStore = createMockTaskStore(); - - // Create a task and complete it - const task = await mockTaskStore.createTask({}, 1, { - method: 'test/method', - params: {} - }); - - const testResult = { - content: [{ type: 'text', text: 'test result' }] - }; - - await mockTaskStore.storeTaskResult(task.taskId, 'completed', testResult); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore }); - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - // Request task result - serverTransport.onmessage?.({ - jsonrpc: '2.0', - id: 2, - method: 'tasks/result', - params: { - taskId: task.taskId - } - }); - - // Wait for async processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify response DOES include related-task metadata - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - result: expect.objectContaining({ - content: testResult.content, - _meta: expect.objectContaining({ - [RELATED_TASK_META_KEY]: { - taskId: task.taskId - } - }) - }) - }) - ); - }); - - it('should propagate related-task metadata to handler sendRequest and sendNotification', async () => { - const mockTaskStore = createMockTaskStore(); - - const serverProtocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - const serverTransport = new MockTransport(); - const sendSpy = vi.spyOn(serverTransport, 'send'); - - await serverProtocol.connect(serverTransport); - - // Set up a handler that uses sendRequest and sendNotification - serverProtocol.setRequestHandler('tools/call', async (_request, ctx) => { - // Send a notification using the ctx.mcpReq.notify - await ctx.mcpReq.notify({ - method: 'notifications/message', - params: { level: 'info', data: 'test' } - }); - - return { - content: [{ type: 'text', text: 'done' }] - }; - }); - - // Send a request with related-task metadata - let handlerPromise: Promise | undefined; - const originalOnMessage = serverTransport.onmessage; - - serverTransport.onmessage = message => { - handlerPromise = Promise.resolve(originalOnMessage?.(message)); - return handlerPromise; - }; - - serverTransport.onmessage({ - jsonrpc: '2.0', - id: 1, - method: 'tools/call', - params: { - name: 'test-tool', - _meta: { - [RELATED_TASK_META_KEY]: { - taskId: 'parent-task-123' - } - } - } - }); - - // Wait for handler to complete - if (handlerPromise) { - await handlerPromise; - } - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify the notification was QUEUED (not sent via transport) - // Messages with relatedTask metadata should be queued for delivery via tasks/result - // to prevent duplicate delivery for bidirectional transports - const queue = (serverProtocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue('parent-task-123'); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.method).toBe('notifications/message'); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ - taskId: 'parent-task-123' - }); - - // Verify the notification was NOT sent via transport (should be queued instead) - const notificationCalls = sendSpy.mock.calls.filter(call => 'method' in call[0] && call[0].method === 'notifications/message'); - expect(notificationCalls).toHaveLength(0); - }); - }); -}); - -describe('Request Cancellation vs Task Cancellation', () => { - let protocol: Protocol; - let transport: MockTransport; - let taskStore: TaskStore; - - beforeEach(() => { - transport = new MockTransport(); - taskStore = createMockTaskStore(); - protocol = createTestProtocol({ taskStore }); - }); - - describe('notifications/cancelled behavior', () => { - test('should abort request handler when notifications/cancelled is received', async () => { - await protocol.connect(transport); - - // Set up a request handler that checks if it was aborted - let wasAborted = false; - protocol.setRequestHandler('ping', async (_request, ctx) => { - // Simulate a long-running operation - await new Promise(resolve => setTimeout(resolve, 100)); - wasAborted = ctx.mcpReq.signal.aborted; - return {}; - }); - - // Simulate an incoming request - const requestId = 123; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: requestId, - method: 'ping', - params: {} - }); - } - - // Wait a bit for the handler to start - await new Promise(resolve => setTimeout(resolve, 10)); - - // Send cancellation notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: requestId, - reason: 'User cancelled' - } - }); - } - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 150)); - - // Verify the request was aborted - expect(wasAborted).toBe(true); - }); - - test('should NOT automatically cancel associated tasks when notifications/cancelled is received', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { - method: 'test/method', - params: {} - }); - - // Send cancellation notification for the request - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: 'req-1', - reason: 'User cancelled' - } - }); - } - - // Wait a bit - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify the task status was NOT changed to cancelled - const updatedTask = await taskStore.getTask(task.taskId); - expect(updatedTask?.status).toBe('working'); - expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'cancelled', expect.any(String)); - }); - }); - - describe('tasks/cancel behavior', () => { - test('should cancel task independently of request cancellation', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { - method: 'test/method', - params: {} - }); - - // Cancel the task using tasks/cancel - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: 999, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - } - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify the task was cancelled - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( - task.taskId, - 'cancelled', - 'Client cancelled task execution.', - undefined - ); - }); - - test('should reject cancellation of terminal tasks', async () => { - await protocol.connect(transport); - const sendSpy = vi.spyOn(transport, 'send'); - - // Create a task and mark it as completed - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { - method: 'test/method', - params: {} - }); - await taskStore.updateTaskStatus(task.taskId, 'completed'); - - // Try to cancel the completed task - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: 999, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - } - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify an error was sent - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: '2.0', - id: 999, - error: expect.objectContaining({ - code: ProtocolErrorCode.InvalidParams, - message: expect.stringContaining('Cannot cancel task in terminal status') - }) - }) - ); - }); - - test('should return error when task not found', async () => { - await protocol.connect(transport); - const sendSpy = vi.spyOn(transport, 'send'); - - // Try to cancel a non-existent task - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: 999, - method: 'tasks/cancel', - params: { - taskId: 'non-existent-task' - } - }); - } - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify an error was sent - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: '2.0', - id: 999, - error: expect.objectContaining({ - code: ProtocolErrorCode.InvalidParams, - message: expect.stringContaining('Task not found') - }) - }) - ); - }); - }); - - describe('separation of concerns', () => { - test('should allow request cancellation without affecting task', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { - method: 'test/method', - params: {} - }); - - // Cancel the request (not the task) - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: 'req-1', - reason: 'User cancelled request' - } - }); - } - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify task is still working - const updatedTask = await taskStore.getTask(task.taskId); - expect(updatedTask?.status).toBe('working'); - }); - - test('should allow task cancellation without affecting request', async () => { - await protocol.connect(transport); - - // Set up a request handler - let requestCompleted = false; - protocol.setRequestHandler('ping', async () => { - await new Promise(resolve => setTimeout(resolve, 50)); - requestCompleted = true; - return {}; - }); - - // Create a task (simulating a long-running tools/call) - const task = await taskStore.createTask({ ttl: 60000 }, 'req-1', { - method: 'tools/call', - params: { name: 'long-running-tool', arguments: {} } - }); - - // Start an unrelated ping request - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: 123, - method: 'ping', - params: {} - }); - } - - // Cancel the task (not the request) - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: 999, - method: 'tasks/cancel', - params: { - taskId: task.taskId - } - }); - } - - // Wait for request to complete - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify request completed normally - expect(requestCompleted).toBe(true); - - // Verify task was cancelled - expect(taskStore.updateTaskStatus).toHaveBeenCalledWith( - task.taskId, - 'cancelled', - 'Client cancelled task execution.', - undefined - ); - }); - }); -}); - -describe('Progress notification support for tasks', () => { - let protocol: Protocol; - let transport: MockTransport; - let sendSpy: MockInstance; - - beforeEach(() => { - transport = new MockTransport(); - sendSpy = vi.spyOn(transport, 'send'); - protocol = createTestProtocol({ taskStore: createMockTaskStore() }); - }); - - it('should maintain progress token association after CreateTaskResult is returned', async () => { - const taskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); - - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - // Start a task-augmented request with progress callback - void testRequest(protocol, request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }).catch(() => { - // May not complete, ignore error - }); - - // Wait a bit for the request to be sent - await new Promise(resolve => setTimeout(resolve, 10)); - - // Get the message ID from the sent request - const sentRequest = sendSpy.mock.calls[0]![0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; - - expect(progressToken).toBe(messageId); - - // Simulate CreateTaskResult response - const taskId = 'test-task-123'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } - - // Wait for response to be processed - await Promise.resolve(); - await Promise.resolve(); - - // Send a progress notification - should still work after CreateTaskResult - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 50, - total: 100 - } - }); - } - - // Wait for notification to be processed - await Promise.resolve(); - - // Verify progress callback was invoked - expect(progressCallback).toHaveBeenCalledWith({ - progress: 50, - total: 100 - }); - }); - - it('should stop progress notifications when task reaches terminal status (completed)', async () => { - const taskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); - - // Set up a request handler that will complete the task - protocol.setRequestHandler('tools/call', async (_request, ctx) => { - if (ctx.task?.store) { - const task = await ctx.task.store.createTask({ ttl: 60000 }); - - // Simulate async work then complete the task - const taskStore = ctx.task.store; - setTimeout(async () => { - await taskStore.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: 'Done' }] - }); - }, 50); - - return { task }; - } - return { content: [] }; - }); - - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - // Start a task-augmented request with progress callback - void testRequest(protocol, request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }).catch(() => { - // May not complete, ignore error - }); - - // Wait a bit for the request to be sent - await new Promise(resolve => setTimeout(resolve, 10)); - - const sentRequest = sendSpy.mock.calls[0]![0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; - - // Create a task in the mock store first so it exists when we try to get it later - const createdTask = await taskStore.createTask({ ttl: 60000 }, messageId, request); - const taskId = createdTask.taskId; - - // Simulate CreateTaskResult response - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: createdTask - } - }); - } - - await Promise.resolve(); - await Promise.resolve(); - - // Progress notification should work while task is working - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 50, - total: 100 - } - }); - } - - await Promise.resolve(); - - expect(progressCallback).toHaveBeenCalledTimes(1); - - // Verify the task-progress association was created - const taskProgressTokens = (protocol as unknown as TestProtocolInternals)._taskManager._taskProgressTokens as Map; - expect(taskProgressTokens.has(taskId)).toBe(true); - expect(taskProgressTokens.get(taskId)).toBe(progressToken); - - // Simulate task completion by triggering an inbound request whose handler - // calls storeTaskResult through the task context (the public RequestTaskStore API). - // This is equivalent to how a real server handler would complete a task. - protocol.setRequestHandler('ping', async (_request, ctx) => { - if (ctx.task?.store) { - await ctx.task.store.storeTaskResult(taskId, 'completed', { content: [] }); - } - return {}; - }); - if (transport.onmessage) { - transport.onmessage({ jsonrpc: '2.0', id: 999, method: 'ping', params: {} }); - } - - // Wait for all async operations including notification sending to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the association was cleaned up - expect(taskProgressTokens.has(taskId)).toBe(false); - - // Try to send progress notification after task completion - should be ignored - progressCallback.mockClear(); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 100, - total: 100 - } - }); - } - - await Promise.resolve(); - - // Progress callback should NOT be invoked after task completion - expect(progressCallback).not.toHaveBeenCalled(); - }); - - it('should stop progress notifications when task reaches terminal status (failed)', async () => { - const taskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); - - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - void testRequest(protocol, request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); - - const sentRequest = sendSpy.mock.calls[0]![0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; - - // Simulate CreateTaskResult response - const taskId = 'test-task-456'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Simulate task failure via storeTaskResult - await taskStore.storeTaskResult(taskId, 'failed', { - content: [], - isError: true - }); - - // Manually trigger the status notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - taskId, - status: 'failed', - ttl: 60000, - createdAt: new Date().toISOString(), - lastUpdatedAt: new Date().toISOString(), - statusMessage: 'Task failed' - } - }); - } - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Try to send progress notification after task failure - should be ignored - progressCallback.mockClear(); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 75, - total: 100 - } - }); - } - - expect(progressCallback).not.toHaveBeenCalled(); - }); - - it('should stop progress notifications when task is cancelled', async () => { - const taskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); - - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - void testRequest(protocol, request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); - - const sentRequest = sendSpy.mock.calls[0]![0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; - - // Simulate CreateTaskResult response - const taskId = 'test-task-789'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Simulate task cancellation via updateTaskStatus - await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); - - // Manually trigger the status notification - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - taskId, - status: 'cancelled', - ttl: 60000, - createdAt: new Date().toISOString(), - lastUpdatedAt: new Date().toISOString(), - statusMessage: 'User cancelled' - } - }); - } - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Try to send progress notification after cancellation - should be ignored - progressCallback.mockClear(); - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 25, - total: 100 - } - }); - } - - expect(progressCallback).not.toHaveBeenCalled(); - }); - - it('should use the same progressToken throughout task lifetime', async () => { - const taskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore }); - - const transport = new MockTransport(); - const sendSpy = vi.spyOn(transport, 'send'); - await protocol.connect(transport); - - const progressCallback = vi.fn(); - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - void testRequest(protocol, request, resultSchema, { - task: { ttl: 60000 }, - onprogress: progressCallback - }); - - const sentRequest = sendSpy.mock.calls[0]![0] as { id: number; params: { _meta: { progressToken: number } } }; - const messageId = sentRequest.id; - const progressToken = sentRequest.params._meta.progressToken; - - // Simulate CreateTaskResult response - const taskId = 'test-task-consistency'; - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } - }); - } - - await Promise.resolve(); - await Promise.resolve(); - - // Send multiple progress notifications with the same token - const progressUpdates = [ - { progress: 25, total: 100 }, - { progress: 50, total: 100 }, - { progress: 75, total: 100 } - ]; - - for (const update of progressUpdates) { - if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, // Same token for all notifications - ...update - } - }); - } - await Promise.resolve(); - } - - // Verify all progress notifications were received with the same token - expect(progressCallback).toHaveBeenCalledTimes(3); - expect(progressCallback).toHaveBeenNthCalledWith(1, { progress: 25, total: 100 }); - expect(progressCallback).toHaveBeenNthCalledWith(2, { progress: 50, total: 100 }); - expect(progressCallback).toHaveBeenNthCalledWith(3, { progress: 75, total: 100 }); - }); - - it('should maintain progressToken throughout task lifetime', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'long-running-tool' } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - const onProgressMock = vi.fn(); - - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 60000 - }, - onprogress: onProgressMock - }); - - const sentMessage = sendSpy.mock.calls[0]![0]; - expect(sentMessage.params._meta.progressToken).toBeDefined(); - }); - - it('should support progress notifications with task-augmented requests', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - content: z.array(z.object({ type: z.literal('text'), text: z.string() })) - }); - - const onProgressMock = vi.fn(); - - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 30000 - }, - onprogress: onProgressMock - }); - - const sentMessage = sendSpy.mock.calls[0]![0]; - const progressToken = sentMessage.params._meta.progressToken; - - // Simulate progress notification - transport.onmessage?.({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 50, - total: 100, - message: 'Processing...' - } - }); - - await new Promise(resolve => setTimeout(resolve, 10)); - - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 50, - total: 100, - message: 'Processing...' - }); - }); - - it('should continue progress notifications after CreateTaskResult', async () => { - await protocol.connect(transport); - - const request = { - method: 'tools/call', - params: { name: 'test-tool' } - }; - - const resultSchema = z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.number().nullable(), - createdAt: z.string() - }) - }); - - const onProgressMock = vi.fn(); - - void testRequest(protocol, request, resultSchema, { - task: { - ttl: 30000 - }, - onprogress: onProgressMock - }); - - const sentMessage = sendSpy.mock.calls[0]![0]; - const progressToken = sentMessage.params._meta.progressToken; - - // Simulate CreateTaskResult response - setTimeout(() => { - transport.onmessage?.({ - jsonrpc: '2.0', - id: sentMessage.id, - result: { - task: { - taskId: 'task-123', - status: 'working', - ttl: 30000, - createdAt: new Date().toISOString() - } - } - }); - }, 5); - - // Progress notifications should still work - setTimeout(() => { - transport.onmessage?.({ - jsonrpc: '2.0', - method: 'notifications/progress', - params: { - progressToken, - progress: 75, - total: 100 - } - }); - }, 10); - - await new Promise(resolve => setTimeout(resolve, 20)); - - expect(onProgressMock).toHaveBeenCalledWith({ - progress: 75, - total: 100 - }); - }); -}); - -describe('Capability negotiation for tasks', () => { - it('should use empty objects for capability fields', () => { - const serverCapabilities = { - tasks: { - list: {}, - cancel: {}, - requests: { - tools: { - call: {} - } - } - } - }; - - expect(serverCapabilities.tasks.list).toEqual({}); - expect(serverCapabilities.tasks.cancel).toEqual({}); - expect(serverCapabilities.tasks.requests.tools.call).toEqual({}); - }); - - it('should include list and cancel in server capabilities', () => { - const serverCapabilities = { - tasks: { - list: {}, - cancel: {} - } - }; - - expect('list' in serverCapabilities.tasks).toBe(true); - expect('cancel' in serverCapabilities.tasks).toBe(true); - }); - - it('should include list and cancel in client capabilities', () => { - const clientCapabilities = { - tasks: { - list: {}, - cancel: {} - } - }; - - expect('list' in clientCapabilities.tasks).toBe(true); - expect('cancel' in clientCapabilities.tasks).toBe(true); - }); -}); - -describe('Message interception for task-related notifications', () => { - it('should queue notifications with io.modelcontextprotocol/related-task metadata', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send a notification with related task metadata - await server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }, - { - relatedTask: { taskId: task.taskId } - } - ); - - // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue(task.taskId); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.method).toBe('notifications/message'); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); - }); - - it('should not queue notifications without related-task metadata', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Send a notification without related task metadata - await server.notification({ - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }); - - // Verify message was not queued (notification without metadata goes through transport) - // We can't directly check the queue, but we know it wasn't queued because - // notifications without relatedTask metadata are sent via transport, not queued - }); - - // Test removed: _taskResultWaiters was removed in favor of polling-based task updates - // The functionality is still tested through integration tests that verify message queuing works - - it('should propagate queue overflow errors without failing the task', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Fill the queue to max capacity (100 messages) - for (let i = 0; i < 100; i++) { - await server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: `message ${i}` } - }, - { - relatedTask: { taskId: task.taskId } - } - ); - } - - // Try to add one more message - should throw an error - await expect( - server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'overflow message' } - }, - { - relatedTask: { taskId: task.taskId } - } - ) - ).rejects.toThrow('overflow'); - - // Verify the task was NOT automatically failed by the Protocol - // (implementations can choose to fail tasks on overflow if they want) - expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); - }); - - it('should extract task ID correctly from metadata', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - const taskId = 'custom-task-id-123'; - - // Send a notification with custom task ID - await server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }, - { - relatedTask: { taskId } - } - ); - - // Verify the message was queued under the correct task ID - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - }); - - it('should preserve message order when queuing multiple notifications', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send multiple notifications - for (let i = 0; i < 5; i++) { - await server.notification( - { - method: 'notifications/message', - params: { level: 'info', data: `message ${i}` } - }, - { - relatedTask: { taskId: task.taskId } - } - ); - } - - // Verify messages are in FIFO order - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - for (let i = 0; i < 5; i++) { - const queuedMessage = await queue!.dequeue(task.taskId); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.params!.data).toBe(`message ${i}`); - } - }); -}); - -describe('Message interception for task-related requests', () => { - it('should queue requests with io.modelcontextprotocol/related-task metadata', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task first - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send a request with related task metadata (don't await - we're testing queuing) - const requestPromise = testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({}), - { - relatedTask: { taskId: task.taskId } - } - ); - - // Access the private queue to verify the message was queued - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue(task.taskId); - assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.method).toBe('ping'); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ taskId: task.taskId }); - - // Verify resolver is stored in _requestResolvers map (not in the message) - const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - const resolvers = (server as unknown as TestProtocolInternals)._taskManager._requestResolvers; - expect(resolvers.has(requestId)).toBe(true); - - // Clean up - send a response to prevent hanging promise - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - result: {} - }); - - await requestPromise; - }); - - it('should not queue requests without related-task metadata', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Send a request without related task metadata - const requestPromise = testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({}) - ); - - // Verify queue exists (but we don't track size in the new API) - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Clean up - send a response - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: {} - }); - - await requestPromise; - }); - - // Test removed: _taskResultWaiters was removed in favor of polling-based task updates - // The functionality is still tested through integration tests that verify message queuing works - - it('should store request resolver for response routing', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send a request with related task metadata - const requestPromise = testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({}), - { - relatedTask: { taskId: task.taskId } - } - ); - - // Verify the resolver was stored - const resolvers = (server as unknown as TestProtocolInternals)._taskManager._requestResolvers; - expect(resolvers.size).toBe(1); - - // Get the request ID from the queue - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - - expect(resolvers.has(requestId)).toBe(true); - - // Send a response to trigger resolver - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - result: {} - }); - - await requestPromise; - - // Verify resolver was cleaned up after response - expect(resolvers.has(requestId)).toBe(false); - }); - - it('should route responses to side-channeled requests', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const queue = new InMemoryTaskMessageQueue(); - const server = createTestProtocol({ taskStore, taskMessageQueue: queue }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send a request with related task metadata - const requestPromise = testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({ message: z.string() }), - { - relatedTask: { taskId: task.taskId } - } - ); - - // Get the request ID from the queue - const queuedMessage = await queue.dequeue(task.taskId); - const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - - // Enqueue a response message to the queue (simulating client sending response back) - await queue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: requestId, - result: { message: 'pong' } - }, - timestamp: Date.now() - }); - - // Simulate a client calling tasks/result which will process the response - // This is done by creating a mock request handler that will trigger the GetTaskPayloadRequest handler - const mockRequestId = 999; - transport.onmessage?.({ - jsonrpc: '2.0', - id: mockRequestId, - method: 'tasks/result', - params: { taskId: task.taskId } - }); - - // Wait for the response to be processed - await new Promise(resolve => setTimeout(resolve, 50)); - - // Mark task as completed - await taskStore.updateTaskStatus(task.taskId, 'completed'); - await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); - - // Verify the response was routed correctly - const result = await requestPromise; - expect(result).toEqual({ message: 'pong' }); - }); - - it('should log error when resolver is missing for side-channeled request', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - - const errors: Error[] = []; - server.onerror = (error: Error) => { - errors.push(error); - }; - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Send a request with related task metadata - void testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({ message: z.string() }), - { - relatedTask: { taskId: task.taskId } - } - ); - - // Get the request ID from the queue - const queue = (server as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const queuedMessage = await queue!.dequeue(task.taskId); - const requestId = (queuedMessage!.message as JSONRPCRequest).id as RequestId; - - // Manually delete the resolver to simulate missing resolver - (server as unknown as TestProtocolInternals)._taskManager._requestResolvers.delete(requestId); - - // Enqueue a response message - this should trigger the error logging when processed - await queue!.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: requestId, - result: { message: 'pong' } - }, - timestamp: Date.now() - }); - - // Simulate a client calling tasks/result which will process the response - const mockRequestId = 888; - transport.onmessage?.({ - jsonrpc: '2.0', - id: mockRequestId, - method: 'tasks/result', - params: { taskId: task.taskId } - }); - - // Wait for the response to be processed - await new Promise(resolve => setTimeout(resolve, 50)); - - // Mark task as completed - await taskStore.updateTaskStatus(task.taskId, 'completed'); - await taskStore.storeTaskResult(task.taskId, 'completed', { _meta: {} }); - - // Wait a bit more for error to be logged - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify error was logged - expect(errors.length).toBeGreaterThanOrEqual(1); - expect(errors.some(e => e.message.includes('Response handler missing for request'))).toBe(true); - }); - - it('should propagate queue overflow errors for requests without failing the task', async () => { - const taskStore = createMockTaskStore(); - const transport = new MockTransport(); - const server = createTestProtocol({ taskStore, taskMessageQueue: new InMemoryTaskMessageQueue(), maxTaskQueueSize: 100 }); - - await server.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 'test-request-1', { method: 'tools/call', params: {} }); - - // Fill the queue to max capacity (100 messages) - const promises: Promise[] = []; - for (let i = 0; i < 100; i++) { - const promise = testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({}), - { - relatedTask: { taskId: task.taskId } - } - ).catch(() => { - // Requests will remain pending until task completes or fails - }); - promises.push(promise); - } - - // Try to add one more request - should throw an error - await expect( - testRequest( - server, - { - method: 'ping', - params: {} - }, - z.object({}), - { - relatedTask: { taskId: task.taskId } - } - ) - ).rejects.toThrow('overflow'); - - // Verify the task was NOT automatically failed by the Protocol - // (implementations can choose to fail tasks on overflow if they want) - expect(taskStore.updateTaskStatus).not.toHaveBeenCalledWith(task.taskId, 'failed', expect.anything(), expect.anything()); - }); -}); - -describe('Message Interception', () => { - let protocol: Protocol; - let transport: MockTransport; - let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; - - beforeEach(() => { - transport = new MockTransport(); - mockTaskStore = createMockTaskStore(); - protocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - }); - - describe('messages with relatedTask metadata are queued', () => { - it('should queue notifications with relatedTask metadata', async () => { - await protocol.connect(transport); - - // Send a notification with relatedTask metadata - await protocol.notification( - { - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }, - { - relatedTask: { - taskId: 'task-123' - } - } - ); - - // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue('task-123'); - assertQueuedNotification(queuedMessage); - expect(queuedMessage!.message.method).toBe('notifications/message'); - }); - - it('should queue requests with relatedTask metadata', async () => { - await protocol.connect(transport); - - const mockSchema = z.object({ result: z.string() }); - - // Send a request with relatedTask metadata - const requestPromise = testRequest( - protocol, - { - method: 'test/request', - params: { data: 'test' } - }, - mockSchema, - { - relatedTask: { - taskId: 'task-456' - } - } - ); - - // Access the private _taskMessageQueue to verify the message was queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue('task-456'); - assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.method).toBe('test/request'); - - // Verify resolver is stored in _requestResolvers map (not in the message) - const requestId = queuedMessage.message.id as RequestId; - const resolvers = (protocol as unknown as TestProtocolInternals)._taskManager._requestResolvers; - expect(resolvers.has(requestId)).toBe(true); - - // Clean up the pending request - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - result: { result: 'success' } - }); - await requestPromise; - }); - }); - - describe('server queues responses/errors for task-related requests', () => { - it('should queue response when handling a request with relatedTask metadata', async () => { - await protocol.connect(transport); - - // Set up a request handler that returns a result - protocol.setRequestHandler('ping', async () => { - return {}; - }); - - // Simulate an incoming request with relatedTask metadata - const requestId = 456; - const taskId = 'task-response-test'; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - method: 'ping', - params: { - _meta: { - 'io.modelcontextprotocol/related-task': { taskId } - } - } - }); - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the response was queued instead of sent directly - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.type).toBe('response'); - if (queuedMessage!.type === 'response') { - expect(queuedMessage!.message.id).toBe(requestId); - expect(queuedMessage!.message.result).toEqual({}); - } - }); - - it('should queue error when handling a request with relatedTask metadata that throws', async () => { - await protocol.connect(transport); - - // Set up a request handler that throws an error - protocol.setRequestHandler('ping', async () => { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'Test error message'); - }); - - // Simulate an incoming request with relatedTask metadata - const requestId = 789; - const taskId = 'task-error-test'; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - method: 'ping', - params: { - _meta: { - 'io.modelcontextprotocol/related-task': { taskId } - } - } - }); - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the error was queued instead of sent directly - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.type).toBe('error'); - if (queuedMessage!.type === 'error') { - expect(queuedMessage!.message.id).toBe(requestId); - expect(queuedMessage!.message.error.code).toBe(ProtocolErrorCode.InternalError); - expect(queuedMessage!.message.error.message).toContain('Test error message'); - } - }); - - it('should queue MethodNotFound error for unknown method with relatedTask metadata', async () => { - await protocol.connect(transport); - - // Simulate an incoming request for unknown method with relatedTask metadata - const requestId = 101; - const taskId = 'task-not-found-test'; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - method: 'unknown/method', - params: { - _meta: { - 'io.modelcontextprotocol/related-task': { taskId } - } - } - }); - - // Wait for processing - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the error was queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - const queuedMessage = await queue!.dequeue(taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage!.type).toBe('error'); - if (queuedMessage!.type === 'error') { - expect(queuedMessage!.message.id).toBe(requestId); - expect(queuedMessage!.message.error.code).toBe(ProtocolErrorCode.MethodNotFound); - } - }); - - it('should send response normally when request has no relatedTask metadata', async () => { - await protocol.connect(transport); - const sendSpy = vi.spyOn(transport, 'send'); - - // Set up a request handler - protocol.setRequestHandler('tools/call', async () => { - return { content: [{ type: 'text', text: 'done' }] }; - }); - - // Simulate an incoming request WITHOUT relatedTask metadata - const requestId = 202; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - method: 'tools/call', - params: { name: 'test-tool' } - }); - - // Wait for the handler to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the response was sent through transport, not queued - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: '2.0', - id: requestId, - result: { content: [{ type: 'text', text: 'done' }] } - }) - ); - }); - }); - - describe('messages without metadata bypass the queue', () => { - it('should not queue notifications without relatedTask metadata', async () => { - await protocol.connect(transport); - - // Send a notification without relatedTask metadata - await protocol.notification({ - method: 'notifications/message', - params: { level: 'info', data: 'test message' } - }); - - // Access the private _taskMessageQueue to verify no messages were queued - // Since we can't check if queues exist without messages, we verify that - // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - }); - - it('should not queue requests without relatedTask metadata', async () => { - await protocol.connect(transport); - - const mockSchema = z.object({ result: z.string() }); - const sendSpy = vi.spyOn(transport, 'send'); - - // Send a request without relatedTask metadata - const requestPromise = testRequest( - protocol, - { - method: 'test/request', - params: { data: 'test' } - }, - mockSchema - ); - - // Access the private _taskMessageQueue to verify no messages were queued - // Since we can't check if queues exist without messages, we verify that - // attempting to dequeue returns undefined (no messages queued) - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Clean up the pending request - const requestId = (sendSpy.mock.calls[0]![0] as JSONRPCResultResponse).id; - transport.onmessage?.({ - jsonrpc: '2.0', - id: requestId, - result: { result: 'success' } - }); - await requestPromise; - }); - }); - - describe('task ID extraction from metadata', () => { - it('should extract correct task ID from relatedTask metadata for notifications', async () => { - await protocol.connect(transport); - - const taskId = 'extracted-task-789'; - - // Send a notification with relatedTask metadata - await protocol.notification( - { - method: 'notifications/message', - params: { data: 'test' } - }, - { - relatedTask: { - taskId: taskId - } - } - ); - - // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Verify a message was queued for this task - const queuedMessage = await queue!.dequeue(taskId); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.method).toBe('notifications/message'); - }); - - it('should extract correct task ID from relatedTask metadata for requests', async () => { - await protocol.connect(transport); - - const taskId = 'extracted-task-999'; - const mockSchema = z.object({ result: z.string() }); - - // Send a request with relatedTask metadata - const requestPromise = testRequest( - protocol, - { - method: 'test/request', - params: { data: 'test' } - }, - mockSchema, - { - relatedTask: { - taskId: taskId - } - } - ); - - // Verify the message was queued under the correct task ID - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Clean up the pending request - const queuedMessage = await queue!.dequeue(taskId); - assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.method).toBe('test/request'); - transport.onmessage?.({ - jsonrpc: '2.0', - id: queuedMessage.message.id, - result: { result: 'success' } - }); - await requestPromise; - }); - - it('should handle multiple messages for different task IDs', async () => { - await protocol.connect(transport); - - // Send messages for different tasks - await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-A' } }); - await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-B' } }); - await protocol.notification({ method: 'test3', params: {} }, { relatedTask: { taskId: 'task-A' } }); - - // Verify messages are queued under correct task IDs - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Verify two messages for task-A - const msg1A = await queue!.dequeue('task-A'); - const msg2A = await queue!.dequeue('task-A'); - const msg3A = await queue!.dequeue('task-A'); // Should be undefined - expect(msg1A).toBeDefined(); - expect(msg2A).toBeDefined(); - expect(msg3A).toBeUndefined(); - - // Verify one message for task-B - const msg1B = await queue!.dequeue('task-B'); - const msg2B = await queue!.dequeue('task-B'); // Should be undefined - expect(msg1B).toBeDefined(); - expect(msg2B).toBeUndefined(); - }); - }); - - describe('queue creation on first message', () => { - it('should queue messages for a task', async () => { - await protocol.connect(transport); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Send first message for a task - await protocol.notification({ method: 'test', params: {} }, { relatedTask: { taskId: 'new-task' } }); - - // Verify message was queued - const msg = await queue!.dequeue('new-task'); - assertQueuedNotification(msg); - expect(msg.message.method).toBe('test'); - }); - - it('should queue multiple messages for the same task', async () => { - await protocol.connect(transport); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Send first message - await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); - - // Send second message - await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'reuse-task' } }); - - // Verify both messages were queued in order - const msg1 = await queue!.dequeue('reuse-task'); - const msg2 = await queue!.dequeue('reuse-task'); - assertQueuedNotification(msg1); - expect(msg1.message.method).toBe('test1'); - assertQueuedNotification(msg2); - expect(msg2.message.method).toBe('test2'); - }); - - it('should queue messages for different tasks separately', async () => { - await protocol.connect(transport); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Send messages for different tasks - await protocol.notification({ method: 'test1', params: {} }, { relatedTask: { taskId: 'task-1' } }); - await protocol.notification({ method: 'test2', params: {} }, { relatedTask: { taskId: 'task-2' } }); - - // Verify messages are queued separately - const msg1 = await queue!.dequeue('task-1'); - const msg2 = await queue!.dequeue('task-2'); - assertQueuedNotification(msg1); - expect(msg1?.message.method).toBe('test1'); - assertQueuedNotification(msg2); - expect(msg2?.message.method).toBe('test2'); - }); - }); - - describe('metadata preservation in queued messages', () => { - it('should preserve relatedTask metadata in queued notification', async () => { - await protocol.connect(transport); - - const relatedTask = { taskId: 'task-meta-123' }; - - await protocol.notification( - { - method: 'test/notification', - params: { data: 'test' } - }, - { relatedTask } - ); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const queuedMessage = await queue!.dequeue('task-meta-123'); - - // Verify the metadata is preserved in the queued message - expect(queuedMessage).toBeDefined(); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.params!._meta).toBeDefined(); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); - }); - - it('should preserve relatedTask metadata in queued request', async () => { - await protocol.connect(transport); - - const relatedTask = { taskId: 'task-meta-456' }; - const mockSchema = z.object({ result: z.string() }); - - const requestPromise = testRequest( - protocol, - { - method: 'test/request', - params: { data: 'test' } - }, - mockSchema, - { relatedTask } - ); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const queuedMessage = await queue!.dequeue('task-meta-456'); - - // Verify the metadata is preserved in the queued message - expect(queuedMessage).toBeDefined(); - assertQueuedRequest(queuedMessage); - expect(queuedMessage.message.params!._meta).toBeDefined(); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual(relatedTask); - - // Clean up - transport.onmessage?.({ - jsonrpc: '2.0', - id: (queuedMessage!.message as JSONRPCRequest).id, - result: { result: 'success' } - }); - await requestPromise; - }); - - it('should preserve existing _meta fields when adding relatedTask', async () => { - await protocol.connect(transport); - - await protocol.notification( - { - method: 'test/notification', - params: { - data: 'test', - _meta: { - customField: 'customValue', - anotherField: 123 - } - } - }, - { - relatedTask: { taskId: 'task-preserve-meta' } - } - ); - - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const queuedMessage = await queue!.dequeue('task-preserve-meta'); - - // Verify both existing and new metadata are preserved - expect(queuedMessage).toBeDefined(); - assertQueuedNotification(queuedMessage); - expect(queuedMessage.message.params!._meta!.customField).toBe('customValue'); - expect(queuedMessage.message.params!._meta!.anotherField).toBe(123); - expect(queuedMessage.message.params!._meta![RELATED_TASK_META_KEY]).toEqual({ - taskId: 'task-preserve-meta' - }); - }); - }); -}); - -describe('Queue lifecycle management', () => { - let protocol: Protocol; - let transport: MockTransport; - let mockTaskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; - - beforeEach(() => { - transport = new MockTransport(); - mockTaskStore = createMockTaskStore(); - protocol = createTestProtocol({ taskStore: mockTaskStore, taskMessageQueue: new InMemoryTaskMessageQueue() }); - }); - - describe('queue cleanup on task completion', () => { - it('should clear queue when task reaches completed status', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue some messages for the task - await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); - await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); - - // Verify messages are queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Verify messages can be dequeued - const msg1 = await queue!.dequeue(taskId); - const msg2 = await queue!.dequeue(taskId); - expect(msg1).toBeDefined(); - expect(msg2).toBeDefined(); - - // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocolInternals)._taskManager._clearTaskQueue(taskId); - - // After cleanup, no more messages should be available - const msg3 = await queue!.dequeue(taskId); - expect(msg3).toBeUndefined(); - }); - - it('should clear queue after delivering messages on tasks/result for completed task', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue a message - await protocol.notification({ method: 'test/notification', params: { data: 'test' } }, { relatedTask: { taskId } }); - - // Mark task as completed - const completedTask = { ...task, status: 'completed' as const }; - mockTaskStore.getTask.mockResolvedValue(completedTask); - mockTaskStore.getTaskResult.mockResolvedValue({ content: [{ type: 'text', text: 'done' }] }); - - // Simulate tasks/result request - const resultPromise = new Promise(resolve => { - transport.onmessage?.({ - jsonrpc: '2.0', - id: 100, - method: 'tasks/result', - params: { taskId } - }); - setTimeout(resolve, 50); - }); - - await resultPromise; - - // Verify queue is cleared after delivery (no messages available) - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const msg = await queue!.dequeue(taskId); - expect(msg).toBeUndefined(); - }); - }); - - describe('queue cleanup on task cancellation', () => { - it('should clear queue when task is cancelled', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue some messages - await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); - - // Verify message is queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - const msg1 = await queue!.dequeue(taskId); - expect(msg1).toBeDefined(); - - // Re-queue the message for cancellation test - await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); - - // Mock task as non-terminal - mockTaskStore.getTask.mockResolvedValue(task); - - // Cancel the task - transport.onmessage?.({ - jsonrpc: '2.0', - id: 200, - method: 'tasks/cancel', - params: { taskId } - }); - - // Wait for cancellation to process - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify queue is cleared (no messages available) - const msg2 = await queue!.dequeue(taskId); - expect(msg2).toBeUndefined(); - }); - - it('should reject pending request resolvers when task is cancelled', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue a request (catch rejection to avoid unhandled promise rejection) - const requestPromise = testRequest( - protocol, - { method: 'test/request', params: { data: 'test' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - // Verify request is queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Mock task as non-terminal - mockTaskStore.getTask.mockResolvedValue(task); - - // Cancel the task - transport.onmessage?.({ - jsonrpc: '2.0', - id: 201, - method: 'tasks/cancel', - params: { taskId } - }); - - // Wait for cancellation to process - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the request promise is rejected - const result = (await requestPromise) as Error; - expect(result).toBeInstanceOf(ProtocolError); - expect(result.message).toContain('Task cancelled or completed'); - - // Verify queue is cleared (no messages available) - const msg = await queue!.dequeue(taskId); - expect(msg).toBeUndefined(); - }); - }); - - describe('queue cleanup on task failure', () => { - it('should clear queue when task reaches failed status', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue some messages - await protocol.notification({ method: 'test/notification', params: { data: 'test1' } }, { relatedTask: { taskId } }); - await protocol.notification({ method: 'test/notification', params: { data: 'test2' } }, { relatedTask: { taskId } }); - - // Verify messages are queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Verify messages can be dequeued - const msg1 = await queue!.dequeue(taskId); - const msg2 = await queue!.dequeue(taskId); - expect(msg1).toBeDefined(); - expect(msg2).toBeDefined(); - - // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocolInternals)._taskManager._clearTaskQueue(taskId); - - // After cleanup, no more messages should be available - const msg3 = await queue!.dequeue(taskId); - expect(msg3).toBeUndefined(); - }); - - it('should reject pending request resolvers when task fails', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue a request (catch the rejection to avoid unhandled promise rejection) - const requestPromise = testRequest( - protocol, - { method: 'test/request', params: { data: 'test' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - // Verify request is queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocolInternals)._taskManager._clearTaskQueue(taskId); - - // Verify the request promise is rejected - const result = (await requestPromise) as Error; - expect(result).toBeInstanceOf(ProtocolError); - expect(result.message).toContain('Task cancelled or completed'); - - // Verify queue is cleared (no messages available) - const msg = await queue!.dequeue(taskId); - expect(msg).toBeUndefined(); - }); - }); - - describe('resolver rejection on cleanup', () => { - it('should reject all pending request resolvers when queue is cleared', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue multiple requests (catch rejections to avoid unhandled promise rejections) - const request1Promise = testRequest( - protocol, - { method: 'test/request1', params: { data: 'test1' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - const request2Promise = testRequest( - protocol, - { method: 'test/request2', params: { data: 'test2' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - const request3Promise = testRequest( - protocol, - { method: 'test/request3', params: { data: 'test3' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - // Verify requests are queued - const queue = (protocol as unknown as TestProtocolInternals)._taskManager._taskMessageQueue; - expect(queue).toBeDefined(); - - // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocolInternals)._taskManager._clearTaskQueue(taskId); - - // Verify all request promises are rejected - const result1 = (await request1Promise) as Error; - const result2 = (await request2Promise) as Error; - const result3 = (await request3Promise) as Error; - - expect(result1).toBeInstanceOf(ProtocolError); - expect(result1.message).toContain('Task cancelled or completed'); - expect(result2).toBeInstanceOf(ProtocolError); - expect(result2.message).toContain('Task cancelled or completed'); - expect(result3).toBeInstanceOf(ProtocolError); - expect(result3.message).toContain('Task cancelled or completed'); - - // Verify queue is cleared (no messages available) - const msg = await queue!.dequeue(taskId); - expect(msg).toBeUndefined(); - }); - - it('should clean up resolver mappings when rejecting requests', async () => { - await protocol.connect(transport); - - // Create a task - const task = await mockTaskStore.createTask({}, 1, { method: 'test', params: {} }); - const taskId = task.taskId; - - // Queue a request (catch rejection to avoid unhandled promise rejection) - const requestPromise = testRequest( - protocol, - { method: 'test/request', params: { data: 'test' } }, - z.object({ result: z.string() }), - { - relatedTask: { taskId } - } - ).catch(err => err); - - // Get the request ID that was sent - const requestResolvers = (protocol as unknown as TestProtocolInternals)._taskManager._requestResolvers; - const initialResolverCount = requestResolvers.size; - expect(initialResolverCount).toBeGreaterThan(0); - - // Complete the task (triggers cleanup) - const completedTask = { ...task, status: 'completed' as const }; - mockTaskStore.getTask.mockResolvedValue(completedTask); - - // Directly call the cleanup method (simulating what happens when task reaches terminal status) - (protocol as unknown as TestProtocolInternals)._taskManager._clearTaskQueue(taskId); - - // Verify request promise is rejected - const result = (await requestPromise) as Error; - expect(result).toBeInstanceOf(ProtocolError); - expect(result.message).toContain('Task cancelled or completed'); - - // Verify resolver mapping is cleaned up - // The resolver should be removed from the map - expect(requestResolvers.size).toBeLessThan(initialResolverCount); - }); - }); -}); - -describe('requestStream() method', () => { - const CallToolResultSchema = z.object({ - content: z.array(z.object({ type: z.string(), text: z.string() })), - _meta: z.object({}).optional() - }); - - test('should yield result immediately for non-task requests', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - // Start the request stream - const streamPromise = (async () => { - const messages = []; - const stream = (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ); - for await (const message of stream) { - messages.push(message); - } - return messages; - })(); - - // Simulate server response - await new Promise(resolve => setTimeout(resolve, 10)); - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: { - content: [{ type: 'text', text: 'test result' }], - _meta: {} - } - }); - - const messages = await streamPromise; - - // Should yield exactly one result message - expect(messages).toHaveLength(1); - expect(messages[0]?.type).toBe('result'); - expect(messages[0]).toHaveProperty('result'); - }); - - test('should yield error message on request failure', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - // Start the request stream - const streamPromise = (async () => { - const messages = []; - const stream = (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ); - for await (const message of stream) { - messages.push(message); - } - return messages; - })(); - - // Simulate server error response - await new Promise(resolve => setTimeout(resolve, 10)); - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Test error' - } - }); - - const messages = await streamPromise; - - // Should yield exactly one error message - expect(messages).toHaveLength(1); - expect(messages[0]?.type).toBe('error'); - expect(messages[0]).toHaveProperty('error'); - if (messages[0]?.type === 'error') { - expect(messages[0]?.error?.message).toContain('Test error'); - } - }); - - test('should handle cancellation via AbortSignal', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const abortController = new AbortController(); - - // Abort immediately before starting the stream - abortController.abort('User cancelled'); - - // Start the request stream with already-aborted signal - const messages = []; - const stream = (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, - { - signal: abortController.signal - } - ); - for await (const message of stream) { - messages.push(message); - } - - // Should yield error message about cancellation - expect(messages).toHaveLength(1); - expect(messages[0]?.type).toBe('error'); - if (messages[0]?.type === 'error') { - expect(messages[0]?.error?.message).toContain('cancelled'); - } - }); - - describe('Error responses', () => { - test('should yield error as terminal message for server error response', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Simulate server error response - await new Promise(resolve => setTimeout(resolve, 10)); - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Server error' - } - }); - - // Collect messages - const messages = await messagesPromise; - - // Verify error is terminal and last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - assertErrorResponse(lastMessage!); - expect(lastMessage.error).toBeDefined(); - expect(lastMessage.error.message).toContain('Server error'); - }); - - test('should yield error as terminal message for timeout', async () => { - vi.useFakeTimers(); - try { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, - { - timeout: 100 - } - ) - ); - - // Advance time to trigger timeout - await vi.advanceTimersByTimeAsync(101); - - // Collect messages - const messages = await messagesPromise; - - // Verify error is terminal and last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - assertErrorResponse(lastMessage!); - expect(lastMessage.error).toBeDefined(); - expect(lastMessage.error).toBeInstanceOf(SdkError); - expect((lastMessage.error as SdkError).code).toBe(SdkErrorCode.RequestTimeout); - } finally { - vi.useRealTimers(); - } - }); - - test('should yield error as terminal message for cancellation', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const abortController = new AbortController(); - abortController.abort('User cancelled'); - - // Collect messages - const messages = await toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema, - { - signal: abortController.signal - } - ) - ); - - // Verify error is terminal and last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - assertErrorResponse(lastMessage!); - expect(lastMessage.error).toBeDefined(); - expect(lastMessage.error.message).toContain('cancelled'); - }); - - test('should not yield any messages after error message', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Simulate server error response - await new Promise(resolve => setTimeout(resolve, 10)); - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Test error' - } - }); - - // Collect messages - const messages = await messagesPromise; - - // Verify only one message (the error) was yielded - expect(messages).toHaveLength(1); - expect(messages[0]?.type).toBe('error'); - - // Try to send another message (should be ignored) - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: { - content: [{ type: 'text', text: 'should not appear' }] - } - }); - - await new Promise(resolve => setTimeout(resolve, 10)); - - // Verify no additional messages were yielded - expect(messages).toHaveLength(1); - }); - - test('should yield error as terminal message for task failure', async () => { - const transport = new MockTransport(); - const mockTaskStore = createMockTaskStore(); - const protocol = createTestProtocol({ taskStore: mockTaskStore }); - await protocol.connect(transport); - - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Simulate task creation response - await new Promise(resolve => setTimeout(resolve, 10)); - const taskId = 'test-task-123'; - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: { - _meta: { - task: { - taskId, - status: 'working', - createdAt: new Date().toISOString(), - pollInterval: 100 - } - } - } - }); - - // Wait for task creation to be processed - await new Promise(resolve => setTimeout(resolve, 20)); - - // Update task to failed status - const failedTask = { - taskId, - status: 'failed' as const, - createdAt: new Date().toISOString(), - pollInterval: 100, - ttl: null, - statusMessage: 'Task failed' - }; - mockTaskStore.getTask.mockResolvedValue(failedTask); - - // Collect messages - const messages = await messagesPromise; - - // Verify error is terminal and last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - assertErrorResponse(lastMessage!); - expect(lastMessage.error).toBeDefined(); - }); - - test('should yield error as terminal message for network error', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - // Override send to simulate network error - transport.send = vi.fn().mockRejectedValue(new Error('Network error')); - - const messages = await toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Verify error is terminal and last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - assertErrorResponse(lastMessage!); - expect(lastMessage.error).toBeDefined(); - }); - - test('should ensure error is always the final message', async () => { - const transport = new MockTransport(); - const protocol = createTestProtocol({}); - await protocol.connect(transport); - - const messagesPromise = toArrayAsync( - (protocol as unknown as TestProtocolInternals)._taskManager.requestStream( - { method: 'tools/call', params: { name: 'test', arguments: {} } }, - CallToolResultSchema - ) - ); - - // Simulate server error response - await new Promise(resolve => setTimeout(resolve, 10)); - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Test error' - } - }); - - // Collect messages - const messages = await messagesPromise; - - // Verify error is the last message - expect(messages.length).toBeGreaterThan(0); - const lastMessage = messages[messages.length - 1]; - expect(lastMessage?.type).toBe('error'); - - // Verify all messages before the last are not terminal - for (let i = 0; i < messages.length - 1; i++) { - expect(messages[i]?.type).not.toBe('error'); - expect(messages[i]?.type).not.toBe('result'); - } - }); - }); -}); - -describe('Error handling for missing resolvers', () => { - let protocol: Protocol; - let transport: MockTransport; - let taskStore: TaskStore & { [K in keyof TaskStore]: MockInstance }; - let taskMessageQueue: TaskMessageQueue; - let errorHandler: MockInstance; - - beforeEach(() => { - taskStore = createMockTaskStore(); - taskMessageQueue = new InMemoryTaskMessageQueue(); - errorHandler = vi.fn(); - - protocol = createTestProtocol({ taskStore, taskMessageQueue, defaultTaskPollInterval: 100 }); - - // @ts-expect-error deliberately overriding error handler with mock - protocol.onerror = errorHandler; - transport = new MockTransport(); - }); - - describe('Response routing with missing resolvers', () => { - it('should log error for unknown request ID without throwing', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - // Enqueue a response message without a corresponding resolver - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: 999, // Non-existent request ID - result: { content: [] } - }, - timestamp: Date.now() - }); - - // Set up the GetTaskPayloadRequest handler to process the message - const testProtocol = protocol as unknown as TestProtocolInternals; - - // Simulate dequeuing and processing the response - const queuedMessage = await taskMessageQueue.dequeue(task.taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('response'); - - // Manually trigger the response handling logic - if (queuedMessage && queuedMessage.type === 'response') { - const responseMessage = queuedMessage.message as JSONRPCResultResponse; - const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - - if (!resolver) { - // This simulates what happens in the actual handler - protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); - } - } - - // Verify error was logged - expect(errorHandler).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('Response handler missing for request 999') - }) - ); - }); - - it('should continue processing after missing resolver error', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - // Enqueue a response with missing resolver, then a valid notification - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: 999, - result: { content: [] } - }, - timestamp: Date.now() - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'notification', - message: { - jsonrpc: '2.0', - method: 'notifications/progress', - params: { progress: 50, total: 100 } - }, - timestamp: Date.now() - }); - - // Process first message (response with missing resolver) - const msg1 = await taskMessageQueue.dequeue(task.taskId); - expect(msg1?.type).toBe('response'); - - // Process second message (should work fine) - const msg2 = await taskMessageQueue.dequeue(task.taskId); - expect(msg2?.type).toBe('notification'); - expect(msg2?.message).toMatchObject({ - method: 'notifications/progress' - }); - }); - }); - - describe('Task cancellation with missing resolvers', () => { - it('should log error when resolver is missing during cleanup', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - // Enqueue a request without storing a resolver - await taskMessageQueue.enqueue(task.taskId, { - type: 'request', - message: { - jsonrpc: '2.0', - id: 42, - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }); - - // Clear the task queue (simulating cancellation) - const testProtocol = protocol as unknown as TestProtocolInternals; - await testProtocol._taskManager._clearTaskQueue(task.taskId); - - // Verify error was logged for missing resolver - expect(errorHandler).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('Resolver missing for request 42') - }) - ); - }); - - it('should handle cleanup gracefully when resolver exists', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - const requestId = 42; - const resolverMock = vi.fn(); - - // Store a resolver - const testProtocol = protocol as unknown as TestProtocolInternals; - testProtocol._taskManager._requestResolvers.set(requestId, resolverMock); - - // Enqueue a request - await taskMessageQueue.enqueue(task.taskId, { - type: 'request', - message: { - jsonrpc: '2.0', - id: requestId, - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }); - - // Clear the task queue - await testProtocol._taskManager._clearTaskQueue(task.taskId); - - // Verify resolver was called with cancellation error - expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); - - // Verify the error has the correct properties - const calledError = resolverMock.mock.calls[0]![0]; - expect(calledError.code).toBe(ProtocolErrorCode.InternalError); - expect(calledError.message).toContain('Task cancelled or completed'); - - // Verify resolver was removed - expect(testProtocol._taskManager._requestResolvers.has(requestId)).toBe(false); - }); - - it('should handle mixed messages during cleanup', async () => { - await protocol.connect(transport); - - // Create a task - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - const testProtocol = protocol as unknown as TestProtocolInternals; - - // Enqueue multiple messages: request with resolver, request without, notification - const requestId1 = 42; - const resolverMock = vi.fn(); - testProtocol._taskManager._requestResolvers.set(requestId1, resolverMock); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'request', - message: { - jsonrpc: '2.0', - id: requestId1, - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'request', - message: { - jsonrpc: '2.0', - id: 43, // No resolver for this one - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'notification', - message: { - jsonrpc: '2.0', - method: 'notifications/progress', - params: { progress: 50, total: 100 } - }, - timestamp: Date.now() - }); - - // Clear the task queue - await testProtocol._taskManager._clearTaskQueue(task.taskId); - - // Verify resolver was called for first request - expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); - - // Verify the error has the correct properties - const calledError = resolverMock.mock.calls[0]![0]; - expect(calledError.code).toBe(ProtocolErrorCode.InternalError); - expect(calledError.message).toContain('Task cancelled or completed'); - - // Verify error was logged for second request - expect(errorHandler).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('Resolver missing for request 43') - }) - ); - - // Verify queue is empty - const remaining = await taskMessageQueue.dequeue(task.taskId); - expect(remaining).toBeUndefined(); - }); - }); - - describe('Side-channeled request error handling', () => { - it('should log error when response handler is missing for side-channeled request', async () => { - await protocol.connect(transport); - - const testProtocol = protocol as unknown as TestProtocolInternals; - const messageId = 123; - - // Create a response resolver without a corresponding response handler - const responseResolver = (response: JSONRPCResultResponse | Error) => { - const handler = testProtocol._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - protocol.onerror?.(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - - // Simulate the resolver being called without a handler - const mockResponse: JSONRPCResultResponse = { - jsonrpc: '2.0', - id: messageId, - result: { content: [] } - }; - - responseResolver(mockResponse); - - // Verify error was logged - expect(errorHandler).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('Response handler missing for side-channeled request 123') - }) - ); - }); - }); - - describe('Error handling does not throw exceptions', () => { - it('should not throw when processing response with missing resolver', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: 999, - result: { content: [] } - }, - timestamp: Date.now() - }); - - // This should not throw - const processMessage = async () => { - const msg = await taskMessageQueue.dequeue(task.taskId); - if (msg && msg.type === 'response') { - const testProtocol = protocol as unknown as TestProtocolInternals; - const responseMessage = msg.message as JSONRPCResultResponse; - const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (!resolver) { - protocol.onerror?.(new Error(`Response handler missing for request ${requestId}`)); - } - } - }; - - await expect(processMessage()).resolves.not.toThrow(); - }); - - it('should not throw during task cleanup with missing resolvers', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'request', - message: { - jsonrpc: '2.0', - id: 42, - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }, - timestamp: Date.now() - }); - - const testProtocol = protocol as unknown as TestProtocolInternals; - - // This should not throw - await expect(testProtocol._taskManager._clearTaskQueue(task.taskId)).resolves.not.toThrow(); - }); - }); - - describe('Error message routing', () => { - it('should route error messages to resolvers correctly', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const requestId = 42; - const resolverMock = vi.fn(); - - // Store a resolver - const testProtocol = protocol as unknown as TestProtocolInternals; - testProtocol._taskManager._requestResolvers.set(requestId, resolverMock); - - // Enqueue an error message - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: requestId, - error: { - code: ProtocolErrorCode.InvalidRequest, - message: 'Invalid request parameters' - } - }, - timestamp: Date.now() - }); - - // Simulate dequeuing and processing the error - const queuedMessage = await taskMessageQueue.dequeue(task.taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('error'); - - // Manually trigger the error handling logic - if (queuedMessage && queuedMessage.type === 'error') { - const errorMessage = queuedMessage.message as JSONRPCErrorResponse; - const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(reqId); - - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(reqId); - const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); - resolver(error); - } - } - - // Verify resolver was called with ProtocolError - expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); - const calledError = resolverMock.mock.calls[0]![0]; - expect(calledError.code).toBe(ProtocolErrorCode.InvalidRequest); - expect(calledError.message).toContain('Invalid request parameters'); - - // Verify resolver was removed from map - expect(testProtocol._taskManager._requestResolvers.has(requestId)).toBe(false); - }); - - it('should log error for unknown request ID in error messages', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - // Enqueue an error message without a corresponding resolver - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: 999, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Something went wrong' - } - }, - timestamp: Date.now() - }); - - // Simulate dequeuing and processing the error - const queuedMessage = await taskMessageQueue.dequeue(task.taskId); - expect(queuedMessage).toBeDefined(); - expect(queuedMessage?.type).toBe('error'); - - // Manually trigger the error handling logic - if (queuedMessage && queuedMessage.type === 'error') { - const testProtocol = protocol as unknown as TestProtocolInternals; - const errorMessage = queuedMessage.message as JSONRPCErrorResponse; - const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - - if (!resolver) { - protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); - } - } - - // Verify error was logged - expect(errorHandler).toHaveBeenCalledWith( - expect.objectContaining({ - message: expect.stringContaining('Error handler missing for request 999') - }) - ); - }); - - it('should handle error messages with data field', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const requestId = 42; - const resolverMock = vi.fn(); - - // Store a resolver - const testProtocol = protocol as unknown as TestProtocolInternals; - testProtocol._taskManager._requestResolvers.set(requestId, resolverMock); - - // Enqueue an error message with data field - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: requestId, - error: { - code: ProtocolErrorCode.InvalidParams, - message: 'Validation failed', - data: { field: 'userName', reason: 'required' } - } - }, - timestamp: Date.now() - }); - - // Simulate dequeuing and processing the error - const queuedMessage = await taskMessageQueue.dequeue(task.taskId); - - if (queuedMessage && queuedMessage.type === 'error') { - const errorMessage = queuedMessage.message as JSONRPCErrorResponse; - const reqId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(reqId); - - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(reqId); - const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); - resolver(error); - } - } - - // Verify resolver was called with ProtocolError including data - expect(resolverMock).toHaveBeenCalledWith(expect.any(ProtocolError)); - const calledError = resolverMock.mock.calls[0]![0]; - expect(calledError.code).toBe(ProtocolErrorCode.InvalidParams); - expect(calledError.message).toContain('Validation failed'); - expect(calledError.data).toEqual({ field: 'userName', reason: 'required' }); - }); - - it('should not throw when processing error with missing resolver', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: 999, - error: { - code: ProtocolErrorCode.InternalError, - message: 'Error occurred' - } - }, - timestamp: Date.now() - }); - - // This should not throw - const processMessage = async () => { - const msg = await taskMessageQueue.dequeue(task.taskId); - if (msg && msg.type === 'error') { - const testProtocol = protocol as unknown as TestProtocolInternals; - const errorMessage = msg.message as JSONRPCErrorResponse; - const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (!resolver) { - protocol.onerror?.(new Error(`Error handler missing for request ${requestId}`)); - } - } - }; - - await expect(processMessage()).resolves.not.toThrow(); - }); - }); - - describe('Response and error message routing integration', () => { - it('should handle mixed response and error messages in queue', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const testProtocol = protocol as unknown as TestProtocolInternals; - - // Set up resolvers for multiple requests - const resolver1 = vi.fn(); - const resolver2 = vi.fn(); - const resolver3 = vi.fn(); - - testProtocol._taskManager._requestResolvers.set(1, resolver1); - testProtocol._taskManager._requestResolvers.set(2, resolver2); - testProtocol._taskManager._requestResolvers.set(3, resolver3); - - // Enqueue mixed messages: response, error, response - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: 1, - result: { content: [{ type: 'text', text: 'Success' }] } - }, - timestamp: Date.now() - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: 2, - error: { - code: ProtocolErrorCode.InvalidRequest, - message: 'Request failed' - } - }, - timestamp: Date.now() - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { - jsonrpc: '2.0', - id: 3, - result: { content: [{ type: 'text', text: 'Another success' }] } - }, - timestamp: Date.now() - }); - - // Process all messages - let msg; - while ((msg = await taskMessageQueue.dequeue(task.taskId))) { - if (msg.type === 'response') { - const responseMessage = msg.message as JSONRPCResultResponse; - const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(requestId); - resolver(responseMessage); - } - } else if (msg.type === 'error') { - const errorMessage = msg.message as JSONRPCErrorResponse; - const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(requestId); - const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); - resolver(error); - } - } - } - - // Verify all resolvers were called correctly - expect(resolver1).toHaveBeenCalledWith(expect.objectContaining({ id: 1 })); - expect(resolver2).toHaveBeenCalledWith(expect.any(ProtocolError)); - expect(resolver3).toHaveBeenCalledWith(expect.objectContaining({ id: 3 })); - - // Verify error has correct properties - const error = resolver2.mock.calls[0]![0]; - expect(error.code).toBe(ProtocolErrorCode.InvalidRequest); - expect(error.message).toContain('Request failed'); - - // Verify all resolvers were removed - expect(testProtocol._taskManager._requestResolvers.size).toBe(0); - }); - - it('should maintain FIFO order when processing responses and errors', async () => { - await protocol.connect(transport); - - const task = await taskStore.createTask({ ttl: 60000 }, 1, { method: 'test', params: {} }); - const testProtocol = protocol as unknown as TestProtocolInternals; - - const callOrder: number[] = []; - const resolver1 = vi.fn(() => callOrder.push(1)); - const resolver2 = vi.fn(() => callOrder.push(2)); - const resolver3 = vi.fn(() => callOrder.push(3)); - - testProtocol._taskManager._requestResolvers.set(1, resolver1); - testProtocol._taskManager._requestResolvers.set(2, resolver2); - testProtocol._taskManager._requestResolvers.set(3, resolver3); - - // Enqueue in specific order - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { jsonrpc: '2.0', id: 1, result: {} }, - timestamp: 1000 - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'error', - message: { - jsonrpc: '2.0', - id: 2, - error: { code: -32600, message: 'Error' } - }, - timestamp: 2000 - }); - - await taskMessageQueue.enqueue(task.taskId, { - type: 'response', - message: { jsonrpc: '2.0', id: 3, result: {} }, - timestamp: 3000 - }); - - // Process all messages - let msg; - while ((msg = await taskMessageQueue.dequeue(task.taskId))) { - if (msg.type === 'response') { - const responseMessage = msg.message as JSONRPCResultResponse; - const requestId = responseMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(requestId); - resolver(responseMessage); - } - } else if (msg.type === 'error') { - const errorMessage = msg.message as JSONRPCErrorResponse; - const requestId = errorMessage.id as RequestId; - const resolver = testProtocol._taskManager._requestResolvers.get(requestId); - if (resolver) { - testProtocol._taskManager._requestResolvers.delete(requestId); - const error = new ProtocolError(errorMessage.error.code, errorMessage.error.message, errorMessage.error.data); - resolver(error); - } - } - } - - // Verify FIFO order was maintained - expect(callOrder).toEqual([1, 2, 3]); - }); - }); -}); - -describe('Protocol without task configuration', () => { - let protocol: TestProtocolImpl; - let transport: MockTransport; - let sendSpy: MockInstance; - - beforeEach(() => { - transport = new MockTransport(); - sendSpy = vi.spyOn(transport, 'send'); - protocol = createTestProtocol(); // empty TaskManager options - }); - - test('request/response flow works normally without task config', async () => { - await protocol.connect(transport); - const mockSchema = z.object({ result: z.string() }); - - const requestPromise = testRequest(protocol, { method: 'example', params: {} }, mockSchema, { timeout: 5000 }); - - // Simulate response - transport.onmessage?.({ - jsonrpc: '2.0', - id: 0, - result: { result: 'hello' } - }); - - const result = await requestPromise; - expect(result).toEqual({ result: 'hello' }); - }); - - test('notifications are sent with proper JSONRPC wrapping without task config', async () => { - await protocol.connect(transport); - - await protocol.notification({ method: 'notifications/cancelled', params: { requestId: '1', reason: 'test' } }); - - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { requestId: '1', reason: 'test' } - }), - undefined - ); - }); - - test('onClose does not error without task config', async () => { - await protocol.connect(transport); - await expect(protocol.close()).resolves.not.toThrow(); - }); - - test('inbound requests dispatch to handlers without task config', async () => { - const handler = vi.fn().mockResolvedValue({ content: 'ok' }); - protocol.setRequestHandler('ping', handler); - - await protocol.connect(transport); - transport.onmessage?.({ jsonrpc: '2.0', method: 'ping', id: 1 }); - - // Wait for async handler - await new Promise(resolve => setTimeout(resolve, 10)); - - expect(handler).toHaveBeenCalled(); - expect(sendSpy).toHaveBeenCalledWith( - expect.objectContaining({ - jsonrpc: '2.0', - id: 1, - result: { content: 'ok' } - }) - ); - }); -}); - -describe('TaskManager lifecycle via Protocol', () => { - let protocol: TestProtocolImpl; - let transport: MockTransport; - - beforeEach(() => { - transport = new MockTransport(); - protocol = new TestProtocolImpl(); - }); - - test('bind() is called during Protocol construction', () => { - const bindSpy = vi.spyOn(TaskManager.prototype, 'bind'); - const p = new TestProtocolImpl({ tasks: {} }); - expect(bindSpy).toHaveBeenCalled(); - expect(p.taskManager).toBeInstanceOf(TaskManager); - bindSpy.mockRestore(); - }); - - test('NullTaskManager is created when no tasks config is provided', () => { - const p = new TestProtocolImpl(); - expect(p.taskManager).toBeInstanceOf(NullTaskManager); - }); - - test('onClose() is called when transport closes', async () => { - const p = createTestProtocol({}); - const onCloseSpy = vi.spyOn(p.taskManager, 'onClose'); - - await p.connect(transport); - await p.close(); - - expect(onCloseSpy).toHaveBeenCalled(); - }); -}); - -describe('TaskManager always present (NullTaskManager pattern)', () => { - test('taskManager accessor always returns a TaskManager', () => { - const mockTaskModule = { getTask: vi.fn() }; - const mockClient = { taskManager: mockTaskModule } as any; - expect(mockClient.taskManager).toBe(mockTaskModule); - }); -}); diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index 4e9c33e67d..94e415031a 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, test } from 'vitest'; import type { BaseContext } from '../../src/shared/protocol.js'; import { Protocol } from '../../src/shared/protocol.js'; +import { HandlerRegistry } from '../../src/shared/handlerRegistry.js'; import type { Transport } from '../../src/shared/transport.js'; import type { EmptyResult, JSONRPCMessage, Notification, Request, Result } from '../../src/types/index.js'; @@ -35,11 +36,11 @@ describe('Protocol transport handling bug', () => { beforeEach(() => { protocol = new (class extends Protocol { + constructor() { + super(new HandlerRegistry()); + } protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} protected buildContext(ctx: BaseContext): BaseContext { return ctx; } diff --git a/packages/core/test/shared/wrapHandler.test.ts b/packages/core/test/shared/wrapHandler.test.ts index 6a6e33fb09..ac6d56fa2f 100644 --- a/packages/core/test/shared/wrapHandler.test.ts +++ b/packages/core/test/shared/wrapHandler.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it } from 'vitest'; import { Protocol } from '../../src/shared/protocol.js'; +import { HandlerRegistry } from '../../src/shared/handlerRegistry.js'; import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js'; class TestProtocol extends Protocol { @@ -9,24 +10,22 @@ class TestProtocol extends Protocol { } protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} + + constructor(registry: HandlerRegistry) { + super(registry); + } } -describe('Protocol._wrapHandler', () => { - it('routes setRequestHandler registration through _wrapHandler', () => { +describe('HandlerRegistry wrapHandler callback', () => { + it('routes setRequestHandler registration through wrapHandler callback', () => { const seen: string[] = []; - class SpyProtocol extends TestProtocol { - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise - ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { + const registry = new HandlerRegistry({ + wrapHandler: (method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise) => { seen.push(method); return handler; } - } - const p = new SpyProtocol(); + }); + const p = new TestProtocol(registry); seen.length = 0; p.setRequestHandler('tools/list', () => ({ tools: [] })); p.setRequestHandler('resources/list', () => ({ resources: [] })); diff --git a/packages/core/test/spec.types.test.ts b/packages/core/test/spec.types.test.ts index d26a4cd701..4f6128062e 100644 --- a/packages/core/test/spec.types.test.ts +++ b/packages/core/test/spec.types.test.ts @@ -225,6 +225,7 @@ const sdkTypeChecks = { spec = sdk; }, ClientNotification: (sdk: WithJSONRPC, spec: SpecTypes.ClientNotification) => { + // @ts-expect-error SDK removed TaskStatusNotification (tasks removed from SDK ahead of spec) sdk = spec; spec = sdk; }, @@ -493,10 +494,12 @@ const sdkTypeChecks = { spec = sdk; }, ClientRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.ClientRequest) => { + // @ts-expect-error SDK removed task request types (tasks removed from SDK ahead of spec) sdk = spec; spec = sdk; }, ServerRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.ServerRequest) => { + // @ts-expect-error SDK removed task request types (tasks removed from SDK ahead of spec) sdk = spec; spec = sdk; }, @@ -505,6 +508,7 @@ const sdkTypeChecks = { spec = sdk; }, ServerNotification: (sdk: WithJSONRPC, spec: SpecTypes.ServerNotification) => { + // @ts-expect-error SDK removed TaskStatusNotification (tasks removed from SDK ahead of spec) sdk = spec; spec = sdk; }, @@ -552,75 +556,10 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - TaskAugmentedRequestParams: (sdk: SDKTypes.TaskAugmentedRequestParams, spec: SpecTypes.TaskAugmentedRequestParams) => { - sdk = spec; - spec = sdk; - }, ToolExecution: (sdk: SDKTypes.ToolExecution, spec: SpecTypes.ToolExecution) => { sdk = spec; spec = sdk; }, - TaskStatus: (sdk: SDKTypes.TaskStatus, spec: SpecTypes.TaskStatus) => { - sdk = spec; - spec = sdk; - }, - TaskMetadata: (sdk: SDKTypes.TaskMetadata, spec: SpecTypes.TaskMetadata) => { - sdk = spec; - spec = sdk; - }, - RelatedTaskMetadata: (sdk: SDKTypes.RelatedTaskMetadata, spec: SpecTypes.RelatedTaskMetadata) => { - sdk = spec; - spec = sdk; - }, - Task: (sdk: SDKTypes.Task, spec: SpecTypes.Task) => { - sdk = spec; - spec = sdk; - }, - CreateTaskResult: (sdk: SDKTypes.CreateTaskResult, spec: SpecTypes.CreateTaskResult) => { - sdk = spec; - spec = sdk; - }, - GetTaskResult: (sdk: SDKTypes.GetTaskResult, spec: SpecTypes.GetTaskResult) => { - sdk = spec; - spec = sdk; - }, - GetTaskPayloadRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.GetTaskPayloadRequest) => { - sdk = spec; - spec = sdk; - }, - ListTasksRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.ListTasksRequest) => { - sdk = spec; - spec = sdk; - }, - ListTasksResult: (sdk: SDKTypes.ListTasksResult, spec: SpecTypes.ListTasksResult) => { - sdk = spec; - spec = sdk; - }, - CancelTaskRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.CancelTaskRequest) => { - sdk = spec; - spec = sdk; - }, - CancelTaskResult: (sdk: SDKTypes.CancelTaskResult, spec: SpecTypes.CancelTaskResult) => { - sdk = spec; - spec = sdk; - }, - GetTaskRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.GetTaskRequest) => { - sdk = spec; - spec = sdk; - }, - GetTaskPayloadResult: (sdk: SDKTypes.GetTaskPayloadResult, spec: SpecTypes.GetTaskPayloadResult) => { - sdk = spec; - spec = sdk; - }, - TaskStatusNotificationParams: (sdk: SDKTypes.TaskStatusNotificationParams, spec: SpecTypes.TaskStatusNotificationParams) => { - sdk = spec; - spec = sdk; - }, - TaskStatusNotification: (sdk: WithJSONRPC, spec: SpecTypes.TaskStatusNotification) => { - sdk = spec; - spec = sdk; - }, - /* JSON primitives */ JSONValue: (sdk: SDKTypes.JSONValue, spec: SpecTypes.JSONValue) => { sdk = spec; @@ -715,29 +654,6 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - CreateTaskResultResponse: (sdk: TypedResultResponse, spec: SpecTypes.CreateTaskResultResponse) => { - sdk = spec; - spec = sdk; - }, - GetTaskResultResponse: (sdk: TypedResultResponse, spec: SpecTypes.GetTaskResultResponse) => { - sdk = spec; - spec = sdk; - }, - GetTaskPayloadResultResponse: ( - sdk: TypedResultResponse, - spec: SpecTypes.GetTaskPayloadResultResponse - ) => { - sdk = spec; - spec = sdk; - }, - CancelTaskResultResponse: (sdk: TypedResultResponse, spec: SpecTypes.CancelTaskResultResponse) => { - sdk = spec; - spec = sdk; - }, - ListTasksResultResponse: (sdk: TypedResultResponse, spec: SpecTypes.ListTasksResultResponse) => { - sdk = spec; - spec = sdk; - }, SetLevelResultResponse: (sdk: TypedResultResponse, spec: SpecTypes.SetLevelResultResponse) => { sdk = spec; spec = sdk; @@ -791,7 +707,7 @@ type AssertExactKeys< type Assert = T; /* - * Excluded from key-level assertions (23 entries): + * Excluded from key-level assertions (22 entries): * * Union types — KnownKeys cannot meaningfully enumerate their members (15): * ClientRequest, ServerRequest, ClientNotification, ServerNotification, @@ -799,12 +715,12 @@ type Assert = T; * SamplingMessageContentBlock, ElicitRequestParams, PrimitiveSchemaDefinition, * SingleSelectEnumSchema, MultiSelectEnumSchema, EnumSchema * - * Primitive type aliases — no object keys to compare (8): + * Primitive type aliases — no object keys to compare (7): * JSONValue, JSONArray, Role, LoggingLevel, ProgressToken, RequestId, - * Cursor, TaskStatus + * Cursor */ -// -- Simple types (96) -- +// -- Simple types (86) -- type _K_RequestParams = Assert>; type _K_NotificationParams = Assert>; @@ -819,14 +735,18 @@ type _K_ResourceUpdatedNotificationParams = Assert< AssertExactKeys >; type _K_GetPromptRequestParams = Assert>; +// @ts-expect-error SDK removed 'task' key (tasks removed from SDK ahead of spec) type _K_CallToolRequestParams = Assert>; type _K_SetLevelRequestParams = Assert>; type _K_LoggingMessageNotificationParams = Assert< AssertExactKeys >; +// @ts-expect-error SDK removed 'task' key (tasks removed from SDK ahead of spec) type _K_CreateMessageRequestParams = Assert>; type _K_CompleteRequestParams = Assert>; +// @ts-expect-error SDK removed 'task' key (tasks removed from SDK ahead of spec) type _K_ElicitRequestFormParams = Assert>; +// @ts-expect-error SDK removed 'task' key (tasks removed from SDK ahead of spec) type _K_ElicitRequestURLParams = Assert>; type _K_PaginatedRequestParams = Assert>; type _K_BaseMetadata = Assert>; @@ -884,7 +804,9 @@ type _K_LegacyTitledEnumSchema = Assert>; type _K_JSONRPCResultResponse = Assert>; type _K_InitializeResult = Assert>; +// @ts-expect-error SDK removed 'tasks' key (tasks removed from SDK ahead of spec) type _K_ClientCapabilities = Assert>; +// @ts-expect-error SDK removed 'tasks' key (tasks removed from SDK ahead of spec) type _K_ServerCapabilities = Assert>; type _K_SamplingMessage = Assert>; type _K_Icon = Assert>; @@ -895,22 +817,9 @@ type _K_ToolChoice = Assert>; type _K_ToolResultContent = Assert>; type _K_Annotations = Assert>; -type _K_TaskAugmentedRequestParams = Assert>; type _K_ToolExecution = Assert>; -type _K_TaskMetadata = Assert>; -type _K_RelatedTaskMetadata = Assert>; -type _K_Task = Assert>; -type _K_CreateTaskResult = Assert>; -type _K_GetTaskResult = Assert>; -type _K_ListTasksResult = Assert>; -type _K_CancelTaskResult = Assert>; -type _K_GetTaskPayloadResult = Assert>; -type _K_TaskStatusNotificationParams = Assert< - AssertExactKeys ->; type _K_JSONObject = Assert>; type _K_MetaObject = Assert>; -// @ts-expect-error Genuine mismatch: SDK RequestMetaObject has extra 'io.modelcontextprotocol/related-task' not in spec type _K_RequestMetaObject = Assert>; type _K_ParseError = Assert>; type _K_InvalidRequestError = Assert>; @@ -918,7 +827,7 @@ type _K_MethodNotFoundError = Assert>; type _K_InternalError = Assert>; -// -- WithJSONRPC-wrapped notification types (11) -- +// -- WithJSONRPC-wrapped notification types (10) -- // SDK notification types do not include `jsonrpc` — the spec types do. We wrap // with WithJSONRPC<> to add the missing field before comparing keys. @@ -946,9 +855,8 @@ type _K_LoggingMessageNotification = Assert< AssertExactKeys, SpecTypes.LoggingMessageNotification> >; type _K_InitializedNotification = Assert, SpecTypes.InitializedNotification>>; -type _K_TaskStatusNotification = Assert, SpecTypes.TaskStatusNotification>>; -// -- WithJSONRPCRequest-wrapped request types (21) -- +// -- WithJSONRPCRequest-wrapped request types (17) -- // SDK request types do not include `jsonrpc` or `id` — the spec types do. We // wrap with WithJSONRPCRequest<> to add the missing fields before comparing keys. @@ -971,14 +879,8 @@ type _K_ListPromptsRequest = Assert, SpecTypes.GetPromptRequest>>; type _K_CreateMessageRequest = Assert, SpecTypes.CreateMessageRequest>>; type _K_InitializeRequest = Assert, SpecTypes.InitializeRequest>>; -type _K_GetTaskPayloadRequest = Assert< - AssertExactKeys, SpecTypes.GetTaskPayloadRequest> ->; -type _K_ListTasksRequest = Assert, SpecTypes.ListTasksRequest>>; -type _K_CancelTaskRequest = Assert, SpecTypes.CancelTaskRequest>>; -type _K_GetTaskRequest = Assert, SpecTypes.GetTaskRequest>>; -// -- TypedResultResponse-wrapped types (21) -- +// -- TypedResultResponse-wrapped types (16) -- // The spec defines typed *ResultResponse interfaces that pair JSONRPCResultResponse // with a specific result. We compare TypedResultResponse against the // spec's combined type. @@ -1004,17 +906,6 @@ type _K_ListPromptsResultResponse = Assert< type _K_GetPromptResultResponse = Assert, SpecTypes.GetPromptResultResponse>>; type _K_ListToolsResultResponse = Assert, SpecTypes.ListToolsResultResponse>>; type _K_CallToolResultResponse = Assert, SpecTypes.CallToolResultResponse>>; -type _K_CreateTaskResultResponse = Assert< - AssertExactKeys, SpecTypes.CreateTaskResultResponse> ->; -type _K_GetTaskResultResponse = Assert, SpecTypes.GetTaskResultResponse>>; -type _K_GetTaskPayloadResultResponse = Assert< - AssertExactKeys, SpecTypes.GetTaskPayloadResultResponse> ->; -type _K_CancelTaskResultResponse = Assert< - AssertExactKeys, SpecTypes.CancelTaskResultResponse> ->; -type _K_ListTasksResultResponse = Assert, SpecTypes.ListTasksResultResponse>>; type _K_SetLevelResultResponse = Assert, SpecTypes.SetLevelResultResponse>>; type _K_CreateMessageResultResponse = Assert< AssertExactKeys, SpecTypes.CreateMessageResultResponse> @@ -1048,15 +939,14 @@ const KEY_PARITY_EXCLUDED = [ 'SingleSelectEnumSchema', 'MultiSelectEnumSchema', 'EnumSchema', - // Primitive aliases (8) + // Primitive aliases (7) 'JSONValue', 'JSONArray', 'Role', 'LoggingLevel', 'ProgressToken', 'RequestId', - 'Cursor', - 'TaskStatus' + 'Cursor' ]; // This file is .gitignore'd, and fetched by `npm run fetch:spec-types` (called by `npm run test`) @@ -1069,6 +959,33 @@ const MISSING_SDK_TYPES = [ 'URLElicitationRequiredError' // In the SDK, but with a custom definition ]; +// Task types that exist in the generated spec types but were removed from the +// SDK ahead of the spec (the SDK no longer ships the TaskManager system). +// They are excluded from both the mutual-assignability and key-parity guards. +const TASK_TYPES_REMOVED_FROM_SDK = [ + 'TaskAugmentedRequestParams', + 'TaskStatus', + 'TaskMetadata', + 'RelatedTaskMetadata', + 'Task', + 'CreateTaskResult', + 'CreateTaskResultResponse', + 'GetTaskRequest', + 'GetTaskResult', + 'GetTaskResultResponse', + 'GetTaskPayloadRequest', + 'GetTaskPayloadResult', + 'GetTaskPayloadResultResponse', + 'CancelTaskRequest', + 'CancelTaskResult', + 'CancelTaskResultResponse', + 'ListTasksRequest', + 'ListTasksResult', + 'ListTasksResultResponse', + 'TaskStatusNotificationParams', + 'TaskStatusNotification' +]; + function extractExportedTypes(source: string): string[] { const matches = [...source.matchAll(/export\s+(?:interface|class|type)\s+(\w+)\b/g)]; return matches.map(m => m[1]!); @@ -1081,7 +998,7 @@ function extractKeyParityTypes(source: string): string[] { describe('Spec Types', () => { const specTypes = extractExportedTypes(fs.readFileSync(SPEC_TYPES_FILE, 'utf8')); const sdkTypes = extractExportedTypes(fs.readFileSync(SDK_TYPES_FILE, 'utf8')); - const typesToCheck = specTypes.filter(type => !MISSING_SDK_TYPES.includes(type)); + const typesToCheck = specTypes.filter(type => !MISSING_SDK_TYPES.includes(type) && !TASK_TYPES_REMOVED_FROM_SDK.includes(type)); it('should define some expected types', () => { expect(specTypes).toContain('JSONRPCNotification'); @@ -1095,6 +1012,12 @@ describe('Spec Types', () => { } }); + it('should have up to date list of task types removed from sdk', () => { + for (const typeName of TASK_TYPES_REMOVED_FROM_SDK) { + expect(sdkTypes).not.toContain(typeName); + } + }); + it('should have comprehensive compatibility tests', () => { const missingTests = []; diff --git a/packages/core/test/types/specTypeSchema.test.ts b/packages/core/test/types/specTypeSchema.test.ts index 198e104f9f..e43513e81b 100644 --- a/packages/core/test/types/specTypeSchema.test.ts +++ b/packages/core/test/types/specTypeSchema.test.ts @@ -154,9 +154,7 @@ describe('SPEC_SCHEMA_KEYS allowlist', () => { const INTERNAL_HELPER_SCHEMAS: readonly string[] = [ 'ListChangedOptionsBaseSchema', 'BaseRequestParamsSchema', - 'NotificationsParamsSchema', - 'ClientTasksCapabilitySchema', - 'ServerTasksCapabilitySchema' + 'NotificationsParamsSchema' ]; it('covers every public protocol schema in schemas.ts (drift guard)', () => { diff --git a/packages/middleware/node/src/streamableHttp.ts b/packages/middleware/node/src/streamableHttp.ts index 68a0c224f0..e4993b7fa9 100644 --- a/packages/middleware/node/src/streamableHttp.ts +++ b/packages/middleware/node/src/streamableHttp.ts @@ -10,7 +10,7 @@ import type { IncomingMessage, ServerResponse } from 'node:http'; import { getRequestListener } from '@hono/node-server'; -import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; +import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, ProtocolConfig, RequestId, Transport } from '@modelcontextprotocol/core'; import type { WebStandardStreamableHTTPServerTransportOptions } from '@modelcontextprotocol/server'; import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; @@ -130,6 +130,10 @@ export class NodeStreamableHTTPServerTransport implements Transport { return this._webStandardTransport.onmessage; } + setProtocolConfig(config: ProtocolConfig): void { + this._webStandardTransport.setProtocolConfig!(config); + } + /** * Starts the transport. This is required by the {@linkcode Transport} interface but is a no-op * for the Streamable HTTP transport as connections are managed per-request. diff --git a/packages/middleware/node/test/streamableHttp.test.ts b/packages/middleware/node/test/streamableHttp.test.ts index c427aa2eea..c70f85bf81 100644 --- a/packages/middleware/node/test/streamableHttp.test.ts +++ b/packages/middleware/node/test/streamableHttp.test.ts @@ -45,6 +45,16 @@ interface TestServerConfig { onsessioninitialized?: ((sessionId: string) => void | Promise) | undefined; onsessionclosed?: ((sessionId: string) => void | Promise) | undefined; retryInterval?: number; + /** Additional tools to register before connecting (needed for routing transport compatibility) */ + additionalTools?: Array<{ + name: string; + description: string; + inputSchema: z.ZodType; + handler: ( + args: Record, + ctx: { http?: { closeSSE?: () => void; closeStandaloneSSE?: () => void } } + ) => Promise; + }>; } /** @@ -172,6 +182,70 @@ describe('Zod v4', () => { } ); + // General-purpose tool for sending log notifications via tool handler context. + // With the routing transport, transport.send() and server.sendLoggingMessage() + // are not available, so tests use this tool to send server-initiated notifications. + mcpServer.registerTool( + 'send-log', + { + description: 'Sends a log notification via handler context', + inputSchema: z.object({ message: z.string() }) + }, + async ({ message }, ctx): Promise => { + ctx.mcpReq.log('info', message); + return { content: [{ type: 'text', text: 'sent' }] }; + } + ); + + // Test tool that exposes Request object info - registered before connect to + // avoid sendToolListChanged errors with the routing transport + mcpServer.registerTool( + 'test-request-info', + { + description: 'A simple test tool with request info', + inputSchema: z.object({ name: z.string().describe('Name to greet') }) + }, + async ({ name }, ctx): Promise => { + const req = ctx.http?.req; + const serializedRequestInfo = { + headers: Object.fromEntries(req?.headers ?? new Headers()), + url: req?.url, + method: req?.method + }; + return { + content: [ + { type: 'text', text: `Hello, ${name}!` }, + { type: 'text', text: `${JSON.stringify(serializedRequestInfo)}` } + ] + }; + } + ); + + // Test tool that reads query params - registered before connect + mcpServer.registerTool( + 'test-query-params', + { + description: 'A tool that reads query params', + inputSchema: z.object({}) + }, + async (_args, ctx): Promise => { + const req = ctx.http?.req; + const url = new URL(req!.url); + const params = Object.fromEntries(url.searchParams); + return { + content: [{ type: 'text', text: JSON.stringify(params) }] + }; + } + ); + + // Register any additional tools before connect (needed for routing transport + // since registerTool after connect triggers sendToolListChanged which throws) + if (config.additionalTools) { + for (const tool of config.additionalTools) { + mcpServer.registerTool(tool.name, { description: tool.description, inputSchema: tool.inputSchema }, tool.handler as never); + } + } + const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, @@ -286,12 +360,12 @@ describe('Zod v4', () => { expect(response.headers.get('mcp-session-id')).toBeDefined(); }); - it('should reject second initialization request', async () => { + it('should create a new session on second initialization request', async () => { // First initialize - const sessionId = await initializeServer(); - expect(sessionId).toBeDefined(); + const firstSessionId = await initializeServer(); + expect(firstSessionId).toBeDefined(); - // Try second initialize + // Second initialize creates a new independent session const secondInitMessage = { ...TEST_MESSAGES.initialize, id: 'second-init' @@ -299,9 +373,11 @@ describe('Zod v4', () => { const response = await sendPostRequest(baseUrl, secondInitMessage); - expect(response.status).toBe(400); - const errorData = await response.json(); - expectErrorResponse(errorData, -32_600, /Server already initialized/); + // The routing transport creates a new per-session legacy stack + expect(response.status).toBe(200); + const secondSessionId = response.headers.get('mcp-session-id'); + expect(secondSessionId).toBeDefined(); + expect(secondSessionId).not.toBe(firstSessionId); }); it('should reject batch initialize request', async () => { @@ -399,28 +475,6 @@ describe('Zod v4', () => { it('should expose the full Request object to tool handlers', async () => { sessionId = await initializeServer(); - mcpServer.registerTool( - 'test-request-info', - { - description: 'A simple test tool with request info', - inputSchema: z.object({ name: z.string().describe('Name to greet') }) - }, - async ({ name }, ctx): Promise => { - const req = ctx.http?.req; - const serializedRequestInfo = { - headers: Object.fromEntries(req?.headers ?? new Headers()), - url: req?.url, - method: req?.method - }; - return { - content: [ - { type: 'text', text: `Hello, ${name}!` }, - { type: 'text', text: `${JSON.stringify(serializedRequestInfo)}` } - ] - }; - } - ); - const toolCallMessage: JSONRPCMessage = { jsonrpc: '2.0', method: 'tools/call', @@ -474,22 +528,6 @@ describe('Zod v4', () => { it('should expose query parameters via the Request object', async () => { sessionId = await initializeServer(); - mcpServer.registerTool( - 'test-query-params', - { - description: 'A tool that reads query params', - inputSchema: z.object({}) - }, - async (_args, ctx): Promise => { - const req = ctx.http?.req; - const url = new URL(req!.url); - const params = Object.fromEntries(url.searchParams); - return { - content: [{ type: 'text', text: JSON.stringify(params) }] - }; - } - ); - const toolCallMessage: JSONRPCMessage = { jsonrpc: '2.0', method: 'tools/call', @@ -535,7 +573,7 @@ describe('Zod v4', () => { expect(response.status).toBe(404); const errorData = await response.json(); - expectErrorResponse(errorData, -32_001, /Session not found/); + expectErrorResponse(errorData, -32_000, /Session not found/); }); it('should establish standalone SSE stream and receive server-initiated messages', async () => { @@ -555,17 +593,17 @@ describe('Zod v4', () => { expect(sseResponse.status).toBe(200); expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); - // Send a notification (server-initiated message) that should appear on SSE stream - const notification: JSONRPCMessage = { + // Send a notification by calling the send-log tool (which sends via its handler context) + const toolCallMessage: JSONRPCMessage = { jsonrpc: '2.0', - method: 'notifications/message', - params: { level: 'info', data: 'Test notification' } + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Test notification' } }, + id: 'notify-1' }; + const toolResponse = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(toolResponse.status).toBe(200); - // Send the notification via transport - await transport.send(notification); - - // Read from the stream and verify we got the notification + // Read from the standalone SSE stream and verify we got the notification const text = await readSSEEvent(sseResponse); const eventLines = text.split('\n'); @@ -596,16 +634,17 @@ describe('Zod v4', () => { expect(sseResponse.status).toBe(200); const reader = sseResponse.body?.getReader(); - // Send multiple notifications - const notification1: JSONRPCMessage = { + // Send notification via tool call using the send-log tool + const toolCallMessage: JSONRPCMessage = { jsonrpc: '2.0', - method: 'notifications/message', - params: { level: 'info', data: 'First notification' } + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'First notification' } }, + id: 'notify-sse-1' }; + const toolResponse = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(toolResponse.status).toBe(200); - // Just send one and verify it comes through - then the stream should stay open - await transport.send(notification1); - + // Just read one and verify it comes through - then the stream should stay open const { value, done } = await reader!.read(); const text = new TextDecoder().decode(value); expect(text).toContain('First notification'); @@ -794,12 +833,12 @@ describe('Zod v4', () => { }); }); - it('should reject requests to uninitialized server', async () => { + it('should reject requests with unknown session ID on fresh server', async () => { // Create a new HTTP server and transport without initializing const { server: uninitializedServer, transport: uninitializedTransport, baseUrl: uninitializedUrl } = await createTestServer(); // Transport not used in test but needed for cleanup - // No initialization, just send a request directly + // No initialization, just send a request directly with a made-up session ID const uninitializedMessage: JSONRPCMessage = { jsonrpc: '2.0', method: 'tools/list', @@ -807,12 +846,12 @@ describe('Zod v4', () => { id: 'uninitialized-test' }; - // Send a request to uninitialized server + // The routing transport returns 404 "Session not found" for unknown session IDs const response = await sendPostRequest(uninitializedUrl, uninitializedMessage, 'any-session-id'); - expect(response.status).toBe(400); + expect(response.status).toBe(404); const errorData = await response.json(); - expectErrorResponse(errorData, -32_000, /Server not initialized/); + expectErrorResponse(errorData, -32_000, /Session not found/); // Cleanup await stopTestServer({ server: uninitializedServer, transport: uninitializedTransport }); @@ -872,18 +911,28 @@ describe('Zod v4', () => { } }); - // Send several server-initiated notifications - await transport.send({ - jsonrpc: '2.0', - method: 'notifications/message', - params: { level: 'info', data: 'First notification' } - }); + // Send notifications via tool calls using the send-log tool + await sendPostRequest( + baseUrl, + { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'First notification' } }, + id: 'keep-open-1' + } as JSONRPCMessage, + sessionId + ); - await transport.send({ - jsonrpc: '2.0', - method: 'notifications/message', - params: { level: 'info', data: 'Second notification' } - }); + await sendPostRequest( + baseUrl, + { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Second notification' } }, + id: 'keep-open-2' + } as JSONRPCMessage, + sessionId + ); // Stream should still be open - it should not close after sending notifications expect(sseResponse.bodyUsed).toBe(false); @@ -931,7 +980,7 @@ describe('Zod v4', () => { expect(response.status).toBe(404); const errorData = await response.json(); - expectErrorResponse(errorData, -32_001, /Session not found/); + expectErrorResponse(errorData, -32_000, /Session not found/); }); describe('protocol version header validation', () => { @@ -1434,15 +1483,15 @@ describe('Zod v4', () => { expect(sseResponse.status).toBe(200); expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); - // Send a notification that should be stored with an event ID - const notification: JSONRPCMessage = { + // Send a notification via the send-log tool (which uses handler context) + const toolCallMessage: JSONRPCMessage = { jsonrpc: '2.0', - method: 'notifications/message', - params: { level: 'info', data: 'Test notification with event ID' } + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Test notification with event ID' } }, + id: 'event-id-test-1' }; - - // Send the notification via transport - await transport.send(notification); + const toolResponse = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(toolResponse.status).toBe(200); // Read from the stream and verify we got the notification with an event ID const reader = sseResponse.body?.getReader(); @@ -1462,10 +1511,26 @@ describe('Zod v4', () => { expect(storedEvents.has(eventId)).toBe(true); const storedEvent = storedEvents.get(eventId); expect(eventId.startsWith('_GET_stream')).toBe(true); - expect(storedEvent?.message).toMatchObject(notification); + expect(storedEvent?.message).toMatchObject({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification with event ID' } + }); }); it('should store and replay MCP server tool notifications', async () => { + // Helper to send a log notification via the send-log tool + async function sendLogNotification(message: string, id: string) { + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message } }, + id + }; + const resp = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(resp.status).toBe(200); + } + // Establish a standalone SSE stream const sseResponse = await fetch(baseUrl, { method: 'GET', @@ -1477,8 +1542,8 @@ describe('Zod v4', () => { }); expect(sseResponse.status).toBe(200); - // Send a server notification through the MCP server - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'First notification from MCP server' }); + // Send a server notification through the send-log tool + await sendLogNotification('First notification from MCP server', 'replay-1'); // Read the notification from the SSE stream const reader = sseResponse.body?.getReader(); @@ -1495,7 +1560,7 @@ describe('Zod v4', () => { const firstEventId = idMatch![1]!; // Send a second notification - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Second notification from MCP server' }); + await sendLogNotification('Second notification from MCP server', 'replay-2'); // Close the first SSE stream to simulate a disconnect await reader!.cancel(); @@ -1524,6 +1589,18 @@ describe('Zod v4', () => { }); it('should store and replay multiple notifications sent while client is disconnected', async () => { + // Helper to send a log notification via the send-log tool + async function sendLogNotification(message: string, id: string) { + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message } }, + id + }; + const resp = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(resp.status).toBe(200); + } + // Establish a standalone SSE stream const sseResponse = await fetch(baseUrl, { method: 'GET', @@ -1538,7 +1615,7 @@ describe('Zod v4', () => { const reader = sseResponse.body?.getReader(); // Send a notification to get an event ID - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Initial notification' }); + await sendLogNotification('Initial notification', 'multi-replay-init'); // Read the notification from the SSE stream const { value } = await reader!.read(); @@ -1553,9 +1630,9 @@ describe('Zod v4', () => { await reader!.cancel(); // Send MULTIPLE notifications while the client is disconnected - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Missed notification 1' }); - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Missed notification 2' }); - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Missed notification 3' }); + await sendLogNotification('Missed notification 1', 'multi-replay-1'); + await sendLogNotification('Missed notification 2', 'multi-replay-2'); + await sendLogNotification('Missed notification 3', 'multi-replay-3'); // Reconnect with the Last-Event-ID to get all missed messages const reconnectResponse = await fetch(baseUrl, { @@ -1613,74 +1690,94 @@ describe('Zod v4', () => { await stopTestServer({ server, transport }); }); - it('should operate without session ID validation', async () => { - // Initialize the server first + it('should create sessions even when sessionIdGenerator is undefined', async () => { + // With the routing transport, sessionIdGenerator:undefined still creates + // per-session legacy stacks with auto-generated UUIDs const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); expect(initResponse.status).toBe(200); - // Should NOT have session ID header in stateless mode - expect(initResponse.headers.get('mcp-session-id')).toBeNull(); + // The routing transport always generates session IDs for legacy sessions + const statelessSessionId = initResponse.headers.get('mcp-session-id'); + expect(statelessSessionId).toBeDefined(); - // Try request without session ID - should work in stateless mode - const toolsResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); + // Requests with the session ID work + const toolsResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, statelessSessionId!); expect(toolsResponse.status).toBe(200); }); - it('should handle POST requests with various session IDs in stateless mode', async () => { - await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + it('should reject POST requests with unknown session IDs', async () => { + // Initialize to create a session + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + const validSessionId = initResponse.headers.get('mcp-session-id')!; + expect(validSessionId).toBeDefined(); - // Try with a random session ID - should be accepted + // Requests with the valid session ID work + const validResponse = await sendPostRequest( + baseUrl, + { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 't0' + } as JSONRPCMessage, + validSessionId + ); + expect(validResponse.status).toBe(200); + + // Random session IDs are rejected as "Session not found" const response1 = await fetch(baseUrl, { method: 'POST', headers: { 'Content-Type': 'application/json', Accept: 'application/json, text/event-stream', - 'mcp-session-id': 'random-id-1' + 'mcp-session-id': 'random-id-1', + 'mcp-protocol-version': '2025-11-25' }, body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't1' }) }); - expect(response1.status).toBe(200); + expect(response1.status).toBe(404); - // Try with another random session ID - should also be accepted const response2 = await fetch(baseUrl, { method: 'POST', headers: { 'Content-Type': 'application/json', Accept: 'application/json, text/event-stream', - 'mcp-session-id': 'different-id-2' + 'mcp-session-id': 'different-id-2', + 'mcp-protocol-version': '2025-11-25' }, body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't2' }) }); - expect(response2.status).toBe(200); + expect(response2.status).toBe(404); }); - it('should reject second SSE stream even in stateless mode', async () => { - // Despite no session ID requirement, the transport still only allows - // one standalone SSE stream at a time - - // Initialize the server first - await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + it('should reject second SSE stream for the same session', async () => { + // Initialize the server to get a session ID + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + const statelessSessionId = initResponse.headers.get('mcp-session-id')!; + expect(statelessSessionId).toBeDefined(); - // Open first SSE stream + // Open first SSE stream with the session ID const stream1 = await fetch(baseUrl, { method: 'GET', headers: { Accept: 'text/event-stream', + 'mcp-session-id': statelessSessionId, 'mcp-protocol-version': '2025-11-25' } }); expect(stream1.status).toBe(200); - // Open second SSE stream - should still be rejected, stateless mode still only allows one + // Open second SSE stream with same session - should be rejected (one per session) const stream2 = await fetch(baseUrl, { method: 'GET', headers: { Accept: 'text/event-stream', + 'mcp-session-id': statelessSessionId, 'mcp-protocol-version': '2025-11-25' } }); - expect(stream2.status).toBe(409); // Conflict - only one stream allowed + expect(stream2.status).toBe(409); // Conflict - only one stream allowed per session }); }); @@ -1883,16 +1980,6 @@ describe('Zod v4', () => { }); it('should close POST SSE stream when ctx.http?.closeSSE is called', async () => { - const result = await createTestServer({ - sessionIdGenerator: () => randomUUID(), - eventStore: createEventStore(), - retryInterval: 1000 - }); - server = result.server; - transport = result.transport; - baseUrl = result.baseUrl; - mcpServer = result.mcpServer; - // Track when stream close is called and tool completes let streamCloseCalled = false; let toolResolve: () => void; @@ -1900,16 +1987,31 @@ describe('Zod v4', () => { toolResolve = resolve; }); - // Register a tool that closes its own SSE stream via ctx callback - mcpServer.registerTool('close-stream-tool', { description: 'Closes its own stream' }, async ctx => { - // Close the SSE stream for this request - ctx.http?.closeSSE?.(); - streamCloseCalled = true; - - // Wait before returning so we can observe the stream closure - await toolCompletePromise; - return { content: [{ type: 'text', text: 'Done' }] }; + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + eventStore: createEventStore(), + retryInterval: 1000, + additionalTools: [ + { + name: 'close-stream-tool', + description: 'Closes its own stream', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + // Close the SSE stream for this request + ctx.http?.closeSSE?.(); + streamCloseCalled = true; + + // Wait before returning so we can observe the stream closure + await toolCompletePromise; + return { content: [{ type: 'text', text: 'Done' }] }; + } + } + ] }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + mcpServer = result.mcpServer; // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); @@ -1955,25 +2057,30 @@ describe('Zod v4', () => { }); it('should provide closeSSEStream callback in ctx when eventStore is configured', async () => { + // Track whether closeSSEStream callback was provided + let receivedCloseSSEStream: (() => void) | undefined; + const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), eventStore: createEventStore(), - retryInterval: 1000 + retryInterval: 1000, + additionalTools: [ + { + name: 'test-callback-tool', + description: 'Test tool', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + receivedCloseSSEStream = ctx.http?.closeSSE; + return { content: [{ type: 'text', text: 'Done' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Track whether closeSSEStream callback was provided - let receivedCloseSSEStream: (() => void) | undefined; - - // Register a tool that captures the ctx.http?.closeSSE callback - mcpServer.registerTool('test-callback-tool', { description: 'Test tool' }, async ctx => { - receivedCloseSSEStream = ctx.http?.closeSSE; - return { content: [{ type: 'text', text: 'Done' }] }; - }); - // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2013,27 +2120,32 @@ describe('Zod v4', () => { }); it('should NOT provide closeSSEStream callback for old protocol versions (backwards compatibility)', async () => { + // Track whether closeSSEStream callback was provided + let receivedCloseSSEStream: (() => void) | undefined; + let receivedCloseStandaloneSSEStream: (() => void) | undefined; + const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), eventStore: createEventStore(), - retryInterval: 1000 + retryInterval: 1000, + additionalTools: [ + { + name: 'test-old-version-tool', + description: 'Test tool', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + receivedCloseSSEStream = ctx.http?.closeSSE; + receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; + return { content: [{ type: 'text', text: 'Done' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Track whether closeSSEStream callback was provided - let receivedCloseSSEStream: (() => void) | undefined; - let receivedCloseStandaloneSSEStream: (() => void) | undefined; - - // Register a tool that captures the ctx.http?.closeSSE callback - mcpServer.registerTool('test-old-version-tool', { description: 'Test tool' }, async ctx => { - receivedCloseSSEStream = ctx.http?.closeSSE; - receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; - return { content: [{ type: 'text', text: 'Done' }] }; - }); - // Initialize with OLD protocol version to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initializeOldVersion); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2074,24 +2186,29 @@ describe('Zod v4', () => { }); it('should NOT provide closeSSEStream callback when eventStore is NOT configured', async () => { + // Track whether closeSSEStream callback was provided + let receivedCloseSSEStream: (() => void) | undefined; + const result = await createTestServer({ - sessionIdGenerator: () => randomUUID() + sessionIdGenerator: () => randomUUID(), // No eventStore + additionalTools: [ + { + name: 'test-no-callback-tool', + description: 'Test tool', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + receivedCloseSSEStream = ctx.http?.closeSSE; + return { content: [{ type: 'text', text: 'Done' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Track whether closeSSEStream callback was provided - let receivedCloseSSEStream: (() => void) | undefined; - - // Register a tool that captures the ctx.http?.closeSSE callback - mcpServer.registerTool('test-no-callback-tool', { description: 'Test tool' }, async ctx => { - receivedCloseSSEStream = ctx.http?.closeSSE; - return { content: [{ type: 'text', text: 'Done' }] }; - }); - // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2130,25 +2247,30 @@ describe('Zod v4', () => { }); it('should provide closeStandaloneSSEStream callback in ctx when eventStore is configured', async () => { + // Track whether closeStandaloneSSEStream callback was provided + let receivedCloseStandaloneSSEStream: (() => void) | undefined; + const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), eventStore: createEventStore(), - retryInterval: 1000 + retryInterval: 1000, + additionalTools: [ + { + name: 'test-standalone-callback-tool', + description: 'Test tool', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; + return { content: [{ type: 'text', text: 'Done' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Track whether closeStandaloneSSEStream callback was provided - let receivedCloseStandaloneSSEStream: (() => void) | undefined; - - // Register a tool that captures the ctx.http?.closeStandaloneSSE callback - mcpServer.registerTool('test-standalone-callback-tool', { description: 'Test tool' }, async ctx => { - receivedCloseStandaloneSSEStream = ctx.http?.closeStandaloneSSE; - return { content: [{ type: 'text', text: 'Done' }] }; - }); - // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2191,19 +2313,24 @@ describe('Zod v4', () => { const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), eventStore: createEventStore(), - retryInterval: 1000 + retryInterval: 1000, + additionalTools: [ + { + name: 'close-standalone-stream-tool', + description: 'Closes standalone stream', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + ctx.http?.closeStandaloneSSE?.(); + return { content: [{ type: 'text', text: 'Stream closed' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Register a tool that closes the standalone SSE stream via ctx callback - mcpServer.registerTool('close-standalone-stream-tool', { description: 'Closes standalone stream' }, async ctx => { - ctx.http?.closeStandaloneSSE?.(); - return { content: [{ type: 'text', text: 'Stream closed' }] }; - }); - // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2222,8 +2349,15 @@ describe('Zod v4', () => { const getReader = sseResponse.body?.getReader(); - // Send a notification to confirm GET stream is established - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Stream established' }); + // Send a notification to confirm GET stream is established (via the send-log tool) + const sendLogMsg: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Stream established' } }, + id: 'confirm-sse-1' + }; + const logResp = await sendPostRequest(baseUrl, sendLogMsg, sessionId); + expect(logResp.status).toBe(200); // Read the notification to confirm stream is working const { value } = await getReader!.read(); @@ -2272,19 +2406,24 @@ describe('Zod v4', () => { const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), eventStore: createEventStore(), - retryInterval: 1000 + retryInterval: 1000, + additionalTools: [ + { + name: 'close-standalone-for-reconnect', + description: 'Closes standalone stream', + inputSchema: z.object({}), + handler: async (_args, ctx) => { + ctx.http?.closeStandaloneSSE?.(); + return { content: [{ type: 'text', text: 'Stream closed' }] }; + } + } + ] }); server = result.server; transport = result.transport; baseUrl = result.baseUrl; mcpServer = result.mcpServer; - // Register a tool that closes the standalone SSE stream - mcpServer.registerTool('close-standalone-for-reconnect', { description: 'Closes standalone stream' }, async ctx => { - ctx.http?.closeStandaloneSSE?.(); - return { content: [{ type: 'text', text: 'Stream closed' }] }; - }); - // Initialize to get session ID const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); sessionId = initResponse.headers.get('mcp-session-id') as string; @@ -2303,8 +2442,15 @@ describe('Zod v4', () => { const getReader = sseResponse.body?.getReader(); - // Send a notification to get an event ID - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Initial message' }); + // Send a notification to get an event ID (via the send-log tool) + const sendLogMsg: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Initial message' } }, + id: 'reconnect-init-1' + }; + const logResp = await sendPostRequest(baseUrl, sendLogMsg, sessionId); + expect(logResp.status).toBe(200); // Read the notification to get the event ID const { value } = await getReader!.read(); @@ -2353,8 +2499,15 @@ describe('Zod v4', () => { // timestamp, the UUID suffix ordering is random and may not preserve creation order. await new Promise(resolve => setTimeout(resolve, 5)); - // Send a notification while client is disconnected - await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Missed while disconnected' }); + // Send a notification while client is disconnected (via the send-log tool) + const sendMissedMsg: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'send-log', arguments: { message: 'Missed while disconnected' } }, + id: 'reconnect-missed-1' + }; + const missedResp = await sendPostRequest(baseUrl, sendMissedMsg, sessionId); + expect(missedResp.status).toBe(200); // Client reconnects with Last-Event-ID const reconnectResponse = await fetch(baseUrl, { @@ -2817,7 +2970,10 @@ describe('Zod v4', () => { expect(body.error.message).toContain('Invalid Host header:'); }); - it('should reject GET requests with disallowed host headers', async () => { + it('should reject GET requests without session ID before DNS check', async () => { + // With the routing transport, GET requests without a session ID + // are rejected with 400 "Missing Mcp-Session-Id header" before + // DNS rebinding checks are reached (those happen in per-session stacks) const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedHosts: ['example.com:3001'], @@ -2834,7 +2990,7 @@ describe('Zod v4', () => { } }); - expect(response.status).toBe(403); + expect(response.status).toBe(400); }); }); diff --git a/packages/server/src/experimental/index.ts b/packages/server/src/experimental/index.ts deleted file mode 100644 index 55dd44ed08..0000000000 --- a/packages/server/src/experimental/index.ts +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Experimental MCP SDK features. - * WARNING: These APIs are experimental and may change without notice. - * - * Import experimental features from this module: - * ```typescript - * import { TaskStore, InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental'; - * ``` - * - * @experimental - */ - -export * from './tasks/index.js'; diff --git a/packages/server/src/experimental/tasks/index.ts b/packages/server/src/experimental/tasks/index.ts deleted file mode 100644 index 6917fe61af..0000000000 --- a/packages/server/src/experimental/tasks/index.ts +++ /dev/null @@ -1,10 +0,0 @@ -/** - * Experimental task features for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - -export * from './interfaces.js'; -export * from './mcpServer.js'; -export * from './server.js'; diff --git a/packages/server/src/experimental/tasks/interfaces.ts b/packages/server/src/experimental/tasks/interfaces.ts deleted file mode 100644 index 2aef91a8c0..0000000000 --- a/packages/server/src/experimental/tasks/interfaces.ts +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Experimental task interfaces for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - */ - -import type { - CallToolResult, - CreateTaskResult, - CreateTaskServerContext, - GetTaskResult, - Result, - StandardSchemaWithJSON, - TaskServerContext -} from '@modelcontextprotocol/core'; - -import type { BaseToolCallback } from '../../server/mcp.js'; - -// ============================================================================ -// Task Handler Types (for registerToolTask) -// ============================================================================ - -/** - * Handler for creating a task. - * @experimental - */ -export type CreateTaskRequestHandler< - SendResultT extends Result, - Args extends StandardSchemaWithJSON | undefined = undefined -> = BaseToolCallback; - -/** - * Handler for task operations (`get`, `getResult`). - * @experimental - */ -export type TaskRequestHandler = BaseToolCallback< - SendResultT, - TaskServerContext, - Args ->; - -/** - * Interface for task-based tool handlers. - * - * Task-based tools split a long-running operation into three phases: - * `createTask`, `getTask`, and `getTaskResult`. - * - * @see {@linkcode @modelcontextprotocol/server!experimental/tasks/mcpServer.ExperimentalMcpServerTasks#registerToolTask | registerToolTask} for registration. - * @experimental - */ -export interface ToolTaskHandler { - /** - * Called on the initial `tools/call` request. - * - * Creates a task via `ctx.task.store.createTask(...)`, starts any - * background work, and returns the task object. - */ - createTask: CreateTaskRequestHandler; - /** - * Handler for `tasks/get` requests. - */ - getTask: TaskRequestHandler; - /** - * Handler for `tasks/result` requests. - */ - getTaskResult: TaskRequestHandler; -} diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts deleted file mode 100644 index b7c28c40d3..0000000000 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Experimental {@linkcode McpServer} task features for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - -import type { StandardSchemaWithJSON, TaskToolExecution, ToolAnnotations, ToolExecution } from '@modelcontextprotocol/core'; - -import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcp.js'; -import type { ToolTaskHandler } from './interfaces.js'; - -/** - * Internal interface for accessing {@linkcode McpServer}'s private _createRegisteredTool method. - * @internal - */ -interface McpServerInternal { - _createRegisteredTool( - name: string, - title: string | undefined, - description: string | undefined, - inputSchema: StandardSchemaWithJSON | undefined, - outputSchema: StandardSchemaWithJSON | undefined, - annotations: ToolAnnotations | undefined, - execution: ToolExecution | undefined, - _meta: Record | undefined, - handler: AnyToolHandler - ): RegisteredTool; -} - -/** - * Experimental task features for {@linkcode McpServer}. - * - * Access via `server.experimental.tasks`: - * ```typescript - * server.experimental.tasks.registerToolTask('long-running', config, handler); - * ``` - * - * @experimental - */ -export class ExperimentalMcpServerTasks { - constructor(private readonly _mcpServer: McpServer) {} - - /** - * Registers a task-based tool with a config object and handler. - * - * Task-based tools support long-running operations that can be polled for status - * and results. The handler must implement {@linkcode ToolTaskHandler.createTask | createTask}, {@linkcode ToolTaskHandler.getTask | getTask}, and {@linkcode ToolTaskHandler.getTaskResult | getTaskResult} - * methods. - * - * @example - * ```typescript - * server.experimental.tasks.registerToolTask('long-computation', { - * description: 'Performs a long computation', - * inputSchema: z.object({ input: z.string() }), - * execution: { taskSupport: 'required' } - * }, { - * createTask: async (args, ctx) => { - * const task = await ctx.task.store.createTask({ ttl: 300000 }); - * startBackgroundWork(task.taskId, args); - * return { task }; - * }, - * getTask: async (args, ctx) => { - * return ctx.task.store.getTask(ctx.task.id); - * }, - * getTaskResult: async (args, ctx) => { - * return ctx.task.store.getTaskResult(ctx.task.id); - * } - * }); - * ``` - * - * @param name - The tool name - * @param config - Tool configuration (description, schemas, etc.) - * @param handler - Task handler with {@linkcode ToolTaskHandler.createTask | createTask}, {@linkcode ToolTaskHandler.getTask | getTask}, {@linkcode ToolTaskHandler.getTaskResult | getTaskResult} methods - * @returns {@linkcode server/mcp.RegisteredTool | RegisteredTool} for managing the tool's lifecycle - * - * @experimental - */ - registerToolTask( - name: string, - config: { - title?: string; - description?: string; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool; - - registerToolTask( - name: string, - config: { - title?: string; - description?: string; - inputSchema: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool; - - registerToolTask( - name: string, - config: { - title?: string; - description?: string; - inputSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - execution?: TaskToolExecution; - _meta?: Record; - }, - handler: ToolTaskHandler - ): RegisteredTool { - // Validate that taskSupport is not 'forbidden' for task-based tools - const execution: ToolExecution = { taskSupport: 'required', ...config.execution }; - if (execution.taskSupport === 'forbidden') { - throw new Error(`Cannot register task-based tool '${name}' with taskSupport 'forbidden'. Use registerTool() instead.`); - } - - // Access McpServer's internal _createRegisteredTool method - const mcpServerInternal = this._mcpServer as unknown as McpServerInternal; - return mcpServerInternal._createRegisteredTool( - name, - config.title, - config.description, - config.inputSchema, - config.outputSchema, - config.annotations, - execution, - config._meta, - handler as AnyToolHandler - ); - } -} diff --git a/packages/server/src/experimental/tasks/server.ts b/packages/server/src/experimental/tasks/server.ts deleted file mode 100644 index 2e7b205fd6..0000000000 --- a/packages/server/src/experimental/tasks/server.ts +++ /dev/null @@ -1,298 +0,0 @@ -/** - * Experimental server task features for MCP SDK. - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - -import type { - AnyObjectSchema, - CancelTaskResult, - CreateMessageRequestParams, - CreateMessageResult, - ElicitRequestFormParams, - ElicitRequestURLParams, - ElicitResult, - GetTaskPayloadResult, - GetTaskResult, - ListTasksResult, - Request, - RequestMethod, - RequestOptions, - ResponseMessage, - ResultTypeMap -} from '@modelcontextprotocol/core'; -import { getResultSchema, GetTaskPayloadResultSchema, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; - -import type { Server } from '../../server/server.js'; - -/** - * Experimental task features for low-level MCP servers. - * - * Access via `server.experimental.tasks`: - * ```typescript - * const stream = server.experimental.tasks.requestStream(request, options); - * ``` - * - * For high-level server usage with task-based tools, use {@linkcode index.McpServer | McpServer}.experimental.tasks instead. - * - * @experimental - */ -export class ExperimentalServerTasks { - constructor(private readonly _server: Server) {} - - private get _module() { - return this._server.taskManager; - } - - /** - * Sends a request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a `'result'` or `'error'` message. - * - * This method provides streaming access to request processing, allowing you to - * observe intermediate task status updates for task-augmented requests. - * - * @param request - The request to send (method name determines the result schema) - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields {@linkcode ResponseMessage} objects - * - * @experimental - */ - requestStream( - request: { method: M; params?: Record }, - options?: RequestOptions - ): AsyncGenerator, void, void> { - const resultSchema = getResultSchema(request.method) as unknown as AnyObjectSchema; - return this._module.requestStream(request as Request, resultSchema, options) as AsyncGenerator< - ResponseMessage, - void, - void - >; - } - - /** - * Sends a sampling request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * For task-augmented requests, yields 'taskCreated' and 'taskStatus' messages - * before the final result. - * - * @example - * ```typescript - * const stream = server.experimental.tasks.createMessageStream({ - * messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - * maxTokens: 100 - * }, { - * onprogress: (progress) => { - * // Handle streaming tokens via progress notifications - * console.log('Progress:', progress.message); - * } - * }); - * - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('Final result:', message.result); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @param params - The sampling request parameters - * @param options - Optional request options (timeout, signal, task creation params, onprogress, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - * - * @experimental - */ - createMessageStream( - params: CreateMessageRequestParams, - options?: RequestOptions - ): AsyncGenerator, void, void> { - // Access client capabilities via the server - const clientCapabilities = this._server.getClientCapabilities(); - - // Capability check - only required when tools/toolChoice are provided - if ((params.tools || params.toolChoice) && !clientCapabilities?.sampling?.tools) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support sampling tools capability.'); - } - - // Message structure validation - always validate tool_use/tool_result pairs. - // These may appear even without tools/toolChoice in the current request when - // a previous sampling request returned tool_use and this is a follow-up with results. - if (params.messages.length > 0) { - const lastMessage = params.messages.at(-1)!; - const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; - const hasToolResults = lastContent.some(c => c.type === 'tool_result'); - - const previousMessage = params.messages.length > 1 ? params.messages.at(-2) : undefined; - const previousContent = previousMessage - ? Array.isArray(previousMessage.content) - ? previousMessage.content - : [previousMessage.content] - : []; - const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use'); - - if (hasToolResults) { - if (lastContent.some(c => c.type !== 'tool_result')) { - throw new Error('The last message must contain only tool_result content if any is present'); - } - if (!hasPreviousToolUse) { - throw new Error('tool_result blocks are not matching any tool_use from the previous message'); - } - } - if (hasPreviousToolUse) { - const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => c.id)); - const toolResultIds = new Set(lastContent.filter(c => c.type === 'tool_result').map(c => c.toolUseId)); - if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { - throw new Error('ids of tool_result blocks and tool_use blocks from previous message do not match'); - } - } - } - - return this.requestStream( - { - method: 'sampling/createMessage', - params - }, - options - ) as AsyncGenerator, void, void>; - } - - /** - * Sends an elicitation request and returns an AsyncGenerator that yields response messages. - * The generator is guaranteed to end with either a 'result' or 'error' message. - * - * For task-augmented requests (especially URL-based elicitation), yields 'taskCreated' - * and 'taskStatus' messages before the final result. - * - * @example - * ```typescript - * const stream = server.experimental.tasks.elicitInputStream({ - * mode: 'url', - * message: 'Please authenticate', - * elicitationId: 'auth-123', - * url: 'https://example.com/auth' - * }, { - * task: { ttl: 300000 } // Task-augmented for long-running auth flow - * }); - * - * for await (const message of stream) { - * switch (message.type) { - * case 'taskCreated': - * console.log('Task created:', message.task.taskId); - * break; - * case 'taskStatus': - * console.log('Task status:', message.task.status); - * break; - * case 'result': - * console.log('User action:', message.result.action); - * break; - * case 'error': - * console.error('Error:', message.error); - * break; - * } - * } - * ``` - * - * @param params - The elicitation request parameters - * @param options - Optional request options (timeout, signal, task creation params, etc.) - * @returns AsyncGenerator that yields ResponseMessage objects - * - * @experimental - */ - elicitInputStream( - params: ElicitRequestFormParams | ElicitRequestURLParams, - options?: RequestOptions - ): AsyncGenerator, void, void> { - // Access client capabilities via the server - const clientCapabilities = this._server.getClientCapabilities(); - const mode = params.mode ?? 'form'; - - // Capability check based on mode - switch (mode) { - case 'url': { - if (!clientCapabilities?.elicitation?.url) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support url elicitation.'); - } - break; - } - case 'form': { - if (!clientCapabilities?.elicitation?.form) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support form elicitation.'); - } - break; - } - } - - // Normalize params to ensure mode is set - const normalizedParams = mode === 'form' && params.mode !== 'form' ? { ...params, mode: 'form' } : params; - return this.requestStream( - { - method: 'elicitation/create', - params: normalizedParams - }, - options - ) as AsyncGenerator, void, void>; - } - - /** - * Gets the current status of a task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * @returns The task status - * - * @experimental - */ - async getTask(taskId: string, options?: RequestOptions): Promise { - return this._module.getTask({ taskId }, options); - } - - /** - * Retrieves the result of a completed task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * @returns The task result. The payload structure matches the result type of the - * original request (e.g., a `tools/call` task returns a `CallToolResult`). - * - * @experimental - */ - async getTaskResult(taskId: string, options?: RequestOptions): Promise { - return this._module.getTaskResult({ taskId }, GetTaskPayloadResultSchema, options); - } - - /** - * Lists tasks with optional pagination. - * - * @param cursor - Optional pagination cursor - * @param options - Optional request options - * @returns List of tasks with optional next cursor - * - * @experimental - */ - async listTasks(cursor?: string, options?: RequestOptions): Promise { - return this._module.listTasks(cursor ? { cursor } : undefined, options); - } - - /** - * Cancels a running task. - * - * @param taskId - The task identifier - * @param options - Optional request options - * - * @experimental - */ - async cancelTask(taskId: string, options?: RequestOptions): Promise { - return this._module.cancelTask({ taskId }, options); - } -} diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 95566bbb4d..e738e4b0bd 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -31,6 +31,9 @@ export { Server } from './server/server.js'; // StdioServerTransport is exported from the './stdio' subpath — server stdio has only type-level Node // imports (erased at compile time), but matching the client's `./stdio` subpath gives consumers a // consistent shape across packages. +export type { ModernHandlerOptions } from './server/modernHandler.js'; +export { ModernProtocolHandler } from './server/modernHandler.js'; +export { WebStandardStreamableHTTPServerTransport } from './server/modernStreamableHttp.js'; export type { EventId, EventStore, @@ -38,12 +41,6 @@ export type { StreamId, WebStandardStreamableHTTPServerTransportOptions } from './server/streamableHttp.js'; -export { WebStandardStreamableHTTPServerTransport } from './server/streamableHttp.js'; - -// experimental exports -export type { CreateTaskRequestHandler, TaskRequestHandler, ToolTaskHandler } from './experimental/tasks/interfaces.js'; -export { ExperimentalMcpServerTasks } from './experimental/tasks/mcpServer.js'; -export { ExperimentalServerTasks } from './experimental/tasks/server.js'; // runtime-aware wrapper (shadows core/public's fromJsonSchema with optional validator) export { fromJsonSchema } from './fromJsonSchema.js'; diff --git a/packages/server/src/server/mcp.examples.ts b/packages/server/src/server/mcp.examples.ts index 740c1bf186..414435399f 100644 --- a/packages/server/src/server/mcp.examples.ts +++ b/packages/server/src/server/mcp.examples.ts @@ -11,7 +11,7 @@ import type { CallToolResult } from '@modelcontextprotocol/core'; import * as z from 'zod/v4'; import { McpServer } from './mcp.js'; -import { StdioServerTransport } from './stdio.js'; +import { StdioServerTransport } from './modernStdio.js'; /** * Example: Creating a new McpServer. diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index fb45fd5db6..d56f35a361 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -1,12 +1,9 @@ import type { BaseMetadata, - CallToolRequest, CallToolResult, CompleteRequestPrompt, CompleteRequestResourceTemplate, CompleteResult, - CreateTaskResult, - CreateTaskServerContext, GetPromptResult, Implementation, ListPromptsResult, @@ -23,7 +20,6 @@ import type { StandardSchemaWithJSON, Tool, ToolAnnotations, - ToolExecution, Transport, Variables } from '@modelcontextprotocol/core'; @@ -41,8 +37,6 @@ import { } from '@modelcontextprotocol/core'; import type * as z from 'zod/v4'; -import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; -import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; import { getCompleter, isCompletable } from './completable.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; @@ -72,28 +66,11 @@ export class McpServer { } = {}; private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; - private _experimental?: { tasks: ExperimentalMcpServerTasks }; constructor(serverInfo: Implementation, options?: ServerOptions) { this.server = new Server(serverInfo, options); } - /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - get experimental(): { tasks: ExperimentalMcpServerTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalMcpServerTasks(this) - }; - } - return this._experimental; - } - /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -147,7 +124,6 @@ export class McpServer { ? (standardSchemaToJsonSchema(tool.inputSchema, 'input') as Tool['inputSchema']) : EMPTY_OBJECT_JSON_SCHEMA, annotations: tool.annotations, - execution: tool.execution, _meta: tool._meta }; @@ -160,7 +136,7 @@ export class McpServer { }) ); - this.server.setRequestHandler('tools/call', async (request, ctx): Promise => { + this.server.setRequestHandler('tools/call', async (request, ctx): Promise => { const tool = this._registeredTools[request.params.name]; if (!tool) { throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Tool ${request.params.name} not found`); @@ -170,41 +146,9 @@ export class McpServer { } try { - const isTaskRequest = !!request.params.task; - const taskSupport = tool.execution?.taskSupport; - const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - - // Validate task hint configuration - if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new ProtocolError( - ProtocolErrorCode.InternalError, - `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` - ); - } - - // Handle taskSupport 'required' without task augmentation - if (taskSupport === 'required' && !isTaskRequest) { - throw new ProtocolError( - ProtocolErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); - } - - // Handle taskSupport 'optional' without task augmentation - automatic polling - if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, ctx); - } - - // Normal execution path const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const result = await this.executeToolHandler(tool, args, ctx); - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } - - // Validate output schema for non-task requests await this.validateToolOutput(tool, result, request.params.name); return result; } catch (error) { @@ -265,16 +209,11 @@ export class McpServer { /** * Validates tool output against the tool's output schema. */ - private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + private async validateToolOutput(tool: RegisteredTool, result: CallToolResult, toolName: string): Promise { if (!tool.outputSchema) { return; } - // Only validate CallToolResult, not CreateTaskResult - if (!('content' in result)) { - return; - } - if (result.isError) { return; } @@ -297,47 +236,13 @@ export class McpServer { } /** - * Executes a tool handler (either regular or task-based). + * Executes a tool handler. */ - private async executeToolHandler(tool: RegisteredTool, args: unknown, ctx: ServerContext): Promise { + private async executeToolHandler(tool: RegisteredTool, args: unknown, ctx: ServerContext): Promise { // Executor encapsulates handler invocation with proper types return tool.executor(args, ctx); } - /** - * Handles automatic task polling for tools with `taskSupport` `'optional'`. - */ - private async handleAutomaticTaskPolling( - tool: RegisteredTool, - request: RequestT, - ctx: ServerContext - ): Promise { - if (!ctx.task?.store) { - throw new Error('No task store provided for task-capable tool.'); - } - - // Validate input and create task using the executor - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const createTaskResult = (await tool.executor(args, ctx)) as CreateTaskResult; - - // Poll until completion - const taskId = createTaskResult.task.taskId; - let task = createTaskResult.task; - const pollInterval = task.pollInterval ?? 5000; - - while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { - await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await ctx.task.store.getTask(taskId); - if (!updatedTask) { - throw new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} not found during polling`); - } - task = updatedTask; - } - - // Return the final result - return (await ctx.task.store.getTaskResult(taskId)) as CallToolResult; - } - private _completionHandlerInitialized = false; private setCompletionRequestHandler() { @@ -773,15 +678,14 @@ export class McpServer { inputSchema: StandardSchemaWithJSON | undefined, outputSchema: StandardSchemaWithJSON | undefined, annotations: ToolAnnotations | undefined, - execution: ToolExecution | undefined, _meta: Record | undefined, - handler: AnyToolHandler + handler: ToolCallback ): RegisteredTool { // Validate tool name according to SEP specification validateAndWarnToolName(name); // Track current handler for executor regeneration - let currentHandler = handler; + let currentHandler: ToolCallback = handler; const registeredTool: RegisteredTool = { title, @@ -789,7 +693,6 @@ export class McpServer { inputSchema, outputSchema, annotations, - execution, _meta, handler: handler, executor: createToolExecutor(inputSchema, handler), @@ -816,7 +719,7 @@ export class McpServer { } if (updates.callback !== undefined) { registeredTool.handler = updates.callback; - currentHandler = updates.callback as AnyToolHandler; + currentHandler = updates.callback as ToolCallback; needsExecutorRegen = true; } if (needsExecutorRegen) { @@ -914,7 +817,6 @@ export class McpServer { normalizeRawShapeSchema(inputSchema), normalizeRawShapeSchema(outputSchema), annotations, - { taskSupport: 'forbidden' }, _meta, cb as ToolCallback ); @@ -1148,14 +1050,14 @@ export type ToolCallback; /** - * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). + * Supertype for tool handlers. */ -export type AnyToolHandler = ToolCallback | ToolTaskHandler; +export type AnyToolHandler = ToolCallback; /** * Internal executor type that encapsulates handler invocation with proper types. */ -type ToolExecutor = (args: unknown, ctx: ServerContext) => Promise; +type ToolExecutor = (args: unknown, ctx: ServerContext) => Promise; export type RegisteredTool = { title?: string; @@ -1163,9 +1065,8 @@ export type RegisteredTool = { inputSchema?: StandardSchemaWithJSON; outputSchema?: StandardSchemaWithJSON; annotations?: ToolAnnotations; - execution?: ToolExecution; _meta?: Record; - handler: AnyToolHandler; + handler: ToolCallback; /** @hidden */ executor: ToolExecutor; enabled: boolean; @@ -1192,25 +1093,8 @@ export type RegisteredTool = { */ function createToolExecutor( inputSchema: StandardSchemaWithJSON | undefined, - handler: AnyToolHandler + handler: ToolCallback ): ToolExecutor { - const isTaskHandler = 'createTask' in handler; - - if (isTaskHandler) { - const taskHandler = handler as TaskHandlerInternal; - return async (args, ctx) => { - if (!ctx.task?.store) { - throw new Error('No task store provided.'); - } - const taskCtx: CreateTaskServerContext = { ...ctx, task: { store: ctx.task.store, requestedTtl: ctx.task?.requestedTtl } }; - if (inputSchema) { - return taskHandler.createTask(args, taskCtx); - } - // When no inputSchema, call with just ctx (the handler expects (ctx) signature) - return (taskHandler.createTask as (ctx: CreateTaskServerContext) => CreateTaskResult | Promise)(taskCtx); - }; - } - if (inputSchema) { const callback = handler as ToolCallbackInternal; return async (args, ctx) => callback(args, ctx); @@ -1300,10 +1184,6 @@ type PromptHandler = (args: Record | undefined, ctx: ServerCont type ToolCallbackInternal = (args: unknown, ctx: ServerContext) => CallToolResult | Promise; -type TaskHandlerInternal = { - createTask: (args: unknown, ctx: CreateTaskServerContext) => CreateTaskResult | Promise; -}; - export type RegisteredPrompt = { title?: string; description?: string; diff --git a/packages/server/src/server/modernHandler.ts b/packages/server/src/server/modernHandler.ts new file mode 100644 index 0000000000..ee1488862b --- /dev/null +++ b/packages/server/src/server/modernHandler.ts @@ -0,0 +1,123 @@ +import type { + AuthInfo, + Implementation, + JSONRPCErrorResponse, + JSONRPCRequest, + JSONRPCResponse, + Result, + ServerCapabilities, + ServerContext +} from '@modelcontextprotocol/core'; +import { ProtocolErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; + +export interface ModernHandlerOptions { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + requestHandlers: ReadonlyMap Promise>; + serverInfo: Implementation; + capabilities: ServerCapabilities; + instructions?: string; +} + +export class ModernProtocolHandler { + constructor(private options: ModernHandlerOptions) {} + + async handleRequest( + request: JSONRPCRequest, + extra?: { authInfo?: AuthInfo; request?: globalThis.Request } + ): Promise { + const method = request.method; + + if (method === 'server/discover') { + return this.handleDiscover(request); + } + + const handler = this.options.requestHandlers.get(method); + if (!handler) { + return this.jsonRpcError(request.id, ProtocolErrorCode.MethodNotFound, `Method not found: ${method}`); + } + + const meta = request.params?._meta; + if (!meta?.protocolVersion) { + return this.jsonRpcError(request.id, ProtocolErrorCode.InvalidRequest, 'Missing _meta.protocolVersion'); + } + + const ctx = this.buildContext(request, extra); + try { + const result = await handler(request, ctx); + return { + jsonrpc: '2.0', + id: request.id, + result: { ...result, result_type: 'complete' } + }; + } catch (error: unknown) { + const err = error as Record; + return { + jsonrpc: '2.0', + id: request.id, + error: { + code: Number.isSafeInteger(err['code']) ? (err['code'] as number) : ProtocolErrorCode.InternalError, + message: (err as unknown as Error).message ?? 'Internal error', + ...(err['data'] !== undefined && { data: err['data'] }) + } + }; + } + } + + private handleDiscover(request: JSONRPCRequest): JSONRPCResponse { + return { + jsonrpc: '2.0', + id: request.id, + result: { + supportedVersions: ['2026-06-30'], + capabilities: this.options.capabilities, + serverInfo: this.options.serverInfo, + ...(this.options.instructions && { instructions: this.options.instructions }) + } + }; + } + + private buildContext(request: JSONRPCRequest, extra?: { authInfo?: AuthInfo; request?: globalThis.Request }): ServerContext { + const abortController = new AbortController(); + return { + sessionId: undefined, + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta, + signal: abortController.signal, + send: (async () => { + throw new SdkError( + SdkErrorCode.UnsupportedOperation, + 'Server-to-client requests are not supported on the stateless 2026-06 path' + ); + }) as ServerContext['mcpReq']['send'], + notify: async () => { + /* no-op: notifications deferred on modern path */ + }, + log: async () => { + /* no-op: in-band logging deferred on modern path */ + }, + elicitInput: async () => { + throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Elicitation is not supported on the stateless 2026-06 path'); + }, + requestSampling: async () => { + throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Sampling is not supported on the stateless 2026-06 path'); + } + }, + http: extra + ? { + authInfo: extra.authInfo, + req: extra.request + } + : undefined + }; + } + + private jsonRpcError(id: JSONRPCRequest['id'], code: number, message: string): JSONRPCErrorResponse { + return { + jsonrpc: '2.0', + id, + error: { code, message } + }; + } +} diff --git a/packages/server/src/server/modernStdio.ts b/packages/server/src/server/modernStdio.ts new file mode 100644 index 0000000000..fa2d0c5c57 --- /dev/null +++ b/packages/server/src/server/modernStdio.ts @@ -0,0 +1,188 @@ +import type { Readable, Writable } from 'node:stream'; + +import type { JSONRPCMessage, ProtocolConfig, Transport, TransportSendOptions } from '@modelcontextprotocol/core'; +import { isJSONRPCRequest, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; + +import { ModernProtocolHandler } from './modernHandler.js'; +import { LegacyServer } from './server.js'; +import { LegacyStdioServerTransport } from './stdio.js'; + +type ProtocolGeneration = 'legacy' | 'modern'; + +class VirtualStdioTransport implements Transport { + onmessage?: Transport['onmessage']; + onclose?: () => void; + onerror?: (error: Error) => void; + sessionId?: string; + + constructor(private _realSend: (msg: JSONRPCMessage) => Promise) {} + + async start(): Promise { + // No-op — the real I/O transport is already started + } + + async send(message: JSONRPCMessage, _options?: TransportSendOptions): Promise { + return this._realSend(message); + } + + async close(): Promise { + this.onclose?.(); + } + + pushMessage(msg: JSONRPCMessage): void { + this.onmessage?.(msg); + } +} + +/** + * Dual-protocol stdio server transport with automatic version detection. + * + * Detects the client's protocol version from the first message and locks + * for the connection lifetime. Modern clients (2026-06) are dispatched to + * a stateless ModernProtocolHandler. Legacy clients (2025-11) get a full + * LegacyServer connected via a VirtualStdioTransport adapter. + * + * The routing transport always owns inner.onmessage — both paths go + * through _routeMessage(). This ensures symmetric behavior and prevents + * race conditions. + * + * Drop-in replacement for LegacyStdioServerTransport with no changes to + * McpServer, Server, or tool handlers. + */ +export class StdioServerTransport implements Transport { + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: Transport['onmessage']; + sessionId?: string; + + private _inner: LegacyStdioServerTransport; + private _protocolConfig?: ProtocolConfig; + private _modernHandler?: ModernProtocolHandler; + private _legacyServer?: LegacyServer; + private _virtualTransport?: VirtualStdioTransport; + private _lockedMode: ProtocolGeneration | null = null; + private _modernQueue: Promise = Promise.resolve(); + + constructor(stdin?: Readable, stdout?: Writable) { + this._inner = new LegacyStdioServerTransport(stdin, stdout); + } + + setProtocolConfig(config: ProtocolConfig): void { + this._protocolConfig = config; + this._modernHandler = new ModernProtocolHandler({ + requestHandlers: config.requestHandlers, + serverInfo: config.serverInfo!, + capabilities: config.capabilities!, + instructions: config.instructions + }); + } + + async start(): Promise { + this._inner.onerror = error => this.onerror?.(error); + this._inner.onclose = () => this.onclose?.(); + this._inner.onmessage = msg => this._routeMessage(msg); + await this._inner.start(); + } + + async send(message: JSONRPCMessage, _options?: TransportSendOptions): Promise { + return this._inner.send(message); + } + + async close(): Promise { + if (this._legacyServer) { + await this._legacyServer.close(); + } + await this._inner.close(); + } + + // ------------------------------------------------------------------- + // Version detection and routing + // ------------------------------------------------------------------- + + private _routeMessage(msg: JSONRPCMessage): void { + if (this._lockedMode === null) { + this._lockedMode = this._detectVersion(msg); + if (this._lockedMode === 'legacy') { + this._initLegacyPath(); + } + } + + if (this._lockedMode === 'modern') { + this._handleModernMessage(msg); + } else { + this._virtualTransport!.pushMessage(msg); + } + } + + private _detectVersion(msg: JSONRPCMessage): ProtocolGeneration { + if (!isJSONRPCRequest(msg)) { + return 'legacy'; + } + if (msg.method === 'initialize') { + return 'legacy'; + } + if (msg.method === 'server/discover') { + return 'modern'; + } + if ( + (msg.params as Record | undefined)?._meta && + ((msg.params as Record)._meta as Record)?.protocolVersion + ) { + return 'modern'; + } + return 'legacy'; + } + + /** + * Synchronous initialization of the legacy path. + * + * Protocol.connect() sets virtualTransport.onmessage synchronously + * at the start of connect(), so pushMessage() works immediately — + * even though connect() itself is async (its awaited start() is a no-op). + */ + private _initLegacyPath(): void { + const config = this._protocolConfig!; + + this._virtualTransport = new VirtualStdioTransport(msg => this._inner.send(msg)); + + this._legacyServer = config.createServer + ? (config.createServer() as LegacyServer) + : new LegacyServer(config.serverInfo!, { + capabilities: config.capabilities + }); + + this._legacyServer.fallbackRequestHandler = async (request, ctx) => { + const handler = config.requestHandlers.get(request.method); + if (!handler) { + throw new ProtocolError(ProtocolErrorCode.MethodNotFound, `Method not found: ${request.method}`); + } + return handler(request, ctx); + }; + + this._legacyServer.connect(this._virtualTransport).catch(error => { + this.onerror?.(error instanceof Error ? error : new Error(String(error))); + }); + } + + /** + * Dispatches a modern-path message. Requests go to ModernProtocolHandler; + * responses are written to stdout via inner.send(). + * + * Processing is serialized to prevent interleaved stdout writes from + * concurrent async handlers. + */ + private _handleModernMessage(msg: JSONRPCMessage): void { + this._modernQueue = this._modernQueue + .then(async () => { + if (!this._modernHandler) return; + + if (isJSONRPCRequest(msg)) { + const response = await this._modernHandler.handleRequest(msg); + await this._inner.send(response); + } + }) + .catch(error => { + this.onerror?.(error instanceof Error ? error : new Error(String(error))); + }); + } +} diff --git a/packages/server/src/server/modernStreamableHttp.ts b/packages/server/src/server/modernStreamableHttp.ts new file mode 100644 index 0000000000..32ed3a4c10 --- /dev/null +++ b/packages/server/src/server/modernStreamableHttp.ts @@ -0,0 +1,221 @@ +import type { + AuthInfo, + JSONRPCMessage, + ProtocolConfig, + RequestId, + ServerCapabilities, + Transport, + TransportSendOptions +} from '@modelcontextprotocol/core'; +import { isJSONRPCRequest, JSONRPCMessageSchema, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; + +import { ModernProtocolHandler } from './modernHandler.js'; +import { LegacyServer } from './server.js'; +import type { HandleRequestOptions, WebStandardStreamableHTTPServerTransportOptions } from './streamableHttp.js'; +import { LegacyWebStandardStreamableHTTPServerTransport } from './streamableHttp.js'; + +interface LegacySessionEntry { + transport: LegacyWebStandardStreamableHTTPServerTransport; + server: LegacyServer; +} + +/** + * Dual-protocol HTTP server transport that transparently serves both legacy (2025-11) + * and modern (2026-06) MCP clients on a single endpoint. + * + * Modern clients are detected via the `Mcp-Method` header and dispatched to a stateless + * handler. Legacy clients are routed to per-session transport stacks. + * + * Accepts the same options as the legacy transport — all legacy-specific options + * (sessionIdGenerator, eventStore, etc.) are forwarded to per-session legacy stacks. + */ +export class WebStandardStreamableHTTPServerTransport implements Transport { + onmessage?: Transport['onmessage']; + onclose?: Transport['onclose']; + onerror?: Transport['onerror']; + sessionId?: string; + + private protocolConfig?: ProtocolConfig; + private modernHandler?: ModernProtocolHandler; + private legacySessions = new Map(); + private options: WebStandardStreamableHTTPServerTransportOptions; + + constructor(options?: WebStandardStreamableHTTPServerTransportOptions) { + this.options = options ?? {}; + } + + setProtocolConfig(config: ProtocolConfig): void { + this.protocolConfig = config; + this.modernHandler = new ModernProtocolHandler({ + requestHandlers: config.requestHandlers, + serverInfo: config.serverInfo!, + capabilities: config.capabilities!, + instructions: config.instructions + }); + } + + async start(): Promise { + // Nothing to do — we handle requests on demand + } + + async close(): Promise { + for (const [id, entry] of this.legacySessions) { + await entry.server.close(); + this.legacySessions.delete(id); + } + } + + async send(_message: JSONRPCMessage, _options?: TransportSendOptions): Promise { + throw new Error( + 'WebStandardStreamableHTTPServerTransport.send() should never be called. ' + + 'All dispatch goes through ModernProtocolHandler or per-session legacy transports.' + ); + } + + closeSSEStream(requestId: RequestId): void { + for (const entry of this.legacySessions.values()) { + entry.transport.closeSSEStream(requestId); + } + } + + closeStandaloneSSEStream(): void { + for (const entry of this.legacySessions.values()) { + entry.transport.closeStandaloneSSEStream(); + } + } + + async handleRequest(req: Request, options?: HandleRequestOptions): Promise { + return this.isStatelessProtocolRequest(req) ? this.handleModernRequest(req, options) : this.handleLegacyRequest(req, options); + } + + private isStatelessProtocolRequest(req: Request): boolean { + return !this.options.forceLegacy && req.headers.has('mcp-method'); + } + + private async handleModernRequest(req: Request, options?: HandleRequestOptions): Promise { + if (!this.modernHandler) { + return this.jsonErrorResponse(500, ProtocolErrorCode.InternalError, 'Modern handler not initialized'); + } + + if (req.method !== 'POST') { + return new Response(null, { status: 405, headers: { Allow: 'POST' } }); + } + + const ct = req.headers.get('content-type'); + if (!ct || !ct.includes('application/json')) { + return this.jsonErrorResponse(415, -32_000, 'Unsupported Media Type: expected application/json'); + } + + let rawMessage: unknown; + if (options?.parsedBody === undefined) { + try { + rawMessage = await req.json(); + } catch { + return this.jsonErrorResponse(400, -32_700, 'Parse error: Invalid JSON'); + } + } else { + rawMessage = options.parsedBody; + } + + if (Array.isArray(rawMessage)) { + return this.jsonErrorResponse(400, -32_600, 'Batch requests not supported on 2026-06 path'); + } + + let message; + try { + message = JSONRPCMessageSchema.parse(rawMessage); + } catch { + return this.jsonErrorResponse(400, -32_700, 'Parse error: Invalid JSON-RPC message'); + } + + if (!isJSONRPCRequest(message)) { + return this.jsonErrorResponse(400, -32_600, 'Expected JSON-RPC request'); + } + + const authInfo: AuthInfo | undefined = options?.authInfo; + const response = await this.modernHandler.handleRequest(message, { + authInfo, + request: req + }); + + return Response.json(response, { + status: 200, + headers: { 'Content-Type': 'application/json' } + }); + } + + private async handleLegacyRequest(req: Request, options?: HandleRequestOptions): Promise { + const sessionId = req.headers.get('mcp-session-id'); + + if (sessionId) { + const entry = this.legacySessions.get(sessionId); + if (!entry) { + return this.jsonErrorResponse(404, -32_000, 'Session not found'); + } + return entry.transport.handleRequest(req, options); + } + + if (req.method === 'POST') { + return this.handleLegacyInitialize(req, options); + } + + return this.jsonErrorResponse(400, -32_600, 'Missing Mcp-Session-Id header'); + } + + private async handleLegacyInitialize(req: Request, options?: HandleRequestOptions): Promise { + const innerServer = this.protocolConfig!.createServer + ? (this.protocolConfig!.createServer() as LegacyServer) + : new LegacyServer(this.protocolConfig!.serverInfo!, { + capabilities: this.protocolConfig!.capabilities as ServerCapabilities, + instructions: this.protocolConfig!.instructions + }); + + innerServer.fallbackRequestHandler = async (request, ctx) => { + const handler = this.protocolConfig!.requestHandlers.get(request.method); + if (!handler) { + throw new ProtocolError(ProtocolErrorCode.MethodNotFound, `Method not found: ${request.method}`); + } + return handler(request, ctx); + }; + + const transportOptions: WebStandardStreamableHTTPServerTransportOptions = { + sessionIdGenerator: this.options.sessionIdGenerator ?? (() => crypto.randomUUID()), + onsessioninitialized: async (sid: string) => { + this.legacySessions.set(sid, { transport: innerTransport, server: innerServer }); + await this.options.onsessioninitialized?.(sid); + }, + onsessionclosed: this.options.onsessionclosed, + enableJsonResponse: this.options.enableJsonResponse, + eventStore: this.options.eventStore, + allowedHosts: this.options.allowedHosts, + allowedOrigins: this.options.allowedOrigins, + enableDnsRebindingProtection: this.options.enableDnsRebindingProtection, + retryInterval: this.options.retryInterval, + supportedProtocolVersions: this.options.supportedProtocolVersions + }; + + const innerTransport = new LegacyWebStandardStreamableHTTPServerTransport(transportOptions); + + innerTransport.onclose = () => { + const sid = innerTransport.sessionId; + if (sid) this.legacySessions.delete(sid); + }; + + await innerServer.connect(innerTransport); + return innerTransport.handleRequest(req, options); + } + + private jsonErrorResponse(httpStatus: number, code: number, message: string): Response { + return Response.json( + { + jsonrpc: '2.0', + error: { code, message }, + id: null + }, + { + status: httpStatus, + headers: { 'Content-Type': 'application/json' } + } + ); + } +} diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index f6a34f02da..b22870e0b3 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -19,34 +19,35 @@ import type { LoggingLevel, LoggingMessageNotification, MessageExtraInfo, + Notification, NotificationMethod, NotificationOptions, + NotificationTypeMap, + ProtocolConfig, ProtocolOptions, RequestMethod, RequestOptions, + RequestTypeMap, ResourceUpdatedNotification, Result, + ResultTypeMap, ServerCapabilities, ServerContext, - TaskManagerOptions, + StandardSchemaV1, ToolResultContent, - ToolUseContent + ToolUseContent, + Transport } from '@modelcontextprotocol/core'; import { - assertClientRequestTaskCapability, - assertToolsCallTaskCapability, - CallToolRequestSchema, CallToolResultSchema, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, - CreateTaskResultSchema, ElicitResultSchema, EmptyResultSchema, - extractTaskManagerOptions, + HandlerRegistry, LATEST_PROTOCOL_VERSION, ListRootsResultSchema, LoggingLevelSchema, - mergeCapabilities, parseSchema, Protocol, ProtocolError, @@ -56,89 +57,154 @@ import { } from '@modelcontextprotocol/core'; import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/server/_shims'; -import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; - -/** - * Extended tasks capability that includes runtime configuration (store, messageQueue). - * The runtime-only fields are stripped before advertising capabilities to clients. - */ -export type ServerTasksCapabilityWithRuntime = NonNullable & TaskManagerOptions; - export type ServerOptions = ProtocolOptions & { - /** - * Capabilities to advertise as being supported by this server. - */ - capabilities?: Omit & { - tasks?: ServerTasksCapabilityWithRuntime; - }; - - /** - * Optional instructions describing how to use the server and its features. - */ + capabilities?: ServerCapabilities; instructions?: string; - + jsonSchemaValidator?: jsonSchemaValidator; /** - * JSON Schema validator for elicitation response validation. - * - * The validator is used to validate user input returned from elicitation - * requests against the requested schema. - * - * @default {@linkcode DefaultJsonSchemaValidator} ({@linkcode index.AjvJsonSchemaValidator | AjvJsonSchemaValidator} on Node.js, `CfWorkerJsonSchemaValidator` on Cloudflare Workers) + * Optional pre-built HandlerRegistry. When supplied (e.g., by Server wrapper), + * LegacyServer will use this registry instead of creating its own. + * @internal */ - jsonSchemaValidator?: jsonSchemaValidator; + registry?: HandlerRegistry; }; +// --------------------------------------------------------------------------- +// Standalone functions extracted from LegacyServer for use as callbacks +// --------------------------------------------------------------------------- + +function assertServerHandlerCapability(method: string, capabilities: ServerCapabilities): void { + switch (method) { + case 'completion/complete': { + if (!capabilities.completions) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support completions (required for ${method})`); + } + break; + } + + case 'logging/setLevel': { + if (!capabilities.logging) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); + } + break; + } + + case 'prompts/get': + case 'prompts/list': { + if (!capabilities.prompts) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support prompts (required for ${method})`); + } + break; + } + + case 'resources/list': + case 'resources/templates/list': + case 'resources/read': { + if (!capabilities.resources) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support resources (required for ${method})`); + } + break; + } + + case 'tools/call': + case 'tools/list': { + if (!capabilities.tools) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support tools (required for ${method})`); + } + break; + } + + case 'ping': + case 'initialize': { + break; + } + } +} + +function serverWrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise +): (request: JSONRPCRequest, ctx: ServerContext) => Promise { + if (method !== 'tools/call') { + return handler; + } + return async (request, ctx) => { + const result = await handler(request, ctx); + + const validationResult = parseSchema(CallToolResultSchema, result); + if (!validationResult.success) { + const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); + } + + return validationResult.data; + }; +} + /** - * An MCP server on top of a pluggable transport. - * - * This server will automatically respond to the initialization flow as initiated from the client. + * Creates a server HandlerRegistry with server-specific callbacks. + * @internal + */ +export function createServerRegistry(capabilities?: ServerCapabilities): HandlerRegistry { + const registry = new HandlerRegistry({ + capabilities, + assertRequestHandlerCapability: method => assertServerHandlerCapability(method, registry.getCapabilities()), + wrapHandler: serverWrapHandler + }); + return registry; +} + +/** + * The Protocol-based MCP server implementation. Handles JSON-RPC dispatch, + * request/response correlation, and bidirectional session management. * - * @deprecated Use {@linkcode server/mcp.McpServer | McpServer} instead for the high-level API. Only use `Server` for advanced use cases. + * Used internally by {@linkcode Server} for legacy transport connections and + * by the routing transport for per-session legacy stacks. */ -export class Server extends Protocol { +export class LegacyServer extends Protocol { private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; - private _capabilities: ServerCapabilities; private _instructions?: string; + private _serverInfo: Implementation; private _jsonSchemaValidator: jsonSchemaValidator; - private _experimental?: { tasks: ExperimentalServerTasks }; - /** - * Callback for when initialization has fully completed (i.e., the client has sent an `notifications/initialized` notification). - */ oninitialized?: () => void; - /** - * Initializes this server with the given name and version information. - */ - constructor( - private _serverInfo: Implementation, - options?: ServerOptions - ) { - super({ - ...options, - tasks: extractTaskManagerOptions(options?.capabilities?.tasks) - }); - this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; + constructor(serverInfo: Implementation, options?: ServerOptions) { + const registry = options?.registry ?? createServerRegistry(options?.capabilities); + super(registry, options); + this._serverInfo = serverInfo; this._instructions = options?.instructions; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); - // Strip runtime-only fields from advertised capabilities - if (options?.capabilities?.tasks) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize, ...wireCapabilities } = - options.capabilities.tasks; - this._capabilities.tasks = wireCapabilities; + // Only register default handlers if they haven't been registered already + // (e.g., the Server wrapper may have pre-populated the shared registry) + if (!this._registry.requestHandlers.has('initialize')) { + this.setRequestHandler('initialize', request => this._oninitialize(request)); + } + if (!this._registry.notificationHandlers.has('notifications/initialized')) { + this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); } - this.setRequestHandler('initialize', request => this._oninitialize(request)); - this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); - - if (this._capabilities.logging) { + if (this._registry.getCapabilities().logging) { this._registerLoggingHandler(); } } + getProtocolConfig(): ProtocolConfig { + return { + requestHandlers: this._registry.requestHandlers, + serverInfo: this._serverInfo, + capabilities: this._registry.getCapabilities(), + instructions: this._instructions, + createServer: () => + new LegacyServer(this._serverInfo, { + capabilities: this._registry.getCapabilities(), + instructions: this._instructions + }) + }; + } + private _registerLoggingHandler(): void { this.setRequestHandler('logging/setLevel', async (request, ctx) => { const transportSessionId: string | undefined = @@ -153,7 +219,6 @@ export class Server extends Protocol { } protected override buildContext(ctx: BaseContext, transportInfo?: MessageExtraInfo): ServerContext { - // Only create http when there's actual HTTP transport info or auth info const hasHttpInfo = ctx.http || transportInfo?.request || transportInfo?.closeSSEStream || transportInfo?.closeStandaloneSSEStream; return { ...ctx, @@ -174,98 +239,25 @@ export class Server extends Protocol { }; } - /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - get experimental(): { tasks: ExperimentalServerTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalServerTasks(this) - }; - } - return this._experimental; - } - - // Map log levels by session id private _loggingLevels = new Map(); - - // Map LogLevelSchema to severity index private readonly LOG_LEVEL_SEVERITY = new Map(LoggingLevelSchema.options.map((level, index) => [level, index])); - // Is a message with the given level ignored in the log level set for the given session id? private isMessageIgnored = (level: LoggingLevel, sessionId?: string): boolean => { const currentLevel = this._loggingLevels.get(sessionId); return currentLevel ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! : false; }; - /** - * Registers new capabilities. This can only be called before connecting to a transport. - * - * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). - */ public registerCapabilities(capabilities: ServerCapabilities): void { if (this.transport) { throw new SdkError(SdkErrorCode.AlreadyConnected, 'Cannot register capabilities after connecting to transport'); } - const hadLogging = !!this._capabilities.logging; - this._capabilities = mergeCapabilities(this._capabilities, capabilities); - if (!hadLogging && this._capabilities.logging) { + const hadLogging = !!this._registry.getCapabilities().logging; + this._registry.registerCapabilities(capabilities); + if (!hadLogging && this._registry.getCapabilities().logging) { this._registerLoggingHandler(); } } - /** - * Enforces server-side validation for `tools/call` results regardless of how the - * handler was registered. - */ - protected override _wrapHandler( - method: string, - handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise - ): (request: JSONRPCRequest, ctx: ServerContext) => Promise { - if (method !== 'tools/call') { - return handler; - } - return async (request, ctx) => { - const validatedRequest = parseSchema(CallToolRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); - } - - const { params } = validatedRequest.data; - - const result = await handler(request, ctx); - - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } - - // For non-task requests, validate against CallToolResultSchema - const validationResult = parseSchema(CallToolResultSchema, result); - if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); - } - - return validationResult.data; - }; - } - protected assertCapabilityForMethod(method: RequestMethod | string): void { switch (method) { case 'sampling/createMessage': { @@ -293,7 +285,6 @@ export class Server extends Protocol { } case 'ping': { - // No specific capability required for ping break; } } @@ -302,7 +293,7 @@ export class Server extends Protocol { protected assertNotificationCapability(method: NotificationMethod | string): void { switch (method) { case 'notifications/message': { - if (!this._capabilities.logging) { + if (!this._registry.getCapabilities().logging) { throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); } break; @@ -310,7 +301,7 @@ export class Server extends Protocol { case 'notifications/resources/updated': case 'notifications/resources/list_changed': { - if (!this._capabilities.resources) { + if (!this._registry.getCapabilities().resources) { throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Server does not support notifying about resources (required for ${method})` @@ -320,7 +311,7 @@ export class Server extends Protocol { } case 'notifications/tools/list_changed': { - if (!this._capabilities.tools) { + if (!this._registry.getCapabilities().tools) { throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Server does not support notifying of tool list changes (required for ${method})` @@ -330,7 +321,7 @@ export class Server extends Protocol { } case 'notifications/prompts/list_changed': { - if (!this._capabilities.prompts) { + if (!this._registry.getCapabilities().prompts) { throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Server does not support notifying of prompt list changes (required for ${method})` @@ -349,75 +340,13 @@ export class Server extends Protocol { break; } - case 'notifications/cancelled': { - // Cancellation notifications are always allowed - break; - } - + case 'notifications/cancelled': case 'notifications/progress': { - // Progress notifications are always allowed - break; - } - } - } - - protected assertRequestHandlerCapability(method: string): void { - switch (method) { - case 'completion/complete': { - if (!this._capabilities.completions) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support completions (required for ${method})`); - } - break; - } - - case 'logging/setLevel': { - if (!this._capabilities.logging) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); - } - break; - } - - case 'prompts/get': - case 'prompts/list': { - if (!this._capabilities.prompts) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support prompts (required for ${method})`); - } - break; - } - - case 'resources/list': - case 'resources/templates/list': - case 'resources/read': { - if (!this._capabilities.resources) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support resources (required for ${method})`); - } - break; - } - - case 'tools/call': - case 'tools/list': { - if (!this._capabilities.tools) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support tools (required for ${method})`); - } - break; - } - - case 'ping': - case 'initialize': { - // No specific capability required for these methods break; } } } - protected assertTaskCapability(method: string): void { - assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, method, 'Client'); - } - - protected assertTaskHandlerCapability(method: string): void { - assertToolsCallTaskCapability(this._capabilities?.tasks?.requests, method, 'Server'); - } - private async _oninitialize(request: InitializeRequest): Promise { const requestedVersion = request.params.protocolVersion; @@ -438,65 +367,36 @@ export class Server extends Protocol { }; } - /** - * After initialization has completed, this will be populated with the client's reported capabilities. - */ getClientCapabilities(): ClientCapabilities | undefined { return this._clientCapabilities; } - /** - * After initialization has completed, this will be populated with information about the client's name and version. - */ getClientVersion(): Implementation | undefined { return this._clientVersion; } - /** - * Returns the current server capabilities. - */ public getCapabilities(): ServerCapabilities { - return this._capabilities; + return this._registry.getCapabilities(); } async ping() { return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema); } - /** - * Request LLM sampling from the client (without tools). - * Returns single content block for backwards compatibility. - */ async createMessage(params: CreateMessageRequestParamsBase, options?: RequestOptions): Promise; - - /** - * Request LLM sampling from the client with tool support. - * Returns content that may be a single block or array (for parallel tool calls). - */ async createMessage(params: CreateMessageRequestParamsWithTools, options?: RequestOptions): Promise; - - /** - * Request LLM sampling from the client. - * When tools may or may not be present, returns the union type. - */ async createMessage( params: CreateMessageRequest['params'], options?: RequestOptions ): Promise; - - // Implementation async createMessage( params: CreateMessageRequest['params'], options?: RequestOptions ): Promise { - // Capability check - only required when tools/toolChoice are provided if ((params.tools || params.toolChoice) && !this._clientCapabilities?.sampling?.tools) { throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support sampling tools capability.'); } - // Message structure validation - always validate tool_use/tool_result pairs. - // These may appear even without tools/toolChoice in the current request when - // a previous sampling request returned tool_use and this is a follow-up with results. if (params.messages.length > 0) { const lastMessage = params.messages.at(-1)!; const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; @@ -538,20 +438,12 @@ export class Server extends Protocol { } } - // Use different schemas based on whether tools are provided if (params.tools) { return this._requestWithSchema({ method: 'sampling/createMessage', params }, CreateMessageResultWithToolsSchema, options); } return this._requestWithSchema({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); } - /** - * Creates an elicitation request for the given parameters. - * For backwards compatibility, `mode` may be omitted for form requests and will default to `"form"`. - * @param params The parameters for the elicitation request. - * @param options Optional request options. - * @returns The result of the elicitation request. - */ async elicitInput(params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions): Promise { const mode = (params.mode ?? 'form') as 'form' | 'url'; @@ -604,14 +496,6 @@ export class Server extends Protocol { } } - /** - * Creates a reusable callback that, when invoked, will send a `notifications/elicitation/complete` - * notification for the specified elicitation ID. - * - * @param elicitationId The ID of the elicitation to mark as complete. - * @param options Optional notification options. Useful when the completion notification should be related to a prior request. - * @returns A function that emits the completion notification when awaited. - */ createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { if (!this._clientCapabilities?.elicitation?.url) { throw new SdkError( @@ -624,9 +508,7 @@ export class Server extends Protocol { this.notification( { method: 'notifications/elicitation/complete', - params: { - elicitationId - } + params: { elicitationId } }, options ); @@ -636,30 +518,18 @@ export class Server extends Protocol { return this._requestWithSchema({ method: 'roots/list', params }, ListRootsResultSchema, options); } - /** - * Sends a logging message to the client, if connected. - * Note: You only need to send the parameters object, not the entire JSON-RPC message. - * @see {@linkcode LoggingMessageNotification} - * @param params - * @param sessionId Optional for stateless transports and backward compatibility. - */ async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { - if (this._capabilities.logging && !this.isMessageIgnored(params.level, sessionId)) { + if (this._registry.getCapabilities().logging && !this.isMessageIgnored(params.level, sessionId)) { return this.notification({ method: 'notifications/message', params }); } } async sendResourceUpdated(params: ResourceUpdatedNotification['params']) { - return this.notification({ - method: 'notifications/resources/updated', - params - }); + return this.notification({ method: 'notifications/resources/updated', params }); } async sendResourceListChanged() { - return this.notification({ - method: 'notifications/resources/list_changed' - }); + return this.notification({ method: 'notifications/resources/list_changed' }); } async sendToolListChanged() { @@ -670,3 +540,229 @@ export class Server extends Protocol { return this.notification({ method: 'notifications/prompts/list_changed' }); } } + +/** + * An MCP server on top of a pluggable transport. + * + * Owns a `HandlerRegistry` directly for handler registration and + * capability management. For routing transports, passes registry and config + * directly. For regular transports, creates a {@linkcode LegacyServer} that + * shares the same registry. + */ +export class Server { + private _registry: HandlerRegistry; + private _impl?: LegacyServer; + private _transport?: Transport; + private _serverInfo: Implementation; + private _instructions?: string; + private _options?: ServerOptions; + + oninitialized?: () => void; + onclose?: () => void; + onerror?: (error: Error) => void; + + get fallbackRequestHandler() { + return this._registry.fallbackRequestHandler; + } + set fallbackRequestHandler(h) { + this._registry.fallbackRequestHandler = h; + } + + constructor(serverInfo: Implementation, options?: ServerOptions) { + this._serverInfo = serverInfo; + this._instructions = options?.instructions; + this._options = options; + + this._registry = createServerRegistry(options?.capabilities); + } + + private _createLegacyServer(): LegacyServer { + return new LegacyServer(this._serverInfo, { + ...this._options, + registry: this._registry + }); + } + + async connect(transport: Transport): Promise { + this._transport = transport; + + if (transport.setProtocolConfig) { + transport.setProtocolConfig({ + requestHandlers: this._registry.requestHandlers, + serverInfo: this._serverInfo, + capabilities: this._registry.getCapabilities(), + instructions: this._instructions, + createServer: () => this._createLegacyServer() + }); + await transport.start(); + } else { + this._impl = this._createLegacyServer(); + if (this.oninitialized) this._impl.oninitialized = this.oninitialized; + if (this.onclose) this._impl.onclose = this.onclose; + if (this.onerror) this._impl.onerror = this.onerror; + await this._impl.connect(transport); + } + } + + async close(): Promise { + await (this._impl?.transport ? this._impl.close() : this._transport?.close()); + } + + get transport(): Transport | undefined { + return this._impl?.transport ?? this._transport; + } + + // Handler registration — delegates to shared registry + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: ( + params: StandardSchemaV1.InferOutput

, + ctx: ServerContext + ) => + | (R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result) + | Promise : Result> + ): void; + setRequestHandler(method: string, ...args: unknown[]): void { + (this._registry.setRequestHandler as (...a: unknown[]) => void).call(this._registry, method, ...args); + } + + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void; + setNotificationHandler

( + method: string, + schemas: { params: P }, + handler: (params: StandardSchemaV1.InferOutput

, notification: Notification) => void | Promise + ): void; + setNotificationHandler(method: string, ...args: unknown[]): void { + (this._registry.setNotificationHandler as (...a: unknown[]) => void).call(this._registry, method, ...args); + } + + removeRequestHandler(method: RequestMethod | string): void { + this._registry.removeRequestHandler(method); + } + + removeNotificationHandler(method: NotificationMethod | string): void { + this._registry.removeNotificationHandler(method); + } + + assertCanSetRequestHandler(method: RequestMethod | string): void { + this._registry.assertCanSetRequestHandler(method); + } + + registerCapabilities(capabilities: ServerCapabilities): void { + if (this._impl?.transport || this._transport) { + throw new SdkError(SdkErrorCode.AlreadyConnected, 'Cannot register capabilities after connecting to transport'); + } + this._registry.registerCapabilities(capabilities); + } + + getCapabilities(): ServerCapabilities { + return this._registry.getCapabilities(); + } + + getClientCapabilities(): ClientCapabilities | undefined { + return this._impl?.getClientCapabilities(); + } + + getClientVersion(): Implementation | undefined { + return this._impl?.getClientVersion(); + } + + // Server-to-client methods — only work when connected to a regular transport + async createMessage(params: CreateMessageRequestParamsBase, options?: RequestOptions): Promise; + async createMessage(params: CreateMessageRequestParamsWithTools, options?: RequestOptions): Promise; + async createMessage( + params: CreateMessageRequest['params'], + options?: RequestOptions + ): Promise; + async createMessage( + params: CreateMessageRequest['params'], + options?: RequestOptions + ): Promise { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.createMessage(params, options); + } + + async elicitInput(params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions): Promise { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.elicitInput(params, options); + } + + createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { + if (!this._impl?.getClientCapabilities()?.elicitation?.url) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + 'Client does not support URL elicitation (required for notifications/elicitation/complete)' + ); + } + return () => + this.notification( + { + method: 'notifications/elicitation/complete', + params: { elicitationId } + }, + options + ); + } + + async listRoots(params?: ListRootsRequest['params'], options?: RequestOptions) { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.listRoots(params, options); + } + + async ping() { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.ping(); + } + + request( + request: { method: M; params?: Record }, + options?: RequestOptions + ): Promise; + request( + request: { method: string; params?: Record }, + resultSchema: T, + options?: RequestOptions + ): Promise>; + request(request: { method: string; params?: Record }, ...args: unknown[]): Promise { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return (this._impl.request as (...a: unknown[]) => Promise).call(this._impl, request, ...args); + } + + async notification(notification: Notification, options?: NotificationOptions): Promise { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.notification(notification, options); + } + + async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.sendLoggingMessage(params, sessionId); + } + + async sendResourceUpdated(params: ResourceUpdatedNotification['params']) { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.sendResourceUpdated(params); + } + + async sendResourceListChanged() { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.sendResourceListChanged(); + } + + async sendToolListChanged() { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.sendToolListChanged(); + } + + async sendPromptListChanged() { + if (!this._impl) throw new SdkError(SdkErrorCode.UnsupportedOperation, 'Not connected to a legacy transport'); + return this._impl.sendPromptListChanged(); + } +} diff --git a/packages/server/src/server/stdio.examples.ts b/packages/server/src/server/stdio.examples.ts index de4603eaa7..a49ebe2c65 100644 --- a/packages/server/src/server/stdio.examples.ts +++ b/packages/server/src/server/stdio.examples.ts @@ -8,15 +8,15 @@ */ import { McpServer } from './mcp.js'; -import { StdioServerTransport } from './stdio.js'; +import { LegacyStdioServerTransport } from './stdio.js'; /** * Example: Basic stdio transport usage. */ -async function StdioServerTransport_basicUsage() { - //#region StdioServerTransport_basicUsage +async function LegacyStdioServerTransport_basicUsage() { + //#region LegacyStdioServerTransport_basicUsage const server = new McpServer({ name: 'my-server', version: '1.0.0' }); - const transport = new StdioServerTransport(); + const transport = new LegacyStdioServerTransport(); await server.connect(transport); - //#endregion StdioServerTransport_basicUsage + //#endregion LegacyStdioServerTransport_basicUsage } diff --git a/packages/server/src/server/stdio.ts b/packages/server/src/server/stdio.ts index ac2dd3f784..8da32b5024 100644 --- a/packages/server/src/server/stdio.ts +++ b/packages/server/src/server/stdio.ts @@ -10,13 +10,13 @@ import { process } from '@modelcontextprotocol/server/_shims'; * This transport is only available in Node.js environments. * * @example - * ```ts source="./stdio.examples.ts#StdioServerTransport_basicUsage" + * ```ts source="./stdio.examples.ts#LegacyStdioServerTransport_basicUsage" * const server = new McpServer({ name: 'my-server', version: '1.0.0' }); - * const transport = new StdioServerTransport(); + * const transport = new LegacyStdioServerTransport(); * await server.connect(transport); * ``` */ -export class StdioServerTransport implements Transport { +export class LegacyStdioServerTransport implements Transport { private _readBuffer: ReadBuffer = new ReadBuffer(); private _started = false; private _closed = false; @@ -51,7 +51,7 @@ export class StdioServerTransport implements Transport { async start(): Promise { if (this._started) { throw new Error( - 'StdioServerTransport already started! If using Server class, note that connect() calls start() automatically.' + 'LegacyStdioServerTransport already started! If using Server class, note that connect() calls start() automatically.' ); } @@ -102,7 +102,7 @@ export class StdioServerTransport implements Transport { send(message: JSONRPCMessage): Promise { if (this._closed) { - return Promise.reject(new Error('StdioServerTransport is closed')); + return Promise.reject(new Error('LegacyStdioServerTransport is closed')); } return new Promise((resolve, reject) => { const json = serializeMessage(message); diff --git a/packages/server/src/server/streamableHttp.examples.ts b/packages/server/src/server/streamableHttp.examples.ts index a805c1dcee..26ac395745 100644 --- a/packages/server/src/server/streamableHttp.examples.ts +++ b/packages/server/src/server/streamableHttp.examples.ts @@ -8,7 +8,7 @@ */ import { McpServer } from './mcp.js'; -import { WebStandardStreamableHTTPServerTransport } from './streamableHttp.js'; +import { WebStandardStreamableHTTPServerTransport } from './modernStreamableHttp.js'; /** * Example: Stateful Streamable HTTP transport (Web Standard). diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index fd3563a077..be733e2d01 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -68,9 +68,16 @@ interface StreamMapping { } /** - * Configuration options for {@linkcode WebStandardStreamableHTTPServerTransport} + * Configuration options for {@linkcode LegacyWebStandardStreamableHTTPServerTransport} */ export interface WebStandardStreamableHTTPServerTransportOptions { + /** + * When `true`, ignore the `Mcp-Method` header and treat all requests as legacy (2025-11) + * protocol. Modern clients that probe with `server/discover` will fail and fall back + * to the legacy protocol automatically. + */ + forceLegacy?: boolean; + /** * Function that generates a session ID for the transport. * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) @@ -94,7 +101,7 @@ export interface WebStandardStreamableHTTPServerTransportOptions { * Useful in cases when you need to clean up resources associated with the session. * Note that this is different from the transport closing, if you are handling * HTTP requests from multiple nodes you might want to close each - * {@linkcode WebStandardStreamableHTTPServerTransport} after a request is completed while still keeping the + * {@linkcode LegacyWebStandardStreamableHTTPServerTransport} after a request is completed while still keeping the * session open/running. * @param sessionId The session ID that was closed */ @@ -221,7 +228,7 @@ export interface HandleRequestOptions { * }; * ``` */ -export class WebStandardStreamableHTTPServerTransport implements Transport { +export class LegacyWebStandardStreamableHTTPServerTransport implements Transport { // when sessionId is not set (undefined), it means the transport is in stateless mode private sessionIdGenerator: (() => string) | undefined; private _started: boolean = false; diff --git a/packages/server/src/stdio.ts b/packages/server/src/stdio.ts index 7865c9cedc..015aeb8206 100644 --- a/packages/server/src/stdio.ts +++ b/packages/server/src/stdio.ts @@ -5,4 +5,5 @@ // subpath gives consumers a consistent shape across packages. Import from // `@modelcontextprotocol/server/stdio` only in process-stdio runtimes (Node.js, Bun, Deno). -export { StdioServerTransport } from './server/stdio.js'; +export { StdioServerTransport } from './server/modernStdio.js'; +export { LegacyStdioServerTransport } from './server/stdio.js'; diff --git a/packages/server/test/server/httpVersionRouting.test.ts b/packages/server/test/server/httpVersionRouting.test.ts new file mode 100644 index 0000000000..28b5a17ffe --- /dev/null +++ b/packages/server/test/server/httpVersionRouting.test.ts @@ -0,0 +1,618 @@ +import { describe, it, expect, beforeEach } from 'vitest'; +import { z } from 'zod/v4'; +import type { + CallToolResult, + GetPromptResult, + JSONRPCErrorResponse, + ListPromptsResult, + ListResourcesResult, + ListToolsResult, + ReadResourceResult +} from '@modelcontextprotocol/core'; +import { McpServer } from '../../src/server/mcp.js'; +import { WebStandardStreamableHTTPServerTransport } from '../../src/server/modernStreamableHttp.js'; + +interface DiscoverResult { + supportedVersions: string[]; + serverInfo: { name: string; version: string }; + capabilities: Record; +} +interface JsonRpcOk { + jsonrpc: '2.0'; + id: number; + result: T & { result_type?: string }; +} +type JsonRpcErr = JSONRPCErrorResponse; + +describe('WebStandardStreamableHTTPServerTransport', () => { + let server: McpServer; + let transport: WebStandardStreamableHTTPServerTransport; + + beforeEach(async () => { + server = new McpServer({ name: 'test-server', version: '1.0.0' }); + + server.registerTool('greet', { description: 'Greet someone', inputSchema: { name: z.string() } }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + + server.registerResource('test-resource', 'test://doc', { description: 'A test resource' }, async () => ({ + contents: [{ uri: 'test://doc', text: 'Resource content here' }] + })); + + server.registerPrompt('test-prompt', { description: 'A test prompt' }, async () => ({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello from prompt' } }] + })); + + transport = new WebStandardStreamableHTTPServerTransport({ + sessionIdGenerator: () => crypto.randomUUID() + }); + + await server.connect(transport); + }); + + describe('modern 2026-06 path', () => { + it('handles server/discover', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'server/discover', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'server/discover', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.supportedVersions).toContain('2026-06-30'); + expect(body.result.serverInfo.name).toBe('test-server'); + expect(body.result.capabilities).toBeDefined(); + }); + + it('handles tools/call', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'tools/call', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'World' }, + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.content).toMatchObject([{ type: 'text', text: 'Hello, World!' }]); + }); + + it('handles tools/list', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'tools/list', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.tools).toHaveLength(1); + expect(body.result.tools).toMatchObject([{ name: 'greet' }]); + }); + + it('handles resources/list', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'resources/list', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'resources/list', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.resources).toMatchObject([{ uri: 'test://doc', name: 'test-resource' }]); + }); + + it('handles resources/read', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'resources/read', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'resources/read', + params: { + uri: 'test://doc', + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.contents).toMatchObject([{ uri: 'test://doc', text: 'Resource content here' }]); + }); + + it('handles prompts/list', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'prompts/list', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'prompts/list', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.prompts).toMatchObject([{ name: 'test-prompt' }]); + }); + + it('handles prompts/get', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'prompts/get', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'prompts/get', + params: { + name: 'test-prompt', + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcOk; + expect(body.result.result_type).toBe('complete'); + expect(body.result.messages).toMatchObject([{ role: 'user', content: { type: 'text', text: 'Hello from prompt' } }]); + }); + + it('returns method not found for unknown methods', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'unknown/method', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'unknown/method', + params: { + _meta: { protocolVersion: '2026-06-30' } + } + }) + }) + ); + + expect(response.status).toBe(200); + const body = (await response.json()) as JsonRpcErr; + expect(body.error.code).toBe(-32601); + }); + + it('rejects wrong Content-Type', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'text/plain', + 'Mcp-Method': 'tools/call' + }, + body: 'not json' + }) + ); + + expect(response.status).toBe(415); + }); + + it('rejects non-POST methods', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'GET', + headers: { 'Mcp-Method': 'server/discover' } + }) + ); + + expect(response.status).toBe(405); + }); + + it('rejects batch requests', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'tools/call' + }, + body: JSON.stringify([ + { jsonrpc: '2.0', id: 1, method: 'tools/list', params: {} }, + { jsonrpc: '2.0', id: 2, method: 'tools/list', params: {} } + ]) + }) + ); + + expect(response.status).toBe(400); + const body = (await response.json()) as JsonRpcErr; + expect(body.error.message).toContain('Batch'); + }); + }); + + describe('legacy 2025-11 path', () => { + it('handles initialize + tools/call', async () => { + // Step 1: Initialize + const initResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }) + }) + ); + + const sessionId = initResponse.headers.get('mcp-session-id'); + expect(sessionId).toBeDefined(); + + // Step 2: Send initialized notification + await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': sessionId! + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }) + }) + ); + + // Step 3: Call tool + const toolResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': sessionId! + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'World' } + } + }) + }) + ); + + // The response could be SSE or JSON depending on transport config + // For SSE, we need to parse the event stream + const contentType = toolResponse.headers.get('content-type'); + if (contentType?.includes('text/event-stream')) { + const text = await toolResponse.text(); + const dataLines = text.split('\n').filter(line => line.startsWith('data: ')); + const lastData = dataLines[dataLines.length - 1]!; + const parsed = JSON.parse(lastData.replace('data: ', '')) as JsonRpcOk; + expect(parsed.result.content).toMatchObject([{ type: 'text', text: 'Hello, World!' }]); + } else { + const body = (await toolResponse.json()) as JsonRpcOk; + expect(body.result.content).toMatchObject([{ type: 'text', text: 'Hello, World!' }]); + } + }); + + it('returns 404 for unknown session ID', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': 'nonexistent-session' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: {} + }) + }) + ); + + expect(response.status).toBe(404); + }); + + it('handles DELETE for session termination', async () => { + // Initialize a session + const initResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }) + }) + ); + const sessionId = initResponse.headers.get('mcp-session-id')!; + + // Send DELETE + const deleteResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'DELETE', + headers: { 'Mcp-Session-Id': sessionId } + }) + ); + + expect(deleteResponse.status).toBe(200); + + // Session should be gone — subsequent request returns 404 + const afterDelete = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': sessionId + }, + body: JSON.stringify({ jsonrpc: '2.0', id: 2, method: 'tools/list', params: {} }) + }) + ); + + expect(afterDelete.status).toBe(404); + }); + + it('rejects GET without session ID', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'GET', + headers: {} + }) + ); + + expect(response.status).toBe(400); + }); + + it('rejects DELETE without session ID', async () => { + const response = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'DELETE', + headers: {} + }) + ); + + expect(response.status).toBe(400); + }); + }); + + describe('same tool on both paths', () => { + it('returns identical content for the same tool call', async () => { + // Modern path + const modernResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Mcp-Method': 'tools/call', + 'MCP-Protocol-Version': '2026-06-30' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Alice' }, + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }) + }) + ); + const modernBody = (await modernResponse.json()) as JsonRpcOk; + + // Legacy path: initialize first + const initResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }) + }) + ); + const sessionId = initResponse.headers.get('mcp-session-id')!; + + await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': sessionId + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }) + }) + ); + + const legacyResponse = await transport.handleRequest( + new Request('http://localhost/mcp', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'Mcp-Session-Id': sessionId + }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Alice' } + } + }) + }) + ); + + // Extract legacy result (may be SSE or JSON) + let legacyContent; + const contentType = legacyResponse.headers.get('content-type'); + if (contentType?.includes('text/event-stream')) { + const text = await legacyResponse.text(); + const dataLines = text.split('\n').filter(line => line.startsWith('data: ')); + const lastData = dataLines[dataLines.length - 1]!; + const parsed = JSON.parse(lastData.replace('data: ', '')) as JsonRpcOk; + legacyContent = parsed.result.content; + } else { + const body = (await legacyResponse.json()) as JsonRpcOk; + legacyContent = body.result.content; + } + + // Both paths should return the same content + expect(modernBody.result.content).toEqual(legacyContent); + expect(modernBody.result.content).toMatchObject([{ type: 'text', text: 'Hello, Alice!' }]); + }); + }); +}); diff --git a/packages/server/test/server/stdio.test.ts b/packages/server/test/server/stdio.test.ts index 92671cacd9..83e0bf1e9f 100644 --- a/packages/server/test/server/stdio.test.ts +++ b/packages/server/test/server/stdio.test.ts @@ -3,7 +3,7 @@ import { Readable, Writable } from 'node:stream'; import type { JSONRPCMessage } from '@modelcontextprotocol/core'; import { ReadBuffer, serializeMessage } from '@modelcontextprotocol/core'; -import { StdioServerTransport } from '../../src/server/stdio.js'; +import { LegacyStdioServerTransport } from '../../src/server/stdio.js'; let input: Readable; let outputBuffer: ReadBuffer; @@ -25,7 +25,7 @@ beforeEach(() => { }); test('should start then close cleanly', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); server.onerror = error => { throw error; }; @@ -42,7 +42,7 @@ test('should start then close cleanly', async () => { }); test('should not read until started', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); server.onerror = error => { throw error; }; @@ -68,7 +68,7 @@ test('should not read until started', async () => { }); test('should read multiple messages', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); server.onerror = error => { throw error; }; @@ -104,7 +104,7 @@ test('should read multiple messages', async () => { }); test('should close and fire onerror when stdout errors', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); let receivedError: Error | undefined; server.onerror = err => { @@ -123,7 +123,7 @@ test('should close and fire onerror when stdout errors', async () => { }); test('should not fire onclose twice when close() is called after stdout error', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); server.onerror = () => {}; let closeCount = 0; @@ -147,7 +147,7 @@ test('should reject send() when stdout errors before drain', async () => { } }); - const server = new StdioServerTransport(input, slowOutput); + const server = new LegacyStdioServerTransport(input, slowOutput); server.onerror = () => {}; await server.start(); @@ -160,7 +160,7 @@ test('should reject send() when stdout errors before drain', async () => { }); test('should reject send() after transport is closed', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); await server.start(); await server.close(); @@ -168,7 +168,7 @@ test('should reject send() after transport is closed', async () => { }); test('should fire onerror before onclose on stdout error', async () => { - const server = new StdioServerTransport(input, output); + const server = new LegacyStdioServerTransport(input, output); const events: string[] = []; server.onerror = () => events.push('error'); diff --git a/packages/server/test/server/stdioVersionRouting.test.ts b/packages/server/test/server/stdioVersionRouting.test.ts new file mode 100644 index 0000000000..2ee3bdf435 --- /dev/null +++ b/packages/server/test/server/stdioVersionRouting.test.ts @@ -0,0 +1,480 @@ +import { Readable, Writable } from 'node:stream'; + +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { ReadBuffer, serializeMessage } from '@modelcontextprotocol/core'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { z } from 'zod/v4'; + +import { McpServer } from '../../src/server/mcp.js'; +import { StdioServerTransport } from '../../src/server/modernStdio.js'; + +interface DiscoverResult { + supportedVersions: string[]; + serverInfo: { name: string; version: string }; + capabilities: Record; +} +interface JsonRpcOk { + jsonrpc: '2.0'; + id: number; + result: T & { result_type?: string }; +} +interface JsonRpcErr { + jsonrpc: '2.0'; + id: number; + error: { code: number; message: string }; +} + +function createMockStreams() { + const input = new Readable({ read() {} }); + const messageResolvers: ((msg: JSONRPCMessage) => void)[] = []; + const bufferedMessages: JSONRPCMessage[] = []; + const outputBuffer = new ReadBuffer(); + + const output = new Writable({ + write(chunk, _encoding, callback) { + outputBuffer.append(chunk); + while (true) { + const msg = outputBuffer.readMessage(); + if (!msg) break; + const resolver = messageResolvers.shift(); + if (resolver) { + resolver(msg); + } else { + bufferedMessages.push(msg); + } + } + callback(); + } + }); + + function nextMessage(): Promise { + const buffered = bufferedMessages.shift(); + if (buffered) return Promise.resolve(buffered); + return new Promise(resolve => messageResolvers.push(resolve)); + } + + function sendToStdin(msg: JSONRPCMessage): void { + input.push(serializeMessage(msg)); + } + + return { input, output, nextMessage, sendToStdin }; +} + +describe('StdioServerTransport (routing)', () => { + let server: McpServer; + let transport: StdioServerTransport; + let nextMessage: () => Promise; + let sendToStdin: (msg: JSONRPCMessage) => void; + + beforeEach(async () => { + server = new McpServer({ name: 'test-server', version: '1.0.0' }); + server.registerTool('greet', { description: 'Greet someone', inputSchema: { name: z.string() } }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + + const streams = createMockStreams(); + nextMessage = streams.nextMessage; + sendToStdin = streams.sendToStdin; + transport = new StdioServerTransport(streams.input, streams.output); + await server.connect(transport); + }); + + describe('version detection', () => { + it('detects modern from server/discover', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'server/discover', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const response = (await nextMessage()) as unknown as JsonRpcOk; + expect(response.result.supportedVersions).toContain('2026-06-30'); + expect(response.result.serverInfo.name).toBe('test-server'); + expect(response.result.capabilities).toBeDefined(); + }); + + it('detects modern from _meta.protocolVersion', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'tools/list', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ tools: unknown[]; result_type: string }>; + expect(response.result.result_type).toBe('complete'); + }); + + it('detects legacy from initialize', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ + protocolVersion: string; + capabilities: Record; + serverInfo: { name: string }; + }>; + expect(response.result.protocolVersion).toBeDefined(); + expect(response.result.capabilities).toBeDefined(); + expect(response.result.serverInfo.name).toBe('test-server'); + }); + + it('locks mode on first message', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'server/discover', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + await nextMessage(); + + sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + const response = (await nextMessage()) as JsonRpcErr; + expect(response.error).toBeDefined(); + }); + }); + + describe('modern path', () => { + it('modern: tools/call', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'server/discover', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + await nextMessage(); + + sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'World' }, + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[]; result_type: string }>; + expect(response.result.result_type).toBe('complete'); + expect(response.result.content).toMatchObject([{ type: 'text', text: 'Hello, World!' }]); + }); + + it('modern: tools/list', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'server/discover', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + await nextMessage(); + + sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/list', + params: { + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ tools: { name: string }[]; result_type: string }>; + expect(response.result.result_type).toBe('complete'); + expect(response.result.tools).toHaveLength(1); + expect(response.result.tools).toMatchObject([{ name: 'greet' }]); + }); + }); + + describe('legacy path', () => { + it('legacy: initialize + tools/call', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + const initResponse = (await nextMessage()) as JsonRpcOk<{ protocolVersion: string }>; + expect(initResponse.result.protocolVersion).toBeDefined(); + + sendToStdin({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }); + + await new Promise(r => setTimeout(r, 10)); + + sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'World' } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + expect(response.result.content).toMatchObject([{ type: 'text', text: 'Hello, World!' }]); + }); + + it('rapid messages during legacy init', async () => { + sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + sendToStdin({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }); + + const initResponse = (await nextMessage()) as JsonRpcOk<{ protocolVersion: string }>; + expect(initResponse.result.protocolVersion).toBeDefined(); + + await new Promise(r => setTimeout(r, 10)); + + sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Rapid' } + } + }); + + const response = (await nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + expect(response.result.content).toMatchObject([{ type: 'text', text: 'Hello, Rapid!' }]); + }); + }); + + describe('cross-path', () => { + it('same tool returns identical content on both paths', async () => { + const modernStreams = createMockStreams(); + const modernServer = new McpServer({ name: 'test-server', version: '1.0.0' }); + modernServer.registerTool('greet', { description: 'Greet someone', inputSchema: { name: z.string() } }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + const modernTransport = new StdioServerTransport(modernStreams.input, modernStreams.output); + await modernServer.connect(modernTransport); + + modernStreams.sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Alice' }, + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const modernResponse = (await modernStreams.nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + + const legacyStreams = createMockStreams(); + const legacyServer = new McpServer({ name: 'test-server', version: '1.0.0' }); + legacyServer.registerTool('greet', { description: 'Greet someone', inputSchema: { name: z.string() } }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + const legacyTransport = new StdioServerTransport(legacyStreams.input, legacyStreams.output); + await legacyServer.connect(legacyTransport); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + await legacyStreams.nextMessage(); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }); + + await new Promise(r => setTimeout(r, 10)); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Alice' } + } + }); + + const legacyResponse = (await legacyStreams.nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + + expect(modernResponse.result.content).toEqual(legacyResponse.result.content); + expect(modernResponse.result.content).toMatchObject([{ type: 'text', text: 'Hello, Alice!' }]); + }); + + it('handler registered after connect() is available on both paths', async () => { + const modernStreams = createMockStreams(); + const sharedServer = new McpServer({ name: 'test-server', version: '1.0.0' }); + sharedServer.registerTool('seed', { description: 'Seed tool', inputSchema: z.object({}) }, async () => ({ + content: [{ type: 'text', text: 'seed' }] + })); + const modernTransport = new StdioServerTransport(modernStreams.input, modernStreams.output); + await sharedServer.connect(modernTransport); + + vi.spyOn(sharedServer, 'sendToolListChanged').mockImplementation(() => {}); + sharedServer.registerTool( + 'late-tool', + { description: 'Registered after connect', inputSchema: { x: z.number() } }, + async ({ x }) => ({ + content: [{ type: 'text', text: `Result: ${x * 2}` }] + }) + ); + + modernStreams.sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { + name: 'late-tool', + arguments: { x: 21 }, + _meta: { + protocolVersion: '2026-06-30', + clientCapabilities: {}, + clientInfo: { name: 'test-client', version: '1.0.0' } + } + } + }); + + const modernResponse = (await modernStreams.nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + expect(modernResponse.result.content).toMatchObject([{ type: 'text', text: 'Result: 42' }]); + + const legacyStreams = createMockStreams(); + const legacyServer = new McpServer({ name: 'test-server', version: '1.0.0' }); + legacyServer.registerTool('seed', { description: 'Seed tool', inputSchema: z.object({}) }, async () => ({ + content: [{ type: 'text', text: 'seed' }] + })); + const legacyTransport = new StdioServerTransport(legacyStreams.input, legacyStreams.output); + await legacyServer.connect(legacyTransport); + + vi.spyOn(legacyServer, 'sendToolListChanged').mockImplementation(() => {}); + legacyServer.registerTool( + 'late-tool', + { description: 'Registered after connect', inputSchema: { x: z.number() } }, + async ({ x }) => ({ + content: [{ type: 'text', text: `Result: ${x * 2}` }] + }) + ); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: '2025-11-25', + capabilities: {}, + clientInfo: { name: 'legacy-client', version: '1.0.0' } + } + }); + + await legacyStreams.nextMessage(); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + method: 'notifications/initialized' + }); + + await new Promise(r => setTimeout(r, 10)); + + legacyStreams.sendToStdin({ + jsonrpc: '2.0', + id: 2, + method: 'tools/call', + params: { + name: 'late-tool', + arguments: { x: 21 } + } + }); + + const legacyResponse = (await legacyStreams.nextMessage()) as JsonRpcOk<{ content: { type: string; text: string }[] }>; + expect(legacyResponse.result.content).toMatchObject([{ type: 'text', text: 'Result: 42' }]); + }); + }); +}); diff --git a/packages/server/test/server/streamableHttp.test.ts b/packages/server/test/server/streamableHttp.test.ts index 7a23dd56bb..4c95b8a4b9 100644 --- a/packages/server/test/server/streamableHttp.test.ts +++ b/packages/server/test/server/streamableHttp.test.ts @@ -5,7 +5,7 @@ import * as z from 'zod/v4'; import { McpServer } from '../../src/server/mcp.js'; import type { EventId, EventStore, StreamId } from '../../src/server/streamableHttp.js'; -import { WebStandardStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; +import { LegacyWebStandardStreamableHTTPServerTransport } from '../../src/server/streamableHttp.js'; /** * Common test messages @@ -119,7 +119,7 @@ function expectErrorResponse(data: unknown, expectedCode: number, expectedMessag describe('Zod v4', () => { describe('HTTPServerTransport', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; let sessionId: string; @@ -137,7 +137,7 @@ describe('Zod v4', () => { } ); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); @@ -428,7 +428,7 @@ describe('Zod v4', () => { }); describe('HTTPServerTransport - Stateless Mode', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; beforeEach(async () => { @@ -442,7 +442,7 @@ describe('Zod v4', () => { } ); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: undefined }); @@ -475,7 +475,7 @@ describe('Zod v4', () => { }); describe('HTTPServerTransport - JSON Response Mode', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; let sessionId: string; @@ -490,7 +490,7 @@ describe('Zod v4', () => { } ); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true }); @@ -566,7 +566,7 @@ describe('Zod v4', () => { const onInitialized = vi.fn(); const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); - const transport = new WebStandardStreamableHTTPServerTransport({ + const transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => 'test-session-123', onsessioninitialized: onInitialized }); @@ -585,7 +585,7 @@ describe('Zod v4', () => { const onClosed = vi.fn(); const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); - const transport = new WebStandardStreamableHTTPServerTransport({ + const transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => 'test-session-456', onsessionclosed: onClosed }); @@ -605,7 +605,7 @@ describe('Zod v4', () => { }); describe('HTTPServerTransport - Event Store (Resumability)', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; let eventStore: EventStore; let storedEvents: Map; @@ -662,7 +662,7 @@ describe('Zod v4', () => { } ); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), eventStore }); @@ -708,14 +708,14 @@ describe('Zod v4', () => { }); describe('HTTPServerTransport - Protocol Version Validation', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; let sessionId: string; beforeEach(async () => { mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); @@ -756,7 +756,7 @@ describe('Zod v4', () => { describe('HTTPServerTransport - start() method', () => { it('should throw error when started twice', async () => { - const transport = new WebStandardStreamableHTTPServerTransport({ + const transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); @@ -767,7 +767,7 @@ describe('Zod v4', () => { }); describe('HTTPServerTransport - onerror callback', () => { - let transport: WebStandardStreamableHTTPServerTransport; + let transport: LegacyWebStandardStreamableHTTPServerTransport; let mcpServer: McpServer; let errors: Error[]; @@ -775,7 +775,7 @@ describe('Zod v4', () => { errors = []; mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); - transport = new WebStandardStreamableHTTPServerTransport({ + transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); @@ -959,7 +959,7 @@ describe('Zod v4', () => { describe('close() re-entrancy guard', () => { it('should not recurse when onclose triggers a second close()', async () => { - const transport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: randomUUID }); + const transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: randomUUID }); let closeCallCount = 0; transport.onclose = () => { @@ -975,7 +975,7 @@ describe('Zod v4', () => { }); it('should clean up all streams exactly once even when close() is called concurrently', async () => { - const transport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: randomUUID }); + const transport = new LegacyWebStandardStreamableHTTPServerTransport({ sessionIdGenerator: randomUUID }); const cleanupCalls: string[] = []; diff --git a/test/helpers/src/helpers/tasks.ts b/test/helpers/src/helpers/tasks.ts deleted file mode 100644 index 4db3231a67..0000000000 --- a/test/helpers/src/helpers/tasks.ts +++ /dev/null @@ -1,33 +0,0 @@ -import type { Task } from '@modelcontextprotocol/core'; - -/** - * Polls the provided getTask function until the task reaches the desired status or times out. - */ -export async function waitForTaskStatus( - getTask: (taskId: string) => Promise, - taskId: string, - desiredStatus: Task['status'], - { - intervalMs = 100, - timeoutMs = 10_000 - }: { - intervalMs?: number; - timeoutMs?: number; - } = {} -): Promise { - const start = Date.now(); - - // eslint-disable-next-line no-constant-condition - while (true) { - const task = await getTask(taskId); - if (task && task.status === desiredStatus) { - return task; - } - - if (Date.now() - start > timeoutMs) { - throw new Error(`Timed out waiting for task ${taskId} to reach status ${desiredStatus}`); - } - - await new Promise(resolve => setTimeout(resolve, intervalMs)); - } -} diff --git a/test/helpers/src/index.ts b/test/helpers/src/index.ts index 1ecfa8e24a..1fd7ce2b9b 100644 --- a/test/helpers/src/index.ts +++ b/test/helpers/src/index.ts @@ -1,3 +1,2 @@ export * from './helpers/http.js'; export * from './helpers/oauth.js'; -export * from './helpers/tasks.js'; diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index 52d151bddb..6f6487963f 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -1,8 +1,6 @@ import { Client, getSupportedElicitationModes } from '@modelcontextprotocol/client'; import type { Prompt, Resource, Tool, Transport } from '@modelcontextprotocol/core'; import { - CallToolResultSchema, - ElicitResultSchema, InMemoryTransport, LATEST_PROTOCOL_VERSION, ProtocolErrorCode, @@ -10,8 +8,7 @@ import { SdkErrorCode, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; -import { InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; -import * as z from 'zod/v4'; +import { McpServer, Server } from '@modelcontextprotocol/server'; /*** * Test: Initialize with Matching Protocol Version @@ -1784,20 +1781,7 @@ describe('outputSchema validation', () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - }, - tasks: { - get: true, - list: {}, - result: true - } - } - } - } + capabilities: {} } ); @@ -1877,20 +1861,7 @@ describe('outputSchema validation', () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - }, - tasks: { - get: true, - list: {}, - result: true - } - } - } - } + capabilities: {} } ); @@ -1967,20 +1938,7 @@ describe('outputSchema validation', () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - }, - tasks: { - get: true, - list: {}, - result: true - } - } - } - } + capabilities: {} } ); @@ -2053,20 +2011,7 @@ describe('outputSchema validation', () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - }, - tasks: { - get: true, - list: {}, - result: true - } - } - } - } + capabilities: {} } ); @@ -2166,20 +2111,7 @@ describe('outputSchema validation', () => { version: '1.0.0' }, { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - }, - tasks: { - get: true, - list: {}, - result: true - } - } - } - } + capabilities: {} } ); @@ -2263,7 +2195,7 @@ describe('outputSchema validation', () => { name: 'test-client', version: '1.0.0' }, - { capabilities: { tasks: { requests: { tools: { call: {} } } } } } + { capabilities: {} } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2280,1813 +2212,6 @@ describe('outputSchema validation', () => { }); }); -describe('Task-based execution', () => { - describe('Client calling server', () => { - let serverTaskStore: InMemoryTaskStore; - - beforeEach(() => { - serverTaskStore = new InMemoryTaskStore(); - }); - - afterEach(() => { - serverTaskStore?.cleanup(); - }); - - test('should create task on server via tool call', async () => { - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: 'Tool executed successfully!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { capabilities: { tasks: { requests: { tools: { call: {} } } } } } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Client creates task on server via tool call - await client.callTool( - { name: 'test-tool', arguments: {} }, - { - task: { - ttl: 60_000 - } - } - ); - - // Verify task was created successfully by listing tasks - const taskList = await client.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const task = taskList.tasks[0]!; - expect(task.status).toBe('completed'); - }); - - test('should query task status from server using getTask', async () => { - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: 'Success!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { capabilities: { tasks: { requests: { tools: { call: {} } } } } } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create a task - await client.callTool( - { name: 'test-tool', arguments: {} }, - { - task: { ttl: 60_000 } - } - ); - - // Query task status by listing tasks and getting the first one - const taskList = await client.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThan(0); - const task = taskList.tasks[0]!; - expect(task).toBeDefined(); - expect(task.taskId).toBeDefined(); - expect(task.status).toBe('completed'); - }); - - test('should query task result from server using getTaskResult', async () => { - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {}, - list: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: 'Result data!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { capabilities: { tasks: { requests: { tools: { call: {} } } } } } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create a task using callToolStream to capture the task ID - let taskId: string | undefined; - const stream = client.experimental.tasks.callToolStream( - { name: 'test-tool', arguments: {} }, - { - task: { ttl: 60_000 } - } - ); - - for await (const message of stream) { - if (message.type === 'taskCreated') { - taskId = message.task.taskId; - } - } - - expect(taskId).toBeDefined(); - - // Query task result using the captured task ID - const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); - expect(result.content).toEqual([{ type: 'text', text: 'Result data!' }]); - }); - - test('should query task list from server using listTasks', async () => { - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: 'Success!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { capabilities: { tasks: { requests: { tools: { call: {} } } } } } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create multiple tasks - const createdTaskIds: string[] = []; - - for (let i = 0; i < 2; i++) { - await client.callTool( - { name: 'test-tool', arguments: {} }, - { - task: { ttl: 60_000 } - } - ); - - // Get the task ID from the task list - const taskList = await client.experimental.tasks.listTasks(); - const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); - if (newTask) { - createdTaskIds.push(newTask.taskId); - } - } - - // Query task list - const taskList = await client.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of createdTaskIds) { - expect(taskList.tasks).toContainEqual( - expect.objectContaining({ - taskId, - status: 'completed' - }) - ); - } - }); - }); - - describe('Server calling client', () => { - let clientTaskStore: InMemoryTaskStore; - - beforeEach(() => { - clientTaskStore = new InMemoryTaskStore(); - }); - - afterEach(() => { - clientTaskStore?.cleanup(); - }); - - test('should create task on client via server elicitation', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Server creates task on client via elicitation - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide your username', - requestedSchema: { - type: 'object', - properties: { - username: { type: 'string' } - }, - required: ['username'] - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Verify task was created - const task = await server.experimental.tasks.getTask(taskId); - expect(task.status).toBe('completed'); - }); - - test('should query task status from client using getTask', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create a task on client and wait for CreateTaskResult - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide info', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Query task status - const task = await server.experimental.tasks.getTask(taskId); - expect(task).toBeDefined(); - expect(task.taskId).toBe(taskId); - expect(task.status).toBe('completed'); - }); - - test('should query task result from client using getTaskResult', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'result-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create a task on client and wait for CreateTaskResult - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide info', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Query task result using getTaskResult - const taskResult = await server.experimental.tasks.getTaskResult(taskId, ElicitResultSchema); - expect(taskResult.action).toBe('accept'); - expect(taskResult.content).toEqual({ username: 'result-user' }); - }); - - test('should query task list from client using listTasks', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create multiple tasks on client - const createdTaskIds: string[] = []; - for (let i = 0; i < 2; i++) { - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide info', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure and capture taskId - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - createdTaskIds.push(createTaskResult.task.taskId); - } - - // Query task list - const taskList = await server.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of createdTaskIds) { - expect(taskList.tasks).toContainEqual( - expect.objectContaining({ - taskId, - status: 'completed' - }) - ); - } - }); - }); - - test('should list tasks from server with pagination', async () => { - const serverTaskStore = new InMemoryTaskStore(); - - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({ - id: z.string() - }) - }, - { - async createTask({ id }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: `Result for ${id || 'unknown'}` }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create multiple tasks - const createdTaskIds: string[] = []; - - for (let i = 0; i < 3; i++) { - await client.callTool( - { name: 'test-tool', arguments: { id: `task-${i + 1}` } }, - { - task: { ttl: 60_000 } - } - ); - - // Get the task ID from the task list - const taskList = await client.experimental.tasks.listTasks(); - const newTask = taskList.tasks.find(t => !createdTaskIds.includes(t.taskId)); - if (newTask) { - createdTaskIds.push(newTask.taskId); - } - } - - // List all tasks without cursor - const firstPage = await client.experimental.tasks.listTasks(); - expect(firstPage.tasks.length).toBeGreaterThan(0); - expect(firstPage.tasks.map(t => t.taskId)).toEqual(expect.arrayContaining(createdTaskIds)); - - // If there's a cursor, test pagination - if (firstPage.nextCursor) { - const secondPage = await client.experimental.tasks.listTasks(firstPage.nextCursor); - expect(secondPage.tasks).toBeDefined(); - } - - serverTaskStore.cleanup(); - }); - - describe('Error scenarios', () => { - let serverTaskStore: InMemoryTaskStore; - let clientTaskStore: InMemoryTaskStore; - - beforeEach(() => { - serverTaskStore = new InMemoryTaskStore(); - clientTaskStore = new InMemoryTaskStore(); - }); - - afterEach(() => { - serverTaskStore?.cleanup(); - clientTaskStore?.cleanup(); - }); - - test('should throw error when querying non-existent task from server', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to get a task that doesn't exist - await expect(client.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); - }); - - test('should throw error when querying result of non-existent task from server', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to get result of a task that doesn't exist - await expect(client.experimental.tasks.getTaskResult('non-existent-task', CallToolResultSchema)).rejects.toThrow(); - }); - - test('should throw error when server queries non-existent task from client', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async () => ({ - action: 'accept', - content: { username: 'test' } - })); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to query a task that doesn't exist on client - await expect(server.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); - }); - }); -}); - -test('should respect server task capabilities', async () => { - const serverTaskStore = new InMemoryTaskStore(); - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore: serverTaskStore - } - } - } - ); - - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - const result = { - content: [{ type: 'text', text: 'Success!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - enforceStrictCapabilities: true, - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Server supports task creation for tools/call - expect(client.getServerCapabilities()).toEqual({ - tools: { - listChanged: true - }, - tasks: { - requests: { - tools: { - call: {} - } - } - } - }); - - // These should work because server supports tasks - await expect( - client.callTool( - { name: 'test-tool', arguments: {} }, - { - task: { ttl: 60_000 } - } - ) - ).resolves.not.toThrow(); - await expect(client.experimental.tasks.listTasks()).resolves.not.toThrow(); - - // tools/list doesn't support task creation, but it shouldn't throw - it should just ignore the task metadata - await expect( - client.request({ - method: 'tools/list', - params: {} - }) - ).resolves.not.toThrow(); - - serverTaskStore.cleanup(); -}); - -/** - * Test: requestStream() method - */ -test('should expose requestStream() method for streaming responses', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/call', async () => { - return { - content: [{ type: 'text', text: 'Tool result' }] - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // First verify that regular request() works - const regularResult = await client.callTool({ name: 'test-tool', arguments: {} }); - expect(regularResult.content).toEqual([{ type: 'text', text: 'Tool result' }]); - - // Test requestStream with non-task request (should yield only result) - const stream = client.experimental.tasks.requestStream({ - method: 'tools/call', - params: { name: 'test-tool', arguments: {} } - }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - // Should have received only a result message (no task messages) - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.content).toEqual([{ type: 'text', text: 'Tool result' }]); - } - - await client.close(); - await server.close(); -}); - -/** - * Test: callToolStream() method - */ -test('should expose callToolStream() method for streaming tool calls', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/call', async () => { - return { - content: [{ type: 'text', text: 'Tool result' }] - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Test callToolStream - const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - // Should have received messages ending with result - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.content).toEqual([{ type: 'text', text: 'Tool result' }]); - } - - await client.close(); - await server.close(); -}); - -/** - * Test: callToolStream() with output schema validation - */ -test('should validate structured output in callToolStream()', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => { - return { - tools: [ - { - name: 'structured-tool', - description: 'A tool with output schema', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - value: { type: 'number' } - }, - required: ['value'] - } - } - ] - }; - }); - - server.setRequestHandler('tools/call', async () => { - return { - content: [{ type: 'text', text: 'Result' }], - structuredContent: { value: 42 } - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // List tools to cache the output schema - await client.listTools(); - - // Test callToolStream with valid structured output - const stream = client.experimental.tasks.callToolStream({ name: 'structured-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - // Should have received result with validated structured content - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.structuredContent).toEqual({ value: 42 }); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should yield error when structuredContent does not match schema', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' }, - count: { type: 'number' } - }, - required: ['result', 'count'], - additionalProperties: false - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - // Return invalid structured content (count is string instead of number) - return { - structuredContent: { result: 'success', count: 'not a number' } - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // List tools to cache the schemas - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('error'); - if (messages[0]!.type === 'error') { - expect(messages[0]!.error.message).toMatch(/Structured content does not match the tool's output schema/); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should yield error when tool with outputSchema returns no structuredContent', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' } - }, - required: ['result'] - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - return { - content: [{ type: 'text', text: 'This should be structured content' }] - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('error'); - if (messages[0]!.type === 'error') { - expect(messages[0]!.error.message).toMatch(/Tool test-tool has an output schema but did not return structured content/); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should handle tools without outputSchema normally', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - return { - content: [{ type: 'text', text: 'Normal response' }] - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.content).toEqual([{ type: 'text', text: 'Normal response' }]); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should handle complex JSON schema validation', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'complex-tool', - description: 'A tool with complex schema', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - name: { type: 'string', minLength: 3 }, - age: { type: 'integer', minimum: 0, maximum: 120 }, - active: { type: 'boolean' }, - tags: { - type: 'array', - items: { type: 'string' }, - minItems: 1 - }, - metadata: { - type: 'object', - properties: { - created: { type: 'string' } - }, - required: ['created'] - } - }, - required: ['name', 'age', 'active', 'tags', 'metadata'], - additionalProperties: false - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - return { - structuredContent: { - name: 'John Doe', - age: 30, - active: true, - tags: ['user', 'admin'], - metadata: { - created: '2023-01-01T00:00:00Z' - } - } - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'complex-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.structuredContent).toBeDefined(); - const structuredContent = messages[0]!.result.structuredContent as { name: string; age: number }; - expect(structuredContent.name).toBe('John Doe'); - expect(structuredContent.age).toBe(30); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should yield error with additional properties when not allowed', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'strict-tool', - description: 'A tool with strict schema', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - name: { type: 'string' } - }, - required: ['name'], - additionalProperties: false - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - return { - structuredContent: { - name: 'John', - extraField: 'not allowed' - } - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'strict-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('error'); - if (messages[0]!.type === 'error') { - expect(messages[0]!.error.message).toMatch(/Structured content does not match the tool's output schema/); - } - - await client.close(); - await server.close(); -}); - -test('callToolStream() should not validate structuredContent when isError is true', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - } - ); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} - }, - outputSchema: { - type: 'object', - properties: { - result: { type: 'string' } - }, - required: ['result'] - } - } - ] - })); - - server.setRequestHandler('tools/call', async () => { - // Return isError with content (no structuredContent) - should NOT trigger validation error - return { - isError: true, - content: [{ type: 'text', text: 'Something went wrong' }] - }; - }); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: { requests: { tools: { call: {} } } } } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - await client.listTools(); - - const stream = client.experimental.tasks.callToolStream({ name: 'test-tool', arguments: {} }); - - const messages = []; - for await (const message of stream) { - messages.push(message); - } - - // Should have received result (not error), with isError flag set - expect(messages.length).toBe(1); - expect(messages[0]!.type).toBe('result'); - if (messages[0]!.type === 'result') { - expect(messages[0]!.result.isError).toBe(true); - expect(messages[0]!.result.content).toEqual([{ type: 'text', text: 'Something went wrong' }]); - } - - await client.close(); - await server.close(); -}); - describe('getSupportedElicitationModes', () => { test('should support nothing when capabilities are undefined', () => { const result = getSupportedElicitationModes(undefined); diff --git a/test/integration/test/experimental/tasks/task.test.ts b/test/integration/test/experimental/tasks/task.test.ts deleted file mode 100644 index d2aca2cc07..0000000000 --- a/test/integration/test/experimental/tasks/task.test.ts +++ /dev/null @@ -1,144 +0,0 @@ -import type { Task } from '@modelcontextprotocol/core'; -import { isTerminal, TaskCreationParamsSchema } from '@modelcontextprotocol/core'; -import { describe, expect, it } from 'vitest'; - -describe('Task utility functions', () => { - describe('isTerminal', () => { - it('should return true for completed status', () => { - expect(isTerminal('completed')).toBe(true); - }); - - it('should return true for failed status', () => { - expect(isTerminal('failed')).toBe(true); - }); - - it('should return true for cancelled status', () => { - expect(isTerminal('cancelled')).toBe(true); - }); - - it('should return false for working status', () => { - expect(isTerminal('working')).toBe(false); - }); - - it('should return false for input_required status', () => { - expect(isTerminal('input_required')).toBe(false); - }); - }); -}); - -describe('Task Schema Validation', () => { - it('should validate task with ttl field', () => { - const createdAt = new Date().toISOString(); - const task: Task = { - taskId: 'test-123', - status: 'working', - ttl: 60_000, - createdAt, - lastUpdatedAt: createdAt, - pollInterval: 1000 - }; - - expect(task.ttl).toBe(60_000); - expect(task.createdAt).toBeDefined(); - expect(typeof task.createdAt).toBe('string'); - }); - - it('should validate task with null ttl', () => { - const createdAt = new Date().toISOString(); - const task: Task = { - taskId: 'test-456', - status: 'completed', - ttl: null, - createdAt, - lastUpdatedAt: createdAt - }; - - expect(task.ttl).toBeNull(); - }); - - it('should validate task with statusMessage field', () => { - const createdAt = new Date().toISOString(); - const task: Task = { - taskId: 'test-789', - status: 'failed', - ttl: null, - createdAt, - lastUpdatedAt: createdAt, - statusMessage: 'Operation failed due to timeout' - }; - - expect(task.statusMessage).toBe('Operation failed due to timeout'); - }); - - it('should validate task with createdAt in ISO 8601 format', () => { - const now = new Date(); - const createdAt = now.toISOString(); - const task: Task = { - taskId: 'test-iso', - status: 'working', - ttl: 30_000, - createdAt, - lastUpdatedAt: createdAt - }; - - expect(task.createdAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); - expect(new Date(task.createdAt).getTime()).toBe(now.getTime()); - }); - - it('should validate task with lastUpdatedAt in ISO 8601 format', () => { - const now = new Date(); - const createdAt = now.toISOString(); - const task: Task = { - taskId: 'test-iso', - status: 'working', - ttl: 30_000, - createdAt, - lastUpdatedAt: createdAt - }; - - expect(task.lastUpdatedAt).toMatch(/^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/); - }); - - it('should validate all task statuses', () => { - const statuses: Task['status'][] = ['working', 'input_required', 'completed', 'failed', 'cancelled']; - - const createdAt = new Date().toISOString(); - for (const status of statuses) { - const task: Task = { - taskId: `test-${status}`, - status, - ttl: null, - createdAt, - lastUpdatedAt: createdAt - }; - expect(task.status).toBe(status); - } - }); -}); - -describe('TaskCreationParams Schema Validation', () => { - it('should accept ttl as a number', () => { - const result = TaskCreationParamsSchema.safeParse({ ttl: 60_000 }); - expect(result.success).toBe(true); - }); - - it('should accept missing ttl (optional)', () => { - const result = TaskCreationParamsSchema.safeParse({}); - expect(result.success).toBe(true); - }); - - it('should reject null ttl (not allowed in request, only response)', () => { - const result = TaskCreationParamsSchema.safeParse({ ttl: null }); - expect(result.success).toBe(false); - }); - - it('should accept pollInterval as a number', () => { - const result = TaskCreationParamsSchema.safeParse({ pollInterval: 1000 }); - expect(result.success).toBe(true); - }); - - it('should accept both ttl and pollInterval', () => { - const result = TaskCreationParamsSchema.safeParse({ ttl: 60_000, pollInterval: 1000 }); - expect(result.success).toBe(true); - }); -}); diff --git a/test/integration/test/experimental/tasks/taskListing.test.ts b/test/integration/test/experimental/tasks/taskListing.test.ts deleted file mode 100644 index 2b21e99d51..0000000000 --- a/test/integration/test/experimental/tasks/taskListing.test.ts +++ /dev/null @@ -1,129 +0,0 @@ -import { ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; -import { afterEach, beforeEach, describe, expect, it } from 'vitest'; - -import { createInMemoryTaskEnvironment } from '../../helpers/mcp.js'; - -describe('Task Listing with Pagination', () => { - let client: Awaited>['client']; - let server: Awaited>['server']; - let taskStore: Awaited>['taskStore']; - - beforeEach(async () => { - const env = await createInMemoryTaskEnvironment(); - client = env.client; - server = env.server; - taskStore = env.taskStore; - }); - - afterEach(async () => { - taskStore.cleanup(); - await client.close(); - await server.close(); - }); - - it('should return empty list when no tasks exist', async () => { - const result = await client.experimental.tasks.listTasks(); - - expect(result.tasks).toEqual([]); - expect(result.nextCursor).toBeUndefined(); - }); - - it('should return all tasks when less than page size', async () => { - // Create 3 tasks - for (let i = 0; i < 3; i++) { - await taskStore.createTask({}, i, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - } - - const result = await client.experimental.tasks.listTasks(); - - expect(result.tasks).toHaveLength(3); - expect(result.nextCursor).toBeUndefined(); - }); - - it('should paginate when more than page size exists', async () => { - // Create 15 tasks (page size is 10 in InMemoryTaskStore) - for (let i = 0; i < 15; i++) { - await taskStore.createTask({}, i, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - } - - // Get first page - const page1 = await client.experimental.tasks.listTasks(); - expect(page1.tasks).toHaveLength(10); - expect(page1.nextCursor).toBeDefined(); - - // Get second page using cursor - const page2 = await client.experimental.tasks.listTasks(page1.nextCursor); - expect(page2.tasks).toHaveLength(5); - expect(page2.nextCursor).toBeUndefined(); - }); - - it('should treat cursor as opaque token', async () => { - // Create 5 tasks - for (let i = 0; i < 5; i++) { - await taskStore.createTask({}, i, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - } - - // Get all tasks to get a valid cursor - const allTasks = taskStore.getAllTasks(); - const validCursor = allTasks[2]!.taskId; - - // Use the cursor - should work even though we don't know its internal structure - const result = await client.experimental.tasks.listTasks(validCursor); - expect(result.tasks).toHaveLength(2); - }); - - it('should return error code -32602 for invalid cursor', async () => { - await taskStore.createTask({}, 1, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - - // Try to use an invalid cursor - should return -32602 (Invalid params) per MCP spec - await expect(client.experimental.tasks.listTasks('invalid-cursor')).rejects.toSatisfy((error: ProtocolError) => { - expect(error).toBeInstanceOf(ProtocolError); - expect(error.code).toBe(ProtocolErrorCode.InvalidParams); - expect(error.message).toContain('Invalid cursor'); - return true; - }); - }); - - it('should ensure tasks accessible via tasks/get are also accessible via tasks/list', async () => { - // Create a task - const task = await taskStore.createTask({}, 1, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - - // Verify it's accessible via tasks/get - const getResult = await client.experimental.tasks.getTask(task.taskId); - expect(getResult.taskId).toBe(task.taskId); - - // Verify it's also accessible via tasks/list - const listResult = await client.experimental.tasks.listTasks(); - expect(listResult.tasks).toHaveLength(1); - expect(listResult.tasks[0]!.taskId).toBe(task.taskId); - }); - - it('should not include related-task metadata in list response', async () => { - // Create a task - await taskStore.createTask({}, 1, { - method: 'tools/call', - params: { name: 'test-tool' } - }); - - const result = await client.experimental.tasks.listTasks(); - - // The response should have _meta but not include related-task metadata - expect(result._meta).toBeDefined(); - expect(result._meta?.['io.modelcontextprotocol/related-task']).toBeUndefined(); - }); -}); diff --git a/test/integration/test/helpers/mcp.ts b/test/integration/test/helpers/mcp.ts deleted file mode 100644 index 1fe0b33912..0000000000 --- a/test/integration/test/helpers/mcp.ts +++ /dev/null @@ -1,70 +0,0 @@ -import { Client } from '@modelcontextprotocol/client'; -import { InMemoryTransport } from '@modelcontextprotocol/core'; -import type { ClientCapabilities, ServerCapabilities } from '@modelcontextprotocol/server'; -import { InMemoryTaskMessageQueue, InMemoryTaskStore, Server } from '@modelcontextprotocol/server'; - -export interface InMemoryTaskEnvironment { - client: Client; - server: Server; - taskStore: InMemoryTaskStore; - clientTransport: InMemoryTransport; - serverTransport: InMemoryTransport; -} - -export async function createInMemoryTaskEnvironment(options?: { - clientCapabilities?: ClientCapabilities; - serverCapabilities?: ServerCapabilities; -}): Promise { - const taskStore = new InMemoryTaskStore(); - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: options?.clientCapabilities ?? { - tasks: { - list: {}, - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: options?.serverCapabilities ?? { - tasks: { - list: {}, - requests: { - tools: { - call: {} - } - }, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() - } - } - } - ); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - return { - client, - server, - taskStore, - clientTransport, - serverTransport - }; -} diff --git a/test/integration/test/processCleanup.test.ts b/test/integration/test/processCleanup.test.ts index 3554d99361..afcbd0f483 100644 --- a/test/integration/test/processCleanup.test.ts +++ b/test/integration/test/processCleanup.test.ts @@ -86,7 +86,8 @@ describe('Process cleanup', () => { const transport = new StdioClientTransport({ command: 'node', args: ['--import', 'tsx', 'serverThatHangs.ts'], - cwd: FIXTURES_DIR + cwd: FIXTURES_DIR, + forceLegacy: true }); await client.connect(transport); diff --git a/test/integration/test/server.test.ts b/test/integration/test/server.test.ts index 825af7ea45..ff573ef6d1 100644 --- a/test/integration/test/server.test.ts +++ b/test/integration/test/server.test.ts @@ -1,29 +1,22 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { Client } from '@modelcontextprotocol/client'; import type { - CreateMessageResult, - ElicitRequestSchema, - ElicitResult, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, LoggingMessageNotification, ResponseMessage, - Task, Transport } from '@modelcontextprotocol/core'; import { - CallToolResultSchema, - ElicitResultSchema, InMemoryTransport, LATEST_PROTOCOL_VERSION, SdkError, SdkErrorCode, - SUPPORTED_PROTOCOL_VERSIONS, - toArrayAsync + SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; import { createMcpExpressApp } from '@modelcontextprotocol/express'; -import { InMemoryTaskStore, McpServer, Server } from '@modelcontextprotocol/server'; +import { McpServer, Server } from '@modelcontextprotocol/server'; import type { Request, Response } from 'express'; import supertest from 'supertest'; import * as z from 'zod/v4'; @@ -1825,249 +1818,6 @@ describe('createMessage validation', () => { }); }); -describe('createMessageStream', () => { - test('should throw when tools are provided without sampling.tools capability', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - - client.setRequestHandler('sampling/createMessage', async () => ({ - role: 'assistant', - content: { type: 'text', text: 'Response' }, - model: 'test-model' - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - expect(() => { - server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100, - tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] - }); - }).toThrow('Client does not support sampling tools capability'); - }); - - test('should throw when tool_result has no matching tool_use in previous message', async () => { - const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); - const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - - client.setRequestHandler('sampling/createMessage', async () => ({ - role: 'assistant', - content: { type: 'text', text: 'Response' }, - model: 'test-model' - })); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - expect(() => { - server.experimental.tasks.createMessageStream({ - messages: [ - { role: 'user', content: { type: 'text', text: 'Hello' } }, - { - role: 'user', - content: [{ type: 'tool_result', toolUseId: 'test-id', content: [{ type: 'text', text: 'result' }] }] - } - ], - maxTokens: 100 - }); - }).toThrow('tool_result blocks are not matching any tool_use from the previous message'); - }); - - describe('with tasks', () => { - let server: Server; - let client: Client; - let clientTransport: ReturnType[0]; - let serverTransport: ReturnType[1]; - - beforeEach(async () => { - server = new Server( - { name: 'test server', version: '1.0' }, - { - capabilities: { - tasks: { - taskStore: new InMemoryTaskStore() - } - } - } - ); - - client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); - - [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - }); - - afterEach(async () => { - await server.close().catch(() => {}); - await client.close().catch(() => {}); - }); - - describe('terminal message guarantees', () => { - test('should yield exactly one terminal message for successful request', async () => { - client.setRequestHandler('sampling/createMessage', async () => ({ - role: 'assistant', - content: { type: 'text', text: 'Response' }, - model: 'test-model' - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100 - }); - - const allMessages = await toArrayAsync(stream); - - expect(allMessages.length).toBe(1); - expect(allMessages[0].type).toBe('result'); - - const taskMessages = allMessages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); - expect(taskMessages.length).toBe(0); - }); - - test('should yield error as terminal message when client returns error', async () => { - client.setRequestHandler('sampling/createMessage', async () => { - throw new Error('Simulated client error'); - }); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100 - }); - - const allMessages = await toArrayAsync(stream); - - expect(allMessages.length).toBe(1); - expect(allMessages[0].type).toBe('error'); - }); - - test('should yield exactly one terminal message with result', async () => { - client.setRequestHandler('sampling/createMessage', () => ({ - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Response' } - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], - maxTokens: 100 - }); - - const messages = await toArrayAsync(stream); - const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); - - expect(terminalMessages.length).toBe(1); - - const lastMessage = messages.at(-1); - expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); - - if (lastMessage.type === 'result') { - expect((lastMessage.result as CreateMessageResult).content).toBeDefined(); - } - }); - }); - - describe('non-task request minimality', () => { - test('should yield only result message for non-task request', async () => { - client.setRequestHandler('sampling/createMessage', () => ({ - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Response' } - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.createMessageStream({ - messages: [{ role: 'user', content: { type: 'text', text: 'Message' } }], - maxTokens: 100 - }); - - const messages = await toArrayAsync(stream); - - const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); - expect(taskMessages.length).toBe(0); - - const resultMessages = messages.filter(m => m.type === 'result'); - expect(resultMessages.length).toBe(1); - - expect(messages.length).toBe(1); - }); - }); - - describe('task-augmented request handling', () => { - test('should yield taskCreated and result for task-augmented request', async () => { - const clientTaskStore = new InMemoryTaskStore(); - const taskClient = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - sampling: {}, - tasks: { - taskStore: clientTaskStore, - requests: { - sampling: { createMessage: {} } - } - } - } - } - ); - - taskClient.setRequestHandler('sampling/createMessage', async (request, extra) => { - const result = { - model: 'test-model', - role: 'assistant' as const, - content: { type: 'text' as const, text: 'Task response' } - }; - - if (request.params.task && extra.task?.store) { - const task = await extra.task.store.createTask({ ttl: extra.task.requestedTtl }); - await extra.task.store.storeTaskResult(task.taskId, 'completed', result); - return { task }; - } - return result; - }); - - const [taskClientTransport, taskServerTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([taskClient.connect(taskClientTransport), server.connect(taskServerTransport)]); - - const stream = server.experimental.tasks.createMessageStream( - { - messages: [{ role: 'user', content: { type: 'text', text: 'Task-augmented message' } }], - maxTokens: 100 - }, - { task: { ttl: 60_000 } } - ); - - const messages = await toArrayAsync(stream); - - // Should have taskCreated and result - expect(messages.length).toBeGreaterThanOrEqual(2); - - // First message should be taskCreated - expect(messages[0].type).toBe('taskCreated'); - const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; - expect(taskCreated.task.taskId).toBeDefined(); - - // Last message should be result - const lastMessage = messages.at(-1); - expect(lastMessage.type).toBe('result'); - if (lastMessage.type === 'result') { - expect((lastMessage.result as CreateMessageResult).model).toBe('test-model'); - } - - clientTaskStore.cleanup(); - await taskClient.close().catch(() => {}); - }); - }); - }); -}); - describe('createMessage backwards compatibility', () => { test('createMessage without tools returns single content (backwards compat)', async () => { const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); @@ -2359,1420 +2109,6 @@ describe('createMcpExpressApp', () => { }); }); -describe('Task-based execution', () => { - test('server with TaskStore should handle task-based tool execution', async () => { - const taskStore = new InMemoryTaskStore(); - - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - // Register a tool using registerToolTask - server.experimental.tasks.registerToolTask( - 'test-tool', - { - description: 'A test tool', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - // Simulate some async work - (async () => { - await new Promise(resolve => setTimeout(resolve, 10)); - const result = { - content: [{ type: 'text', text: 'Tool executed successfully!' }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Use callToolStream to create a task and capture the task ID - let taskId: string | undefined; - const stream = client.experimental.tasks.callToolStream( - { name: 'test-tool', arguments: {} }, - { - task: { - ttl: 60_000 - } - } - ); - - for await (const message of stream) { - if (message.type === 'taskCreated') { - taskId = message.task.taskId; - } - } - - expect(taskId).toBeDefined(); - - // Wait for the task to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify we can retrieve the task - const task = await client.experimental.tasks.getTask(taskId!); - expect(task).toBeDefined(); - expect(task.status).toBe('completed'); - - // Verify we can retrieve the result - const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); - expect(result.content).toEqual([{ type: 'text', text: 'Tool executed successfully!' }]); - - // Cleanup - taskStore.cleanup(); - }); - - test('server without TaskStore should reject task-based requests', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {} - } - // No taskStore configured - } - ); - - server.setRequestHandler('tools/call', async request => { - if (request.params.name === 'test-tool') { - return { - content: [{ type: 'text', text: 'Success!' }] - }; - } - throw new Error('Unknown tool'); - }); - - server.setRequestHandler('tools/list', async () => ({ - tools: [ - { - name: 'test-tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: {} - } - } - ] - })); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to get a task when server doesn't have TaskStore - // The server will return a "Method not found" error - await expect(client.experimental.tasks.getTask('non-existent')).rejects.toThrow('Method not found'); - }); - - test('should automatically attach related-task metadata to nested requests during tool execution', async () => { - const taskStore = new InMemoryTaskStore(); - - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - // Track the elicitation request to verify related-task metadata - let capturedElicitRequest: z.infer | null = null; - - // Set up client elicitation handler - client.setRequestHandler('elicitation/create', async (request, ctx) => { - let taskId: string | undefined; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const createdTask = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - taskId = createdTask.taskId; - } - - // Capture the request to verify metadata later - capturedElicitRequest = request; - - return { - action: 'accept', - content: { - username: 'test-user' - } - }; - }); - - // Register a tool using registerToolTask that makes a nested elicitation request - server.experimental.tasks.registerToolTask( - 'collect-info', - { - description: 'Collects user info via elicitation', - inputSchema: z.object({}) - }, - { - async createTask(_args, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - // Perform async work that makes a nested request - (async () => { - // During tool execution, make a nested request to the client using ctx.mcpReq.send - const elicitResult = await ctx.mcpReq.send({ - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide your username', - requestedSchema: { - type: 'object', - properties: { - username: { type: 'string' } - }, - required: ['username'] - } - } - }); - - const result = { - content: [ - { - type: 'text', - text: `Collected username: ${elicitResult.action === 'accept' && elicitResult.content ? (elicitResult.content as Record).username : 'none'}` - } - ] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Call tool WITH task creation using callToolStream to capture task ID - let taskId: string | undefined; - const stream = client.experimental.tasks.callToolStream( - { name: 'collect-info', arguments: {} }, - { - task: { - ttl: 60_000 - } - } - ); - - for await (const message of stream) { - if (message.type === 'taskCreated') { - taskId = message.task.taskId; - } - } - - expect(taskId).toBeDefined(); - - // Wait for completion - await new Promise(resolve => setTimeout(resolve, 50)); - - // Verify the nested elicitation request was made (related-task metadata is no longer automatically attached) - expect(capturedElicitRequest).toBeDefined(); - - // Verify tool result was correct - const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); - expect(result.content).toEqual([ - { - type: 'text', - text: 'Collected username: test-user' - } - ]); - - // Cleanup - taskStore.cleanup(); - }); - - describe('Server calling client via elicitation', () => { - let clientTaskStore: InMemoryTaskStore; - - beforeEach(() => { - clientTaskStore = new InMemoryTaskStore(); - }); - - afterEach(() => { - clientTaskStore?.cleanup(); - }); - - test('should create task on client via elicitation', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'server-test-user', confirmed: true } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Server creates task on client via elicitation - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Please provide your username', - requestedSchema: { - type: 'object', - properties: { - username: { type: 'string' }, - confirmed: { type: 'boolean' } - }, - required: ['username'] - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Verify task was created - const task = await server.experimental.tasks.getTask(taskId); - expect(task.status).toBe('completed'); - }); - - test('should query task from client using getTask', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { create: {} } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create task - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Provide info', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Query task - const task = await server.experimental.tasks.getTask(taskId); - expect(task).toBeDefined(); - expect(task.taskId).toBe(taskId); - expect(task.status).toBe('completed'); - }); - - test('should query task result from client using getTaskResult', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'result-user', confirmed: true } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { create: {} } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create task - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Provide info', - requestedSchema: { - type: 'object', - properties: { - username: { type: 'string' }, - confirmed: { type: 'boolean' } - } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - // Query result - const result = await server.experimental.tasks.getTaskResult(taskId, ElicitResultSchema); - expect(result.action).toBe('accept'); - expect(result.content).toEqual({ username: 'result-user', confirmed: true }); - }); - - test('should query task list from client using listTasks', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'list-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create multiple tasks - const createdTaskIds: string[] = []; - for (let i = 0; i < 2; i++) { - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Provide info', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure and capture taskId - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - createdTaskIds.push(createTaskResult.task.taskId); - } - - // Query task list - const taskList = await server.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThanOrEqual(2); - for (const taskId of createdTaskIds) { - expect(taskList.tasks).toContainEqual( - expect.objectContaining({ - taskId, - status: 'completed' - }) - ); - } - }); - }); - - test('should handle multiple concurrent task-based tool calls', async () => { - const taskStore = new InMemoryTaskStore(); - - const server = new McpServer( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - // Register a tool using registerToolTask with variable delay - server.experimental.tasks.registerToolTask( - 'async-tool', - { - description: 'An async test tool', - inputSchema: z.object({ - delay: z.number().optional().default(10), - taskNum: z.number().optional() - }) - }, - { - async createTask({ delay, taskNum }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - - // Simulate async work - (async () => { - await new Promise(resolve => setTimeout(resolve, delay)); - const result = { - content: [{ type: 'text', text: `Completed task ${taskNum || 'unknown'}` }] - }; - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Create multiple tasks concurrently - const pendingRequests = Array.from({ length: 4 }, (_, index) => - client.callTool( - { name: 'async-tool', arguments: { delay: 10 + index * 5, taskNum: index + 1 } }, - { - task: { ttl: 60_000 } - } - ) - ); - - // Wait for all tasks to complete - await Promise.all(pendingRequests); - - // Wait a bit more to ensure all tasks are completed - await new Promise(resolve => setTimeout(resolve, 50)); - - // Get all task IDs from the task list - const taskList = await client.experimental.tasks.listTasks(); - expect(taskList.tasks.length).toBeGreaterThanOrEqual(4); - const taskIds = taskList.tasks.map(t => t.taskId); - - // Verify all tasks completed successfully - for (const [i, taskId] of taskIds.entries()) { - const task = await client.experimental.tasks.getTask(taskId!); - expect(task.status).toBe('completed'); - expect(task.taskId).toBe(taskId!); - - const result = await client.experimental.tasks.getTaskResult(taskId!, CallToolResultSchema); - expect(result.content).toEqual([{ type: 'text', text: `Completed task ${i + 1}` }]); - } - - // Verify listTasks returns all tasks - const finalTaskList = await client.experimental.tasks.listTasks(); - for (const taskId of taskIds) { - expect(finalTaskList.tasks).toContainEqual(expect.objectContaining({ taskId })); - } - - // Cleanup - taskStore.cleanup(); - }); - - describe('Error scenarios', () => { - let taskStore: InMemoryTaskStore; - let clientTaskStore: InMemoryTaskStore; - - beforeEach(() => { - taskStore = new InMemoryTaskStore(); - clientTaskStore = new InMemoryTaskStore(); - }); - - afterEach(() => { - taskStore?.cleanup(); - clientTaskStore?.cleanup(); - }); - - test('should throw error when client queries non-existent task from server', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to query a task that doesn't exist - await expect(client.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); - }); - - test('should throw error when server queries non-existent task from client', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async () => ({ - action: 'accept', - content: { username: 'test' } - })); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Try to query a task that doesn't exist on client - await expect(server.experimental.tasks.getTask('non-existent-task')).rejects.toThrow(); - }); - }); -}); - -test('should respect client task capabilities', async () => { - const clientTaskStore = new InMemoryTaskStore(); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - sampling: {}, - elicitation: {}, - tasks: { - requests: { - elicitation: { - create: {} - } - }, - - taskStore: clientTaskStore - } - } - } - ); - - client.setRequestHandler('elicitation/create', async (request, ctx) => { - const result = { - action: 'accept', - content: { username: 'test-user' } - }; - - // Check if task creation is requested - if (request.params.task && ctx.task?.store) { - const task = await ctx.task.store.createTask({ - ttl: ctx.task.requestedTtl - }); - await ctx.task.store.storeTaskResult(task.taskId, 'completed', result); - // Return CreateTaskResult when task creation is requested - return { task }; - } - - // Return ElicitResult for non-task requests - return result; - }); - - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: { - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - }, - enforceStrictCapabilities: true - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Client supports task creation for elicitation/create and task methods - expect(server.getClientCapabilities()).toEqual({ - sampling: {}, - elicitation: { - form: {} - }, - tasks: { - requests: { - elicitation: { - create: {} - } - } - } - }); - - // These should work because client supports tasks - const createTaskResult = await server.request( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'Test', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } } - } - } - }, - { task: { ttl: 60_000 } } - ); - - // Verify CreateTaskResult structure - expect(createTaskResult.task).toBeDefined(); - expect(createTaskResult.task.taskId).toBeDefined(); - const taskId = createTaskResult.task.taskId; - - await expect(server.experimental.tasks.listTasks()).resolves.not.toThrow(); - await expect(server.experimental.tasks.getTask(taskId)).resolves.not.toThrow(); - - // This should throw because client doesn't support task creation for sampling/createMessage - await expect( - server.request( - { - method: 'sampling/createMessage', - params: { - messages: [], - maxTokens: 10 - } - }, - { task: { taskId: 'test-task-2', keepAlive: 60_000 } } - ) - ).rejects.toThrow('Client does not support task creation for sampling/createMessage'); - - clientTaskStore.cleanup(); -}); - -describe('elicitInputStream', () => { - let server: Server; - let client: Client; - let clientTransport: ReturnType[0]; - let serverTransport: ReturnType[1]; - - beforeEach(async () => { - server = new Server( - { name: 'test server', version: '1.0' }, - { - capabilities: { - tasks: { - taskStore: new InMemoryTaskStore() - } - } - } - ); - - client = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - elicitation: { - form: {}, - url: {} - } - } - } - ); - - [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - }); - - afterEach(async () => { - await server.close().catch(() => {}); - await client.close().catch(() => {}); - }); - - test('should throw when client does not support form elicitation', async () => { - // Create client without form elicitation capability - const noFormClient = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - elicitation: { - url: {} - } - } - } - ); - - const [noFormClientTransport, noFormServerTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([noFormClient.connect(noFormClientTransport), server.connect(noFormServerTransport)]); - - expect(() => { - server.experimental.tasks.elicitInputStream({ - mode: 'form', - message: 'Enter data', - requestedSchema: { type: 'object', properties: {} } - }); - }).toThrow('Client does not support form elicitation.'); - - await noFormClient.close().catch(() => {}); - }); - - test('should throw when client does not support url elicitation', async () => { - // Create client without url elicitation capability - const noUrlClient = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - elicitation: { - form: {} - } - } - } - ); - - const [noUrlClientTransport, noUrlServerTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([noUrlClient.connect(noUrlClientTransport), server.connect(noUrlServerTransport)]); - - expect(() => { - server.experimental.tasks.elicitInputStream({ - mode: 'url', - message: 'Open URL', - elicitationId: 'test-123', - url: 'https://example.com/auth' - }); - }).toThrow('Client does not support url elicitation.'); - - await noUrlClient.close().catch(() => {}); - }); - - test('should default to form mode when mode is not specified', async () => { - const requestStreamSpy = vi.spyOn(server.experimental.tasks, 'requestStream'); - - client.setRequestHandler('elicitation/create', () => ({ - action: 'accept', - content: { value: 'test' } - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Call without explicit mode - const params = { - message: 'Enter value', - requestedSchema: { - type: 'object' as const, - properties: { value: { type: 'string' as const } } - } - }; - - const stream = server.experimental.tasks.elicitInputStream( - params as Parameters[0] - ); - await toArrayAsync(stream); - - // Verify mode was normalized to 'form' - expect(requestStreamSpy).toHaveBeenCalledWith( - expect.objectContaining({ - method: 'elicitation/create', - params: expect.objectContaining({ mode: 'form' }) - }), - undefined - ); - }); - - test('should yield error as terminal message when client returns error', async () => { - client.setRequestHandler('elicitation/create', () => { - throw new Error('Simulated client error'); - }); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.elicitInputStream({ - mode: 'form', - message: 'Enter data', - requestedSchema: { - type: 'object', - properties: { value: { type: 'string' } } - } - }); - - const allMessages = await toArrayAsync(stream); - - expect(allMessages.length).toBe(1); - expect(allMessages[0].type).toBe('error'); - }); - - // For any streaming elicitation request, the AsyncGenerator yields exactly one terminal - // message (either 'result' or 'error') as its final message. - describe('terminal message guarantees', () => { - test.each([ - { action: 'accept' as const, content: { data: 'test-value' } }, - { action: 'decline' as const, content: undefined }, - { action: 'cancel' as const, content: undefined } - ])('should yield exactly one terminal message for action: $action', async ({ action, content }) => { - client.setRequestHandler('elicitation/create', () => ({ - action, - content - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - const stream = server.experimental.tasks.elicitInputStream({ - mode: 'form', - message: 'Test message', - requestedSchema: { - type: 'object', - properties: { data: { type: 'string' } } - } - }); - - const messages = await toArrayAsync(stream); - - // Count terminal messages (result or error) - const terminalMessages = messages.filter(m => m.type === 'result' || m.type === 'error'); - - expect(terminalMessages.length).toBe(1); - - // Verify terminal message is the last message - const lastMessage = messages.at(-1); - expect(lastMessage.type === 'result' || lastMessage.type === 'error').toBe(true); - - // Verify result content matches expected action - if (lastMessage.type === 'result') { - expect((lastMessage.result as ElicitResult).action).toBe(action); - } - }); - }); - - // For any non-task elicitation request, the generator yields exactly one 'result' message - // (or 'error' if the request fails), with no 'taskCreated' or 'taskStatus' messages. - describe('non-task request minimality', () => { - test.each([ - { action: 'accept' as const, content: { value: 'test' } }, - { action: 'decline' as const, content: undefined }, - { action: 'cancel' as const, content: undefined } - ])('should yield only result message for non-task request with action: $action', async ({ action, content }) => { - client.setRequestHandler('elicitation/create', () => ({ - action, - content - })); - - await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); - - // Non-task request (no task option) - const stream = server.experimental.tasks.elicitInputStream({ - mode: 'form', - message: 'Non-task request', - requestedSchema: { - type: 'object', - properties: { value: { type: 'string' } } - } - }); - - const messages = await toArrayAsync(stream); - - // Verify no taskCreated or taskStatus messages - const taskMessages = messages.filter(m => m.type === 'taskCreated' || m.type === 'taskStatus'); - expect(taskMessages.length).toBe(0); - - // Verify exactly one result message - const resultMessages = messages.filter(m => m.type === 'result'); - expect(resultMessages.length).toBe(1); - - // Verify total message count is 1 - expect(messages.length).toBe(1); - }); - }); - - // For any task-augmented elicitation request, the generator should yield at least one - // 'taskCreated' message followed by 'taskStatus' messages before yielding the final - // result or error. - describe('task-augmented request handling', () => { - test('should yield taskCreated and result for task-augmented request', async () => { - const clientTaskStore = new InMemoryTaskStore(); - const taskClient = new Client( - { name: 'test client', version: '1.0' }, - { - capabilities: { - elicitation: { form: {} }, - tasks: { - taskStore: clientTaskStore, - requests: { - elicitation: { create: {} } - } - } - } - } - ); - - taskClient.setRequestHandler('elicitation/create', async (request, extra) => { - const result = { - action: 'accept' as const, - content: { username: 'task-user' } - }; - - if (request.params.task && extra.task?.store) { - const task = await extra.task.store.createTask({ ttl: extra.task.requestedTtl }); - await extra.task.store.storeTaskResult(task.taskId, 'completed', result); - return { task }; - } - return result; - }); - - const [taskClientTransport, taskServerTransport] = InMemoryTransport.createLinkedPair(); - await Promise.all([taskClient.connect(taskClientTransport), server.connect(taskServerTransport)]); - - const stream = server.experimental.tasks.elicitInputStream( - { - mode: 'form', - message: 'Task-augmented request', - requestedSchema: { - type: 'object', - properties: { username: { type: 'string' } }, - required: ['username'] - } - }, - { task: { ttl: 60_000 } } - ); - - const messages = await toArrayAsync(stream); - - // Should have taskCreated and result - expect(messages.length).toBeGreaterThanOrEqual(2); - - // First message should be taskCreated - expect(messages[0].type).toBe('taskCreated'); - const taskCreated = messages[0] as { type: 'taskCreated'; task: Task }; - expect(taskCreated.task.taskId).toBeDefined(); - - // Last message should be result - const lastMessage = messages.at(-1); - expect(lastMessage.type).toBe('result'); - if (lastMessage.type === 'result') { - expect((lastMessage.result as ElicitResult).action).toBe('accept'); - expect((lastMessage.result as ElicitResult).content).toEqual({ username: 'task-user' }); - } - - clientTaskStore.cleanup(); - await taskClient.close().catch(() => {}); - }); - }); -}); - describe('Server registerCapabilities with logging', () => { test('registerCapabilities should register logging/setLevel handler', async () => { const server = new Server({ name: 'test-server', version: '1.0.0' }); diff --git a/test/integration/test/server/mcp.test.ts b/test/integration/test/server/mcp.test.ts index 92af09744c..d66b0648c4 100644 --- a/test/integration/test/server/mcp.test.ts +++ b/test/integration/test/server/mcp.test.ts @@ -1,33 +1,10 @@ import { Client } from '@modelcontextprotocol/client'; -import type { CallToolResult, Notification, TextContent } from '@modelcontextprotocol/core'; -import { - getDisplayName, - InMemoryTaskStore, - InMemoryTransport, - ProtocolErrorCode, - UriTemplate, - UrlElicitationRequiredError -} from '@modelcontextprotocol/core'; +import type { Notification, TextContent } from '@modelcontextprotocol/core'; +import { getDisplayName, InMemoryTransport, ProtocolErrorCode, UriTemplate, UrlElicitationRequiredError } from '@modelcontextprotocol/core'; import { completable, McpServer, ResourceTemplate } from '@modelcontextprotocol/server'; import { afterEach, beforeEach, describe, expect, test } from 'vitest'; import * as z from 'zod/v4'; -function createLatch() { - let latch = false; - const waitForLatch = async () => { - while (!latch) { - await new Promise(resolve => setTimeout(resolve, 0)); - } - }; - - return { - releaseLatch: () => { - latch = true; - }, - waitForLatch - }; -} - describe('Zod v4', () => { describe('McpServer', () => { /*** @@ -2019,146 +1996,6 @@ describe('Zod v4', () => { expect(result.tools[0]!._meta).toBeUndefined(); }); - test('should include execution field in listTools response when tool has execution settings', async () => { - const taskStore = new InMemoryTaskStore(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a tool with execution.taskSupport - mcpServer.experimental.tasks.registerToolTask( - 'task-tool', - { - description: 'A tool with task support', - inputSchema: z.object({ input: z.string() }), - execution: { - taskSupport: 'required' - } - }, - { - createTask: async (_args, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000 }); - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) throw new Error('Task not found'); - return task; - }, - getTaskResult: async (_args, ctx) => { - return (await ctx.task.store.getTaskResult(ctx.task.id)) as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - const result = await client.request({ method: 'tools/list' }); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0]!.name).toBe('task-tool'); - expect(result.tools[0]!.execution).toEqual({ - taskSupport: 'required' - }); - - taskStore.cleanup(); - }); - - test('should include execution field with taskSupport optional in listTools response', async () => { - const taskStore = new InMemoryTaskStore(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client({ - name: 'test client', - version: '1.0' - }); - - // Register a tool with execution.taskSupport optional - mcpServer.experimental.tasks.registerToolTask( - 'optional-task-tool', - { - description: 'A tool with optional task support', - inputSchema: z.object({ input: z.string() }), - execution: { - taskSupport: 'optional' - } - }, - { - createTask: async (_args, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000 }); - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) throw new Error('Task not found'); - return task; - }, - getTaskResult: async (_args, ctx) => { - return (await ctx.task.store.getTaskResult(ctx.task.id)) as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - const result = await client.request({ method: 'tools/list' }); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0]!.name).toBe('optional-task-tool'); - expect(result.tools[0]!.execution).toEqual({ - taskSupport: 'optional' - }); - - taskStore.cleanup(); - }); - test('should validate tool names according to SEP specification', () => { // Create a new server instance for this test const testServer = new McpServer({ @@ -6444,599 +6281,4 @@ describe('Zod v4', () => { ); }); }); - - describe('Tool-level task hints with automatic polling wrapper', () => { - test('should return error for tool with taskSupport "required" called without task augmentation', async () => { - const taskStore = new InMemoryTaskStore(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test client', - version: '1.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - // Register a task-based tool with taskSupport "required" - mcpServer.experimental.tasks.registerToolTask( - 'long-running-task', - { - description: 'A long running task', - inputSchema: z.object({ - input: z.string() - }), - execution: { - taskSupport: 'required' - } - }, - { - createTask: async ({ input }, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - - // Capture taskStore for use in setTimeout - const store = ctx.task.store; - - // Simulate async work - setTimeout(async () => { - await store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text' as const, text: `Processed: ${input}` }] - }); - }, 200); - - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_input, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool WITHOUT task augmentation - should return error - const result = await client.callTool({ - name: 'long-running-task', - arguments: { input: 'test data' } - }); - - // Should receive error result - expect(result.isError).toBe(true); - const content = result.content as TextContent[]; - expect(content[0]!.text).toContain('requires task augmentation'); - - taskStore.cleanup(); - }); - - test('should automatically poll and return CallToolResult for tool with taskSupport "optional" called without task augmentation', async () => { - const taskStore = new InMemoryTaskStore(); - const { releaseLatch, waitForLatch } = createLatch(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test client', - version: '1.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - // Register a task-based tool with taskSupport "optional" - mcpServer.experimental.tasks.registerToolTask( - 'optional-task', - { - description: 'An optional task', - inputSchema: z.object({ - value: z.number() - }), - execution: { - taskSupport: 'optional' - } - }, - { - createTask: async ({ value }, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - - // Capture taskStore for use in setTimeout - const store = ctx.task.store; - - // Simulate async work - setTimeout(async () => { - await store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text' as const, text: `Result: ${value * 2}` }] - }); - releaseLatch(); - }, 150); - - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_value, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool WITHOUT task augmentation - const result = await client.callTool({ - name: 'optional-task', - arguments: { value: 21 } - }); - - // Should receive CallToolResult directly, not CreateTaskResult - expect(result).toHaveProperty('content'); - expect(result.content).toEqual([{ type: 'text' as const, text: 'Result: 42' }]); - expect(result).not.toHaveProperty('task'); - - // Wait for async operations to complete - await waitForLatch(); - taskStore.cleanup(); - }); - - test('should return CreateTaskResult when tool with taskSupport "required" is called WITH task augmentation', async () => { - const taskStore = new InMemoryTaskStore(); - const { releaseLatch, waitForLatch } = createLatch(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test client', - version: '1.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - // Register a task-based tool with taskSupport "required" - mcpServer.experimental.tasks.registerToolTask( - 'task-tool', - { - description: 'A task tool', - inputSchema: z.object({ - data: z.string() - }), - execution: { - taskSupport: 'required' - } - }, - { - createTask: async ({ data }, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - - // Capture taskStore for use in setTimeout - const store = ctx.task.store; - - // Simulate async work - setTimeout(async () => { - await store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text' as const, text: `Completed: ${data}` }] - }); - releaseLatch(); - }, 200); - - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_data, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool WITH task augmentation - const result = await client.request( - { - method: 'tools/call', - params: { - name: 'task-tool', - arguments: { data: 'test' }, - task: { ttl: 60_000 } - } - }, - z.object({ - task: z.object({ - taskId: z.string(), - status: z.string(), - ttl: z.union([z.number(), z.null()]), - createdAt: z.string(), - pollInterval: z.number().optional() - }) - }) - ); - - // Should receive CreateTaskResult with task field - expect(result).toHaveProperty('task'); - expect(result.task).toHaveProperty('taskId'); - expect(result.task.status).toBe('working'); - - // Wait for async operations to complete - await waitForLatch(); - taskStore.cleanup(); - }); - - test('should handle task failures during automatic polling', async () => { - const taskStore = new InMemoryTaskStore(); - const { releaseLatch, waitForLatch } = createLatch(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test client', - version: '1.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - // Register a task-based tool that fails - mcpServer.experimental.tasks.registerToolTask( - 'failing-task', - { - description: 'A failing task', - execution: { - taskSupport: 'optional' - } - }, - { - createTask: async ctx => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - - // Capture taskStore for use in setTimeout - const store = ctx.task.store; - - // Simulate async failure - setTimeout(async () => { - await store.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text' as const, text: 'Error occurred' }], - isError: true - }); - releaseLatch(); - }, 150); - - return { task }; - }, - getTask: async ctx => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async ctx => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool WITHOUT task augmentation - const result = await client.callTool({ - name: 'failing-task', - arguments: {} - }); - - // Should receive the error result - expect(result).toHaveProperty('content'); - expect(result.content).toEqual([{ type: 'text' as const, text: 'Error occurred' }]); - expect(result.isError).toBe(true); - - // Wait for async operations to complete - await waitForLatch(); - taskStore.cleanup(); - }); - - test('should handle task cancellation during automatic polling', async () => { - const taskStore = new InMemoryTaskStore(); - const { releaseLatch, waitForLatch } = createLatch(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - const client = new Client( - { - name: 'test client', - version: '1.0' - }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - } - } - } - } - ); - - // Register a task-based tool that gets cancelled - mcpServer.experimental.tasks.registerToolTask( - 'cancelled-task', - { - description: 'A task that gets cancelled', - execution: { - taskSupport: 'optional' - } - }, - { - createTask: async ctx => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - - // Capture taskStore for use in setTimeout - const store = ctx.task.store; - - // Simulate async cancellation - setTimeout(async () => { - await store.updateTaskStatus(task.taskId, 'cancelled', 'Task was cancelled'); - releaseLatch(); - }, 150); - - return { task }; - }, - getTask: async ctx => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async ctx => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); - - // Call the tool WITHOUT task augmentation - const result = await client.callTool({ - name: 'cancelled-task', - arguments: {} - }); - - // Should receive an error since cancelled tasks don't have results - expect(result).toHaveProperty('content'); - expect(result.content).toEqual([{ type: 'text' as const, text: expect.stringContaining('has no result stored') }]); - - // Wait for async operations to complete - await waitForLatch(); - taskStore.cleanup(); - }); - - test('should raise error when registerToolTask is called with taskSupport "forbidden"', () => { - const taskStore = new InMemoryTaskStore(); - - const mcpServer = new McpServer( - { - name: 'test server', - version: '1.0' - }, - { - capabilities: { - tools: {}, - tasks: { - requests: { - tools: { - call: {} - } - }, - - taskStore - } - } - } - ); - - // Attempt to register a task-based tool with taskSupport "forbidden" (cast to bypass type checking) - expect(() => { - mcpServer.experimental.tasks.registerToolTask( - 'invalid-task', - { - description: 'A task with forbidden support', - inputSchema: z.object({ - input: z.string() - }), - execution: { - taskSupport: 'forbidden' as unknown as 'required' - } - }, - { - createTask: async (_args, ctx) => { - const task = await ctx.task.store.createTask({ ttl: 60_000, pollInterval: 100 }); - return { task }; - }, - getTask: async (_args, ctx) => { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error('Task not found'); - } - return task; - }, - getTaskResult: async (_args, ctx) => { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as CallToolResult; - } - } - ); - }).toThrow(); - - taskStore.cleanup(); - }); - }); }); diff --git a/test/integration/test/stateManagementStreamableHttp.test.ts b/test/integration/test/stateManagementStreamableHttp.test.ts index 3f32c64a34..2a475d9690 100644 --- a/test/integration/test/stateManagementStreamableHttp.test.ts +++ b/test/integration/test/stateManagementStreamableHttp.test.ts @@ -212,7 +212,7 @@ describe('Zod v4', () => { version: '1.0.0' }); - const transport = new StreamableHTTPClientTransport(baseUrl); + const transport = new StreamableHTTPClientTransport(baseUrl, { forceLegacy: true }); // Verify protocol version is not set before connecting expect(transport.protocolVersion).toBeUndefined(); @@ -255,7 +255,7 @@ describe('Zod v4', () => { version: '1.0.0' }); - const transport = new StreamableHTTPClientTransport(baseUrl); + const transport = new StreamableHTTPClientTransport(baseUrl, { forceLegacy: true }); await client.connect(transport); // Verify that a session ID was set diff --git a/test/integration/test/taskLifecycle.test.ts b/test/integration/test/taskLifecycle.test.ts deleted file mode 100644 index 1a540df0fd..0000000000 --- a/test/integration/test/taskLifecycle.test.ts +++ /dev/null @@ -1,1625 +0,0 @@ -import { randomUUID } from 'node:crypto'; -import type { Server } from 'node:http'; -import { createServer } from 'node:http'; - -import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; -import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { TaskRequestOptions } from '@modelcontextprotocol/server'; -import { - InMemoryTaskMessageQueue, - InMemoryTaskStore, - McpServer, - ProtocolError, - ProtocolErrorCode, - RELATED_TASK_META_KEY -} from '@modelcontextprotocol/server'; -import { listenOnRandomPort, waitForTaskStatus } from '@modelcontextprotocol/test-helpers'; -import * as z from 'zod/v4'; - -describe('Task Lifecycle Integration Tests', () => { - let server: Server; - let mcpServer: McpServer; - let serverTransport: NodeStreamableHTTPServerTransport; - let baseUrl: URL; - let taskStore: InMemoryTaskStore; - - beforeEach(async () => { - // Create task store - taskStore = new InMemoryTaskStore(); - - // Create MCP server with task support - mcpServer = new McpServer( - { name: 'test-server', version: '1.0.0' }, - { - capabilities: { - tasks: { - requests: { - tools: { - call: {} - } - }, - list: {}, - cancel: {}, - taskStore, - taskMessageQueue: new InMemoryTaskMessageQueue() - } - } - } - ); - - // Register a long-running tool using registerToolTask - mcpServer.experimental.tasks.registerToolTask( - 'long-task', - { - title: 'Long Running Task', - description: 'A tool that takes time to complete', - inputSchema: z.object({ - duration: z.number().describe('Duration in milliseconds').default(1000), - shouldFail: z.boolean().describe('Whether the task should fail').default(false) - }) - }, - { - async createTask({ duration, shouldFail }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Simulate async work - (async () => { - await new Promise(resolve => setTimeout(resolve, duration)); - - try { - await (shouldFail - ? ctx.task.store.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: 'Task failed as requested' }], - isError: true - }) - : ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Completed after ${duration}ms` }] - })); - } catch { - // Task may have been cleaned up if test ended - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - // Register a tool that requires input via elicitation - mcpServer.experimental.tasks.registerToolTask( - 'input-task', - { - title: 'Input Required Task', - description: 'A tool that requires user input', - inputSchema: z.object({ - userName: z.string().describe('User name').optional() - }) - }, - { - async createTask({ userName }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Perform async work that requires elicitation - (async () => { - await new Promise(resolve => setTimeout(resolve, 100)); - - // If userName not provided, request it via elicitation - if (userName) { - // Complete immediately if userName was provided - try { - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Hello, ${userName}!` }] - }); - } catch { - // Task may have been cleaned up if test ended - } - } else { - const elicitationResult = await ctx.mcpReq.send( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: 'What is your name?', - requestedSchema: { - type: 'object', - properties: { - userName: { type: 'string' } - }, - required: ['userName'] - } - } - }, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ); - - // Complete with the elicited name - const name = - elicitationResult.action === 'accept' && elicitationResult.content - ? elicitationResult.content.userName - : 'Unknown'; - try { - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Hello, ${name}!` }] - }); - } catch { - // Task may have been cleaned up if test ended - } - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - // Create transport - serverTransport = new NodeStreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID() - }); - - await mcpServer.connect(serverTransport); - - // Create HTTP server - server = createServer(async (req, res) => { - await serverTransport.handleRequest(req, res); - }); - - // Start server - baseUrl = await listenOnRandomPort(server); - }); - - afterEach(async () => { - taskStore.cleanup(); - await mcpServer.close().catch(() => {}); - await serverTransport.close().catch(() => {}); - server.close(); - }); - - describe('Task Creation and Completion', () => { - it('should create a task and return CreateTaskResult', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 500, - shouldFail: false - }, - task: { - ttl: 60_000 - } - } - }); - - // Verify CreateTaskResult structure - expect(createResult).toHaveProperty('task'); - expect(createResult.task).toHaveProperty('taskId'); - expect(createResult.task.status).toBe('working'); - expect(createResult.task.ttl).toBe(60_000); - expect(createResult.task.createdAt).toBeDefined(); - expect(createResult.task.pollInterval).toBe(100); - - // Verify task is stored in taskStore - const taskId = createResult.task.taskId; - const storedTask = await taskStore.getTask(taskId); - expect(storedTask).toBeDefined(); - expect(storedTask?.taskId).toBe(taskId); - expect(storedTask?.status).toBe('working'); - - // Wait for completion - const completedTask = await waitForTaskStatus(id => taskStore.getTask(id), taskId, 'completed'); - - // Verify task completed - expect(completedTask.status).toBe('completed'); - - // Verify result is stored - const result = await taskStore.getTaskResult(taskId); - expect(result).toBeDefined(); - expect(result.content).toEqual([{ type: 'text', text: 'Completed after 500ms' }]); - - await transport.close(); - }); - - it('should handle task failure correctly', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task that will fail - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 300, - shouldFail: true - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Wait for failure - const task = await waitForTaskStatus(id => taskStore.getTask(id), taskId, 'failed'); - - // Verify task failed - expect(task.status).toBe('failed'); - - // Verify error result is stored - const result = await taskStore.getTaskResult(taskId); - expect(result.content).toEqual([{ type: 'text', text: 'Task failed as requested' }]); - expect(result.isError).toBe(true); - - await transport.close(); - }); - }); - - describe('Task Cancellation', () => { - it('should cancel a working task and return the cancelled task', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: {} } - } - ); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a long-running task - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 5000 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Verify task is working - let task = await taskStore.getTask(taskId); - expect(task?.status).toBe('working'); - - // Cancel the task via client.experimental.tasks.cancelTask - per spec, returns Result & Task - const cancelResult = await client.experimental.tasks.cancelTask(taskId); - - // Verify the cancel response includes the cancelled task (per MCP spec CancelTaskResult is Result & Task) - expect(cancelResult.taskId).toBe(taskId); - expect(cancelResult.status).toBe('cancelled'); - expect(cancelResult.createdAt).toBeDefined(); - expect(cancelResult.lastUpdatedAt).toBeDefined(); - expect(cancelResult.ttl).toBeDefined(); - - // Verify task is cancelled in store as well - task = await taskStore.getTask(taskId); - expect(task?.status).toBe('cancelled'); - - await transport.close(); - }); - - it('should reject cancellation of completed task with error code -32602', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: {} } - } - ); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a quick task - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 100 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Wait for completion - const task = await waitForTaskStatus(id => taskStore.getTask(id), taskId, 'completed'); - - // Verify task is completed - expect(task.status).toBe('completed'); - - // Try to cancel via tasks/cancel request (should fail with -32602) - await expect(client.experimental.tasks.cancelTask(taskId)).rejects.toSatisfy((error: ProtocolError) => { - expect(error).toBeInstanceOf(ProtocolError); - expect(error.code).toBe(ProtocolErrorCode.InvalidParams); - expect(error.message).toContain('Cannot cancel task in terminal status'); - return true; - }); - - await transport.close(); - }); - }); - - describe('Multiple Queued Messages', () => { - it('should deliver multiple queued messages in order', async () => { - // Register a tool that sends multiple server requests during execution - mcpServer.experimental.tasks.registerToolTask( - 'multi-request-task', - { - title: 'Multi Request Task', - description: 'A tool that sends multiple server requests', - inputSchema: z.object({ - requestCount: z.number().describe('Number of requests to send').default(3) - }) - }, - { - async createTask({ requestCount }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Perform async work that sends multiple requests - (async () => { - await new Promise(resolve => setTimeout(resolve, 100)); - - const responses: string[] = []; - - // Send multiple elicitation requests - for (let i = 0; i < requestCount; i++) { - const elicitationResult = await ctx.mcpReq.send( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Request ${i + 1} of ${requestCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } - } - }, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ); - - if (elicitationResult.action === 'accept' && elicitationResult.content) { - responses.push(elicitationResult.content.response as string); - } - } - - // Complete with all responses - try { - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Received responses: ${responses.join(', ')}` }] - }); - } catch { - // Task may have been cleaned up if test ended - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {} - } - } - ); - - const receivedMessages: Array<{ method: string; message: string }> = []; - - // Set up elicitation handler on client to track message order - client.setRequestHandler('elicitation/create', async request => { - // Track the message - receivedMessages.push({ - method: request.method, - message: request.params.message - }); - - // Extract the request number from the message - const match = request.params.message.match(/Request (\d+) of (\d+)/); - const requestNum = match ? match[1] : 'unknown'; - - // Respond with the request number - return { - action: 'accept' as const, - content: { - response: `Response ${requestNum}` - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task that will send 3 requests - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'multi-request-task', - arguments: { - requestCount: 3 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Wait for messages to be queued - await new Promise(resolve => setTimeout(resolve, 200)); - - // Call tasks/result to receive all queued messages - // This should deliver all 3 elicitation requests in order - const result = await client.request({ - method: 'tasks/result', - params: { taskId } - }); - - // Verify all messages were delivered in order - expect(receivedMessages.length).toBe(3); - expect(receivedMessages[0]!.message).toBe('Request 1 of 3'); - expect(receivedMessages[1]!.message).toBe('Request 2 of 3'); - expect(receivedMessages[2]!.message).toBe('Request 3 of 3'); - - // Verify final result includes all responses - expect(result.content).toEqual([{ type: 'text', text: 'Received responses: Response 1, Response 2, Response 3' }]); - - // Verify task is completed - const task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task.status).toBe('completed'); - - await transport.close(); - }, 10_000); - }); - - describe('Input Required Flow', () => { - it('should handle elicitation during tool execution', async () => { - // Complete flow phases: - // 1. Client creates task - // 2. Server queues elicitation request and sets status to input_required - // 3. Client polls tasks/get, sees input_required status - // 4. Client calls tasks/result to dequeue elicitation request - // 5. Client responds to elicitation - // 6. Server receives response, completes task - // 7. Client receives final result - - const elicitClient = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {} - } - } - ); - - // Track elicitation request receipt - let elicitationReceived = false; - let elicitationRequestMeta: Record | undefined; - - // Set up elicitation handler on client - elicitClient.setRequestHandler('elicitation/create', async request => { - elicitationReceived = true; - elicitationRequestMeta = request.params._meta; - - return { - action: 'accept' as const, - content: { - userName: 'TestUser' - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await elicitClient.connect(transport); - - // Phase 1: Create task - const createResult = await elicitClient.request({ - method: 'tools/call', - params: { - name: 'input-task', - arguments: {}, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - expect(createResult.task.status).toBe('working'); - - // Phase 2: Wait for server to queue elicitation and update status - const task = await waitForTaskStatus( - id => - elicitClient.request({ - method: 'tasks/get', - params: { taskId: id } - }), - taskId, - 'input_required', - { - intervalMs: createResult.task.pollInterval ?? 100 - } - ); - - // Verify we saw input_required status (not completed or failed) - expect(task.status).toBe('input_required'); - - // Phase 3: Call tasks/result to dequeue messages and get final result - // This should: - // - Deliver the queued elicitation request via SSE - // - Client handler responds - // - Server receives response, completes task - // - Return final result - const result = await elicitClient.request({ - method: 'tasks/result', - params: { taskId } - }); - - // Verify elicitation was received and processed - expect(elicitationReceived).toBe(true); - - // Verify the elicitation request had related-task metadata - expect(elicitationRequestMeta).toBeDefined(); - expect(elicitationRequestMeta?.[RELATED_TASK_META_KEY]).toEqual({ taskId }); - - // Verify final result - expect(result.content).toEqual([{ type: 'text', text: 'Hello, TestUser!' }]); - - // Verify task is now completed - const finalTask = await elicitClient.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(finalTask.status).toBe('completed'); - - await transport.close(); - }, 15_000); - }); - - describe('Task Listing and Pagination', () => { - it('should list tasks', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create multiple tasks - const taskIds: string[] = []; - for (let i = 0; i < 3; i++) { - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 1000 - }, - task: { - ttl: 60_000 - } - } - }); - taskIds.push(createResult.task.taskId); - } - - // List tasks using taskStore - const listResult = await taskStore.listTasks(); - - expect(listResult.tasks.length).toBeGreaterThanOrEqual(3); - expect(listResult.tasks.some(t => taskIds.includes(t.taskId))).toBe(true); - - await transport.close(); - }); - - it('should handle pagination with large datasets', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create 15 tasks (more than page size of 10) - for (let i = 0; i < 15; i++) { - await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 5000 - }, - task: { - ttl: 60_000 - } - } - }); - } - - // Get first page using taskStore - const page1 = await taskStore.listTasks(); - - expect(page1.tasks.length).toBe(10); - expect(page1.nextCursor).toBeDefined(); - - // Get second page - const page2 = await taskStore.listTasks(page1.nextCursor); - - expect(page2.tasks.length).toBeGreaterThanOrEqual(5); - - await transport.close(); - }); - }); - - describe('Error Handling', () => { - it('should return error code -32602 for non-existent task in tasks/get', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: {} } - } - ); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Try to get non-existent task via tasks/get request - await expect(client.experimental.tasks.getTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { - expect(error).toBeInstanceOf(ProtocolError); - expect(error.code).toBe(ProtocolErrorCode.InvalidParams); - expect(error.message).toContain('Task not found'); - return true; - }); - - await transport.close(); - }); - - it('should return error code -32602 for non-existent task in tasks/cancel', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: {} } - } - ); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Try to cancel non-existent task via tasks/cancel request - await expect(client.experimental.tasks.cancelTask('non-existent-task-id')).rejects.toSatisfy((error: ProtocolError) => { - expect(error).toBeInstanceOf(ProtocolError); - expect(error.code).toBe(ProtocolErrorCode.InvalidParams); - expect(error.message).toContain('Task not found'); - return true; - }); - - await transport.close(); - }); - - it('should return error code -32602 for non-existent task in tasks/result', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Try to get result of non-existent task via tasks/result request - await expect( - client.request({ - method: 'tasks/result', - params: { taskId: 'non-existent-task-id' } - }) - ).rejects.toSatisfy((error: ProtocolError) => { - expect(error).toBeInstanceOf(ProtocolError); - expect(error.code).toBe(ProtocolErrorCode.InvalidParams); - expect(error.message).toContain('Task not found'); - return true; - }); - - await transport.close(); - }); - }); - - describe('TTL and Cleanup', () => { - it('should respect TTL in task creation', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task with specific TTL - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 100 - }, - task: { - ttl: 5000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Verify TTL is set correctly - expect(createResult.task.ttl).toBe(60_000); // The task store uses 60000 as default - - // Task should exist - const task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task).toBeDefined(); - expect(task.ttl).toBe(60_000); - - await transport.close(); - }); - }); - - describe('Task Cancellation with Queued Messages', () => { - it('should clear queue and deliver no messages when task is cancelled before tasks/result', async () => { - // Register a tool that queues messages but doesn't complete immediately - mcpServer.experimental.tasks.registerToolTask( - 'cancellable-task', - { - title: 'Cancellable Task', - description: 'A tool that queues messages and can be cancelled', - inputSchema: z.object({ - messageCount: z.number().describe('Number of messages to queue').default(2) - }) - }, - { - async createTask({ messageCount }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Perform async work that queues messages - (async () => { - try { - await new Promise(resolve => setTimeout(resolve, 100)); - - // Queue multiple elicitation requests - for (let i = 0; i < messageCount; i++) { - // Send request but don't await - let it queue - ctx.mcpReq - .send( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } - } - }, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ) - .catch(() => { - // Ignore errors from cancelled requests - }); - } - - // Don't complete - let the task be cancelled - // Wait indefinitely (or until cancelled) - await new Promise(() => {}); - } catch { - // Ignore errors - task was cancelled - } - })().catch(() => { - // Catch any unhandled errors from the async execution - }); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {} - } - } - ); - - let elicitationCallCount = 0; - - // Set up elicitation handler to track if any messages are delivered - client.setRequestHandler('elicitation/create', async () => { - elicitationCallCount++; - return { - action: 'accept' as const, - content: { - response: 'Should not be called' - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task that will queue messages - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'cancellable-task', - arguments: { - messageCount: 2 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Wait for messages to be queued - await new Promise(resolve => setTimeout(resolve, 200)); - - // Verify task is in input_required state and messages are queued - let task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task.status).toBe('input_required'); - - // Cancel the task before calling tasks/result using the proper tasks/cancel request - // This will trigger queue cleanup via _clearTaskQueue in the handler - await client.request({ - method: 'tasks/cancel', - params: { taskId } - }); - - // Verify task is cancelled - task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task.status).toBe('cancelled'); - - // Attempt to call tasks/result - // When a task is cancelled, the system needs to clear the message queue - // and reject any pending message delivery promises, meaning no further - // messages should be delivered for a cancelled task. - try { - await client.request({ - method: 'tasks/result', - params: { taskId } - }); - } catch { - // tasks/result might throw an error for cancelled tasks without a result - // This is acceptable behavior - } - - // Verify no elicitation messages were delivered, as the queue should be cleared immediately on cancellation - expect(elicitationCallCount).toBe(0); - - // Verify queue remains cleared on subsequent calls - try { - await client.request({ - method: 'tasks/result', - params: { taskId } - }); - } catch { - // Expected - task is cancelled - } - - // Still no messages should have been delivered - expect(elicitationCallCount).toBe(0); - - await transport.close(); - }, 10_000); - }); - - describe('Continuous Message Delivery', () => { - it('should deliver messages immediately while tasks/result is blocking', async () => { - // Register a tool that queues messages over time - mcpServer.experimental.tasks.registerToolTask( - 'streaming-task', - { - title: 'Streaming Task', - description: 'A tool that sends messages over time', - inputSchema: z.object({ - messageCount: z.number().describe('Number of messages to send').default(3), - delayBetweenMessages: z.number().describe('Delay between messages in ms').default(200) - }) - }, - { - async createTask({ messageCount, delayBetweenMessages }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Perform async work that sends messages over time - (async () => { - try { - // Wait a bit before starting to send messages - await new Promise(resolve => setTimeout(resolve, 100)); - - const responses: string[] = []; - - // Send messages with delays between them - for (let i = 0; i < messageCount; i++) { - const elicitationResult = await ctx.mcpReq.send( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Streaming message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } - } - }, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ); - - if (elicitationResult.action === 'accept' && elicitationResult.content) { - responses.push(elicitationResult.content.response as string); - } - - // Wait before sending next message (if not the last one) - if (i < messageCount - 1) { - await new Promise(resolve => setTimeout(resolve, delayBetweenMessages)); - } - } - - // Complete with all responses - try { - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: `Received all responses: ${responses.join(', ')}` }] - }); - } catch { - // Task may have been cleaned up if test ended - } - } catch (error) { - // Handle errors - try { - await ctx.task.store.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: `Error: ${error}` }], - isError: true - }); - } catch { - // Task may have been cleaned up if test ended - } - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {} - } - } - ); - - const receivedMessages: Array<{ message: string; timestamp: number }> = []; - let tasksResultStartTime = 0; - - // Set up elicitation handler to track when messages arrive - client.setRequestHandler('elicitation/create', async request => { - const timestamp = Date.now(); - receivedMessages.push({ - message: request.params.message, - timestamp - }); - - // Extract the message number - const match = request.params.message.match(/Streaming message (\d+) of (\d+)/); - const messageNum = match ? match[1] : 'unknown'; - - // Respond immediately - return { - action: 'accept' as const, - content: { - response: `Response ${messageNum}` - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task that will send messages over time - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'streaming-task', - arguments: { - messageCount: 3, - delayBetweenMessages: 300 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Verify task is in working status - let task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task.status).toBe('working'); - - // Call tasks/result immediately (before messages are queued) - // This should block and deliver messages as they arrive - tasksResultStartTime = Date.now(); - const resultPromise = client.request({ - method: 'tasks/result', - params: { taskId } - }); - - // Wait for the task to complete and get the result - const result = await resultPromise; - - // Verify all 3 messages were delivered - expect(receivedMessages.length).toBe(3); - expect(receivedMessages[0]!.message).toBe('Streaming message 1 of 3'); - expect(receivedMessages[1]!.message).toBe('Streaming message 2 of 3'); - expect(receivedMessages[2]!.message).toBe('Streaming message 3 of 3'); - - // Verify messages were delivered over time (not all at once) - // The delay between messages should be approximately 300ms - const timeBetweenFirstAndSecond = receivedMessages[1]!.timestamp - receivedMessages[0]!.timestamp; - const timeBetweenSecondAndThird = receivedMessages[2]!.timestamp - receivedMessages[1]!.timestamp; - - // Allow some tolerance for timing (messages should be at least 200ms apart) - expect(timeBetweenFirstAndSecond).toBeGreaterThan(200); - expect(timeBetweenSecondAndThird).toBeGreaterThan(200); - - // Verify messages were delivered while tasks/result was blocking - // (all messages should arrive after tasks/result was called) - for (const msg of receivedMessages) { - expect(msg.timestamp).toBeGreaterThanOrEqual(tasksResultStartTime); - } - - // Verify final result is correct - expect(result.content).toEqual([{ type: 'text', text: 'Received all responses: Response 1, Response 2, Response 3' }]); - - // Verify task is now completed - task = await client.request({ - method: 'tasks/get', - params: { taskId } - }); - expect(task.status).toBe('completed'); - - await transport.close(); - }, 15_000); // Increase timeout to 15 seconds to allow for message delays - }); - - describe('Terminal Task with Queued Messages', () => { - it('should deliver queued messages followed by final result for terminal task', async () => { - // Register a tool that completes quickly and queues messages before completion - mcpServer.experimental.tasks.registerToolTask( - 'quick-complete-task', - { - title: 'Quick Complete Task', - description: 'A tool that queues messages and completes quickly', - inputSchema: z.object({ - messageCount: z.number().describe('Number of messages to queue').default(2) - }) - }, - { - async createTask({ messageCount }, ctx) { - const task = await ctx.task.store.createTask({ - ttl: 60_000, - pollInterval: 100 - }); - - // Perform async work that queues messages and completes quickly - (async () => { - try { - // Queue messages - these will be queued before the task completes - // We await each one starting to ensure they're queued before completing - for (let i = 0; i < messageCount; i++) { - // Start the request but don't wait for response - // The request gets queued when sendRequest is called - ctx.mcpReq - .send( - { - method: 'elicitation/create', - params: { - mode: 'form', - message: `Quick message ${i + 1} of ${messageCount}`, - requestedSchema: { - type: 'object', - properties: { - response: { type: 'string' } - }, - required: ['response'] - } - } - }, - { relatedTask: { taskId: task.taskId } } as unknown as TaskRequestOptions - ) - .catch(() => {}); - // Small delay to ensure message is queued before next iteration - await new Promise(resolve => setTimeout(resolve, 10)); - } - - // Complete the task after all messages are queued - try { - await ctx.task.store.storeTaskResult(task.taskId, 'completed', { - content: [{ type: 'text', text: 'Task completed quickly' }] - }); - } catch { - // Task may have been cleaned up if test ended - } - } catch (error) { - // Handle errors - try { - await ctx.task.store.storeTaskResult(task.taskId, 'failed', { - content: [{ type: 'text', text: `Error: ${error}` }], - isError: true - }); - } catch { - // Task may have been cleaned up if test ended - } - } - })(); - - return { task }; - }, - async getTask(_args, ctx) { - const task = await ctx.task.store.getTask(ctx.task.id); - if (!task) { - throw new Error(`Task ${ctx.task.id} not found`); - } - return task; - }, - async getTaskResult(_args, ctx) { - const result = await ctx.task.store.getTaskResult(ctx.task.id); - return result as { content: Array<{ type: 'text'; text: string }> }; - } - } - ); - - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {} - } - } - ); - - const receivedMessages: Array<{ type: string; message?: string; content?: unknown }> = []; - - // Set up elicitation handler to track message order - client.setRequestHandler('elicitation/create', async request => { - receivedMessages.push({ - type: 'elicitation', - message: request.params.message - }); - - // Extract the message number - const match = request.params.message.match(/Quick message (\d+) of (\d+)/); - const messageNum = match ? match[1] : 'unknown'; - - return { - action: 'accept' as const, - content: { - response: `Response ${messageNum}` - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task that will complete quickly with queued messages - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'quick-complete-task', - arguments: { - messageCount: 2 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Wait for task to complete and messages to be queued - const task = await waitForTaskStatus(id => taskStore.getTask(id), taskId, 'completed'); - - // Verify task is in terminal status (completed) - expect(task.status).toBe('completed'); - - // Call tasks/result - should deliver queued messages followed by final result - const result = await client.request({ - method: 'tasks/result', - params: { taskId } - }); - - // Verify all queued messages were delivered before the final result - expect(receivedMessages.length).toBe(2); - expect(receivedMessages[0]!.message).toBe('Quick message 1 of 2'); - expect(receivedMessages[1]!.message).toBe('Quick message 2 of 2'); - - // Verify final result is correct - expect(result.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); - - // Verify queue is cleaned up - calling tasks/result again should only return the result - receivedMessages.length = 0; // Clear the array - - const result2 = await client.request({ - method: 'tasks/result', - params: { taskId } - }); - - // No messages should be delivered on second call (queue was cleaned up) - expect(receivedMessages.length).toBe(0); - expect(result2.content).toEqual([{ type: 'text', text: 'Task completed quickly' }]); - - await transport.close(); - }, 10_000); - }); - - describe('Concurrent Operations', () => { - it('should handle multiple concurrent task creations', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create multiple tasks concurrently - const promises = Array.from({ length: 5 }, () => - client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 500 - }, - task: { - ttl: 60_000 - } - } - }) - ); - - const results = await Promise.all(promises); - - // Verify all tasks were created with unique IDs - const taskIds = results.map(r => r.task.taskId); - expect(new Set(taskIds).size).toBe(5); - - // Verify all tasks are in working status - for (const result of results) { - expect(result.task.status).toBe('working'); - } - - await transport.close(); - }); - - it('should handle concurrent operations on same task', async () => { - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Create a task - const createResult = await client.request({ - method: 'tools/call', - params: { - name: 'long-task', - arguments: { - duration: 2000 - }, - task: { - ttl: 60_000 - } - } - }); - - const taskId = createResult.task.taskId; - - // Perform multiple concurrent gets - const getPromises = Array.from({ length: 5 }, () => - client.request({ - method: 'tasks/get', - params: { taskId } - }) - ); - - const tasks = await Promise.all(getPromises); - - // All should return the same task - for (const task of tasks) { - expect(task.taskId).toBe(taskId); - expect(task.status).toBe('working'); - } - - await transport.close(); - }); - }); - - describe('callToolStream with failed task', () => { - it('should yield stored result (isError: true) when task fails, not a generic ProtocolError', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { tasks: {} } - } - ); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Use callToolStream with shouldFail: true so the tool stores a failed result - const stream = client.experimental.tasks.callToolStream( - { name: 'long-task', arguments: { duration: 100, shouldFail: true } }, - { task: { ttl: 60_000 } } - ); - - // Collect all stream messages - const messages: Array<{ type: string; task?: unknown; result?: unknown; error?: unknown }> = []; - for await (const message of stream) { - messages.push(message); - } - - // First message should be taskCreated - expect(messages[0]!.type).toBe('taskCreated'); - - // Last message must be 'result' (carrying the stored isError content), - // NOT 'error' (which would mean the generic hardcoded ProtocolError was returned) - const lastMessage = messages.at(-1)!; - expect(lastMessage.type).toBe('result'); - - // The stored result should contain isError: true and the real failure content - const result = lastMessage.result as { content: Array<{ type: string; text: string }>; isError: boolean }; - expect(result.isError).toBe(true); - expect(result.content).toEqual([{ type: 'text', text: 'Task failed as requested' }]); - - await transport.close(); - }, 15_000); - }); - - describe('callToolStream with elicitation', () => { - it('should deliver elicitation via callToolStream and complete task', async () => { - const client = new Client( - { - name: 'test-client', - version: '1.0.0' - }, - { - capabilities: { - elicitation: {}, - tasks: {} - } - } - ); - - // Track elicitation request receipt - let elicitationReceived = false; - let elicitationMessage = ''; - - // Set up elicitation handler on client - client.setRequestHandler('elicitation/create', async request => { - elicitationReceived = true; - elicitationMessage = request.params.message; - - return { - action: 'accept' as const, - content: { - userName: 'StreamUser' - } - }; - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Use callToolStream instead of raw request() - const stream = client.experimental.tasks.callToolStream( - { name: 'input-task', arguments: {} }, - { - task: { ttl: 60_000 } - } - ); - - // Collect all stream messages - const messages: Array<{ type: string; task?: unknown; result?: unknown; error?: unknown }> = []; - for await (const message of stream) { - messages.push(message); - } - - // Verify stream yielded expected message types - expect(messages.length).toBeGreaterThanOrEqual(2); - - // First message should be taskCreated - expect(messages[0]!.type).toBe('taskCreated'); - expect(messages[0]!.task).toBeDefined(); - - // Should have a taskStatus message - const statusMessages = messages.filter(m => m.type === 'taskStatus'); - expect(statusMessages.length).toBeGreaterThanOrEqual(1); - - // Last message should be result - const lastMessage = messages.at(-1)!; - expect(lastMessage.type).toBe('result'); - expect(lastMessage.result).toBeDefined(); - - // Verify elicitation was received and processed - expect(elicitationReceived).toBe(true); - expect(elicitationMessage).toContain('What is your name?'); - - // Verify result content - const result = lastMessage.result as { content: Array<{ type: string; text: string }> }; - expect(result.content).toEqual([{ type: 'text', text: 'Hello, StreamUser!' }]); - - await transport.close(); - }, 15_000); - }); -}); diff --git a/test/integration/test/taskResumability.test.ts b/test/integration/test/taskResumability.test.ts deleted file mode 100644 index f7b4174d18..0000000000 --- a/test/integration/test/taskResumability.test.ts +++ /dev/null @@ -1,300 +0,0 @@ -import { randomUUID } from 'node:crypto'; -import type { Server } from 'node:http'; -import { createServer } from 'node:http'; - -import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; -import { NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/node'; -import type { EventStore, JSONRPCMessage } from '@modelcontextprotocol/server'; -import { McpServer } from '@modelcontextprotocol/server'; -import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; -import * as z from 'zod/v4'; - -/** - * Simple in-memory EventStore for testing resumability. - */ -class InMemoryEventStore implements EventStore { - private events = new Map(); - - async storeEvent(streamId: string, message: JSONRPCMessage): Promise { - const eventId = `${streamId}_${Date.now()}_${Math.random().toString(36).slice(2, 10)}`; - this.events.set(eventId, { streamId, message }); - return eventId; - } - - async replayEventsAfter( - lastEventId: string, - { send }: { send: (eventId: string, message: JSONRPCMessage) => Promise } - ): Promise { - if (!lastEventId || !this.events.has(lastEventId)) return ''; - const streamId = lastEventId.split('_')[0] ?? ''; - if (!streamId) return ''; - - let found = false; - const sorted = [...this.events.entries()].toSorted((a, b) => a[0].localeCompare(b[0])); - for (const [eventId, { streamId: sid, message }] of sorted) { - if (sid !== streamId) continue; - if (eventId === lastEventId) { - found = true; - continue; - } - if (found) await send(eventId, message); - } - return streamId; - } -} - -describe('Zod v4', () => { - describe('Transport resumability', () => { - let server: Server; - let mcpServer: McpServer; - let serverTransport: NodeStreamableHTTPServerTransport; - let baseUrl: URL; - let eventStore: InMemoryEventStore; - - beforeEach(async () => { - // Create event store for resumability - eventStore = new InMemoryEventStore(); - - // Create a simple MCP server - mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); - - // Add a simple notification tool that completes quickly - mcpServer.registerTool( - 'send-notification', - { - description: 'Sends a single notification', - inputSchema: z.object({ - message: z.string().describe('Message to send').default('Test notification') - }) - }, - async ({ message }, ctx) => { - // Send notification immediately - await ctx.mcpReq.notify({ - method: 'notifications/message', - params: { - level: 'info', - data: message - } - }); - - return { - content: [{ type: 'text', text: 'Notification sent' }] - }; - } - ); - - // Add a long-running tool that sends multiple notifications - mcpServer.registerTool( - 'run-notifications', - { - description: 'Sends multiple notifications over time', - inputSchema: z.object({ - count: z.number().describe('Number of notifications to send').default(10), - interval: z.number().describe('Interval between notifications in ms').default(50) - }) - }, - async ({ count, interval }, ctx) => { - // Send notifications at specified intervals - for (let i = 0; i < count; i++) { - await ctx.mcpReq.notify({ - method: 'notifications/message', - params: { - level: 'info', - data: `Notification ${i + 1} of ${count}` - } - }); - - // Wait for the specified interval before sending next notification - if (i < count - 1) { - await new Promise(resolve => setTimeout(resolve, interval)); - } - } - - return { - content: [{ type: 'text', text: `Sent ${count} notifications` }] - }; - } - ); - - // Create a transport with the event store - serverTransport = new NodeStreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - eventStore - }); - - // Connect the transport to the MCP server - await mcpServer.connect(serverTransport); - - // Create and start an HTTP server - server = createServer(async (req, res) => { - await serverTransport.handleRequest(req, res); - }); - - // Start the server on a random port - baseUrl = await listenOnRandomPort(server); - }); - - afterEach(async () => { - // Clean up resources - await mcpServer.close().catch(() => {}); - await serverTransport.close().catch(() => {}); - server.close(); - }); - - it('should store session ID when client connects', async () => { - // Create and connect a client - const client = new Client({ - name: 'test-client', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - await client.connect(transport); - - // Verify session ID was generated - expect(transport.sessionId).toBeDefined(); - - // Clean up - await transport.close(); - }); - - it('should have session ID functionality', async () => { - // The ability to store a session ID when connecting - const client = new Client({ - name: 'test-client-reconnection', - version: '1.0.0' - }); - - const transport = new StreamableHTTPClientTransport(baseUrl); - - // Make sure the client can connect and get a session ID - await client.connect(transport); - expect(transport.sessionId).toBeDefined(); - - // Clean up - await transport.close(); - }); - - // This test demonstrates the capability to resume long-running tools - // across client disconnection/reconnection - it('should resume long-running notifications with lastEventId', async () => { - // Create unique client ID for this test - const clientTitle = 'test-client-long-running'; - const notifications = []; - let lastEventId: string | undefined; - - // Create first client - const client1 = new Client({ - title: clientTitle, - name: 'test-client', - version: '1.0.0' - }); - - // Set up notification handler for first client - client1.setNotificationHandler('notifications/message', notification => { - if (notification.method === 'notifications/message') { - notifications.push(notification.params); - } - }); - - // Connect first client - const transport1 = new StreamableHTTPClientTransport(baseUrl); - await client1.connect(transport1); - const sessionId = transport1.sessionId; - expect(sessionId).toBeDefined(); - - // Start a long-running notification stream with tracking of lastEventId - const onLastEventIdUpdate = vi.fn((eventId: string) => { - lastEventId = eventId; - }); - expect(lastEventId).toBeUndefined(); - // Start the notification tool with event tracking using request - const toolPromise = client1.request( - { - method: 'tools/call', - params: { - name: 'run-notifications', - arguments: { - count: 3, - interval: 10 - } - } - }, - { - resumptionToken: lastEventId, - onresumptiontoken: onLastEventIdUpdate - } - ); - - // Fix for node 18 test failures, allow some time for notifications to arrive - const maxWaitTime = 2000; // 2 seconds max wait - const pollInterval = 10; // Check every 10ms - const startTime = Date.now(); - while (notifications.length === 0 && Date.now() - startTime < maxWaitTime) { - // Wait for some notifications to arrive (not all) - shorter wait time - await new Promise(resolve => setTimeout(resolve, pollInterval)); - } - - // Verify we received some notifications and lastEventId was updated - expect(notifications.length).toBeGreaterThan(0); - expect(notifications.length).toBeLessThan(4); - expect(onLastEventIdUpdate).toHaveBeenCalled(); - expect(lastEventId).toBeDefined(); - - // Disconnect first client without waiting for completion - // When we close the connection, it will cause a ConnectionClosed error for - // any in-progress requests, which is expected behavior - await transport1.close(); - // Save the promise so we can catch it after closing - const catchPromise = toolPromise.catch(error => { - // This error is expected - the connection was intentionally closed - if (error?.code !== -32_000) { - // ConnectionClosed error code - console.error('Unexpected error type during transport close:', error); - } - }); - - // Add a short delay to ensure clean disconnect before reconnecting - await new Promise(resolve => setTimeout(resolve, 10)); - - // Wait for the rejection to be handled - await catchPromise; - - // Create second client with same client ID - const client2 = new Client({ - title: clientTitle, - name: 'test-client', - version: '1.0.0' - }); - - // Track replayed notifications separately - const replayedNotifications: unknown[] = []; - client2.setNotificationHandler('notifications/message', notification => { - if (notification.method === 'notifications/message') { - replayedNotifications.push(notification.params); - } - }); - - // Connect second client with same session ID - const transport2 = new StreamableHTTPClientTransport(baseUrl, { - sessionId - }); - await client2.connect(transport2); - - // Resume GET SSE stream with Last-Event-ID to replay missed events - // Per spec, resumption uses GET with Last-Event-ID header - await transport2.resumeStream(lastEventId!, { onresumptiontoken: onLastEventIdUpdate }); - - // Wait for replayed events to arrive via SSE - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify the test infrastructure worked - we received notifications in first session - // and captured the lastEventId for potential replay - expect(notifications.length).toBeGreaterThan(0); - expect(lastEventId).toBeDefined(); - - // Clean up - await transport2.close(); - }); - }); -});