diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md new file mode 100644 index 000000000..f5c064e7f --- /dev/null +++ b/.changeset/token-provider-composable-auth.md @@ -0,0 +1,16 @@ +--- +'@modelcontextprotocol/client': minor +--- + +Add `AuthProvider` for composable bearer-token auth; transports adapt `OAuthClientProvider` automatically + +- New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). +- Transport `authProvider` option now accepts `AuthProvider | OAuthClientProvider`. OAuth providers are adapted internally via `adaptOAuthProvider()` — no changes needed to existing `OAuthClientProvider` implementations. +- For simple bearer tokens (API keys, gateway-managed tokens, service accounts): `{ authProvider: { token: async () => myKey } }` — one-line object literal, no class. +- New `adaptOAuthProvider(provider)` export for explicit adaptation. +- New `handleOAuthUnauthorized(provider, ctx)` helper — the standard OAuth `onUnauthorized` behavior. +- New `isOAuthClientProvider()` type guard. +- New `UnauthorizedContext` type. +- Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. + +Transports are simplified internally — ~50 lines of inline OAuth orchestration (auth() calls, WWW-Authenticate parsing, circuit-breaker state) moved into the adapter's `onUnauthorized()` implementation. `OAuthClientProvider` itself is unchanged. diff --git a/docs/client.md b/docs/client.md index 782ab885b..b5086f531 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,7 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require OAuth 2.0 authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an `authProvider` to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} to enable this — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. + +### Bearer tokens + +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials — implement only `token()`. With no `onUnauthorized()`, a 401 throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} immediately: + +```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" +const authProvider: AuthProvider = { token: async () => getStoredToken() }; + +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); +``` + +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. ### Client credentials diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 38a8c372b..e5de064da 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -47,15 +47,15 @@ Replace all `@modelcontextprotocol/sdk/...` imports using this table. ### Server imports -| v1 import path | v2 package | -| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | | `@modelcontextprotocol/sdk/server/streamableHttp.js` | `@modelcontextprotocol/node` (class renamed to `NodeStreamableHTTPServerTransport`) OR `@modelcontextprotocol/server` (web-standard `WebStandardStreamableHTTPServerTransport` for Cloudflare Workers, Deno, etc.) | -| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | -| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | -| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | +| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | +| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | +| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | ### Types / shared imports @@ -107,19 +107,19 @@ Two error classes now exist: - **`ProtocolError`** (renamed from `McpError`): Protocol errors that cross the wire as JSON-RPC responses - **`SdkError`** (new): Local SDK errors that never cross the wire -| Error scenario | v1 type | v2 type | -| -------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | -| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | -| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | -| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | -| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | -| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | -| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | -| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | -| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | -| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | -| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | +| Error scenario | v1 type | v2 type | +| --------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | +| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | +| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | +| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | +| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | +| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | +| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | +| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | +| 401 after re-auth (circuit break) | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | +| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | +| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | New `SdkErrorCode` enum values: @@ -203,7 +203,8 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). +**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all +Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). ## 6. McpServer API Changes @@ -283,8 +284,8 @@ Note: the third argument (`metadata`) is required — pass `{}` if no metadata. |----------------|-----------------| | `{ name: z.string() }` | `z.object({ name: z.string() })` | | `{ count: z.number().optional() }` | `z.object({ count: z.number().optional() })` | -| `{}` (empty) | `z.object({})` | -| `undefined` (no schema) | `undefined` or omit the field | +| `{}` (empty) | `z.object({})` | +| `undefined` (no schema) | `undefined` or omit the field | ### Removed core exports @@ -379,31 +380,31 @@ Request/notification params remain fully typed. Remove unused schema imports aft `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. -| v1 | v2 | -|----|-----| -| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | -| `extra` (param name) | `ctx` | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| v1 | v2 | +| -------------------------------- | -------------------------------------------------------------------------- | +| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | +| `extra` (param name) | `ctx` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | `ServerContext` convenience methods (new in v2, no v1 equivalent): -| Method | Description | Replaces | -|--------|-------------|----------| -| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | -| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | -| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | +| Method | Description | Replaces | +| ---------------------------------------------- | ------------------------------------------------------ | ---------------------------------------------------- | +| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | +| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | +| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | ## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` @@ -422,14 +423,14 @@ const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { . const tool = await client.callTool({ name: 'my-tool', arguments: {} }); ``` -| v1 call | v2 call | -|---------|---------| -| `client.request(req, ResultSchema)` | `client.request(req)` | -| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | -| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | -| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | -| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | -| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +| v1 call | v2 call | +| ------------------------------------------------------------ | ---------------------------------- | +| `client.request(req, ResultSchema)` | `client.request(req)` | +| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | +| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | +| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | +| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | +| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. @@ -440,6 +441,7 @@ Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResu ## 13. Runtime-Specific JSON Schema Validators (Enhancement) The SDK now auto-selects the appropriate JSON Schema validator based on runtime: + - Node.js → `AjvJsonSchemaValidator` (no change from v1) - Cloudflare Workers (workerd) → `CfWorkerJsonSchemaValidator` (previously required manual config) @@ -447,9 +449,12 @@ The SDK now auto-selects the appropriate JSON Schema validator based on runtime: ```typescript // v1 (Cloudflare Workers): Required explicit validator -new McpServer({ name: 'server', version: '1.0.0' }, { - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() -}); +new McpServer( + { name: 'server', version: '1.0.0' }, + { + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() + } +); // v2 (Cloudflare Workers): Auto-selected, explicit config optional new McpServer({ name: 'server', version: '1.0.0' }, {}); diff --git a/docs/migration.md b/docs/migration.md index d2a9db947..64540bc92 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -264,6 +264,7 @@ server.registerTool('ping', { ``` This applies to: + - `inputSchema` in `registerTool()` - `outputSchema` in `registerTool()` - `argsSchema` in `registerPrompt()` @@ -360,25 +361,21 @@ Common method string replacements: ### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making requests. **`client.request()` — Before (v1):** ```typescript import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, - CallToolResultSchema -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, CallToolResultSchema); ``` **After (v2):** ```typescript -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } } -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }); ``` **`ctx.mcpReq.send()` — Before (v1):** @@ -411,10 +408,7 @@ server.setRequestHandler('tools/call', async (request, ctx) => { ```typescript import { CompatibilityCallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.callTool( - { name: 'my-tool', arguments: {} }, - CompatibilityCallToolResultSchema -); +const result = await client.callTool({ name: 'my-tool', arguments: {} }, CompatibilityCallToolResultSchema); ``` **After (v2):** @@ -473,32 +467,32 @@ import { JSONRPCErrorResponse, ResourceTemplateReference, isJSONRPCErrorResponse The `RequestHandlerExtra` type has been replaced with a structured context type hierarchy using nested groups: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| ---------------------------------------- | ---------------------------------------------------------------------- | | `RequestHandlerExtra` (flat, all fields) | `ServerContext` (server handlers) or `ClientContext` (client handlers) | -| `extra` parameter name | `ctx` parameter name | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra` parameter name | `ctx` parameter name | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | **Before (v1):** ```typescript server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - const headers = extra.requestInfo?.headers; - const taskStore = extra.taskStore; - await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = extra.requestInfo?.headers; + const taskStore = extra.taskStore; + await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -506,10 +500,10 @@ server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - const headers = ctx.http?.req?.headers; - const taskStore = ctx.task?.store; - await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = ctx.http?.req?.headers; + const taskStore = ctx.task?.store; + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -525,22 +519,22 @@ Context fields are organized into 4 groups: ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - // Send a log message (respects client's log level filter) - await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); - - // Request client to sample an LLM - const samplingResult = await ctx.mcpReq.requestSampling({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100, - }); - - // Elicit user input via a form - const elicitResult = await ctx.mcpReq.elicitInput({ - message: 'Please provide details', - requestedSchema: { type: 'object', properties: { name: { type: 'string' } } }, - }); - - return { content: [{ type: 'text', text: 'done' }] }; + // Send a log message (respects client's log level filter) + await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); + + // Request client to sample an LLM + const samplingResult = await ctx.mcpReq.requestSampling({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + // Elicit user input via a form + const elicitResult = await ctx.mcpReq.elicitInput({ + message: 'Please provide details', + requestedSchema: { type: 'object', properties: { name: { type: 'string' } } } + }); + + return { content: [{ type: 'text', text: 'done' }] }; }); ``` @@ -602,21 +596,21 @@ try { The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: -| Code | Description | -| ------------------------------------------------- | ------------------------------------------ | -| `SdkErrorCode.NotConnected` | Transport is not connected | -| `SdkErrorCode.AlreadyConnected` | Transport is already connected | -| `SdkErrorCode.NotInitialized` | Protocol is not initialized | -| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | -| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | -| `SdkErrorCode.ConnectionClosed` | Connection was closed | -| `SdkErrorCode.SendFailed` | Failed to send message | -| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after successful auth | -| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | -| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | -| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | -| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | +| Code | Description | +| ------------------------------------------------- | ------------------------------------------- | +| `SdkErrorCode.NotConnected` | Transport is not connected | +| `SdkErrorCode.AlreadyConnected` | Transport is already connected | +| `SdkErrorCode.NotInitialized` | Protocol is not initialized | +| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | +| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | +| `SdkErrorCode.ConnectionClosed` | Connection was closed | +| `SdkErrorCode.SendFailed` | Failed to send message | +| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | +| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after re-authentication | +| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | +| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | +| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | +| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | #### `StreamableHTTPError` removed @@ -647,7 +641,7 @@ try { if (error instanceof SdkError) { switch (error.code) { case SdkErrorCode.ClientHttpAuthentication: - console.log('Auth failed after completing auth flow'); + console.log('Auth failed — server rejected token after re-auth'); break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); @@ -667,7 +661,8 @@ try { #### Why this change? -Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was semantically inconsistent. +Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was +semantically inconsistent. The new design: @@ -764,11 +759,11 @@ This means Cloudflare Workers users no longer need to explicitly pass the valida import { McpServer, CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { - capabilities: { tools: {} }, - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 - } + { name: 'my-server', version: '1.0.0' }, + { + capabilities: { tools: {} }, + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 + } ); ``` @@ -778,9 +773,9 @@ const server = new McpServer( import { McpServer } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { capabilities: { tools: {} } } - // Validator auto-selected based on runtime + { name: 'my-server', version: '1.0.0' }, + { capabilities: { tools: {} } } + // Validator auto-selected based on runtime ); ``` diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index 389059024..f07d272db 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,6 +107,16 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- +/** Example: Minimal AuthProvider for bearer auth with externally-managed tokens. */ +async function auth_tokenProvider(getStoredToken: () => Promise) { + //#region auth_tokenProvider + const authProvider: AuthProvider = { token: async () => getStoredToken() }; + + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); + //#endregion auth_tokenProvider + return transport; +} + /** Example: Client credentials auth for service-to-service communication. */ async function auth_clientCredentials() { //#region auth_clientCredentials @@ -540,6 +550,7 @@ void connect_stdio; void connect_sseFallback; void disconnect_streamableHttp; void serverInstructions_basic; +void auth_tokenProvider; void auth_clientCredentials; void auth_privateKeyJwt; void auth_crossAppAccess; diff --git a/examples/client/src/dualModeAuth.ts b/examples/client/src/dualModeAuth.ts new file mode 100644 index 000000000..4dd1eaded --- /dev/null +++ b/examples/client/src/dualModeAuth.ts @@ -0,0 +1,114 @@ +#!/usr/bin/env node + +/** + * Two auth patterns through the same `authProvider` option. + * + * The transport accepts either a minimal `AuthProvider` (just `token()` + + * optional `onUnauthorized()`) or a full `OAuthClientProvider`, adapting + * the latter automatically. This means your connect/call code is identical + * regardless of which pattern fits your deployment. + * + * HOST-MANAGED — token lives in an enclosing app + * The app fetches and stores tokens; the MCP client just reads them. + * On 401, there is nothing to refresh — signal the UI and throw so the + * user can re-authenticate through the host's flow. + * + * USER-CONFIGURED — OAuth credentials supplied directly + * Pass a built-in or custom OAuthClientProvider. The transport handles + * the full OAuth flow: token refresh on 401, or redirect for interactive + * authorization. + */ + +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, ClientCredentialsProvider, StreamableHTTPClientTransport, UnauthorizedError } from '@modelcontextprotocol/client'; + +// --- Stubs for host-app integration points --------------------------------- + +/** Whatever the host app uses to store session state (e.g., cookies, keychain, in-memory). */ +interface HostSessionStore { + getMcpToken(): string | undefined; +} + +/** Whatever the host app uses to surface UI prompts. */ +interface HostUi { + showReauthPrompt(message: string): void; +} + +// --- MODE A: Host-managed auth --------------------------------------------- + +function createHostManagedTransport(serverUrl: URL, session: HostSessionStore, ui: HostUi): StreamableHTTPClientTransport { + const authProvider: AuthProvider = { + // Called before every request — just read whatever the host has. + token: async () => session.getMcpToken(), + + // Called on 401 — don't refresh (the host owns the token), signal the UI and bail. + // The transport will retry once after this returns, so we throw to stop it: + // the user needs to act before a retry makes sense. + onUnauthorized: async () => { + ui.showReauthPrompt('MCP connection lost — click to reconnect'); + throw new UnauthorizedError('Host token rejected — user action required'); + } + }; + + return new StreamableHTTPClientTransport(serverUrl, { authProvider }); +} + +// --- MODE B: User-configured OAuth ----------------------------------------- + +function createUserConfiguredTransport(serverUrl: URL, clientId: string, clientSecret: string): StreamableHTTPClientTransport { + // Built-in OAuth provider — the transport adapts it to AuthProvider internally. + // On 401, adaptOAuthProvider synthesizes onUnauthorized → handleOAuthUnauthorized, + // which runs token refresh (or redirect for interactive flows). + const authProvider = new ClientCredentialsProvider({ clientId, clientSecret }); + + return new StreamableHTTPClientTransport(serverUrl, { authProvider }); +} + +// --- Same caller code for both modes --------------------------------------- + +async function connectAndList(transport: StreamableHTTPClientTransport): Promise { + const client = new Client({ name: 'dual-mode-example', version: '1.0.0' }, { capabilities: {} }); + await client.connect(transport); + + const tools = await client.listTools(); + console.log('Tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +// --- Driver ---------------------------------------------------------------- + +async function main() { + const serverUrl = new URL(process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'); + const mode = process.argv[2] || 'host'; + + let transport: StreamableHTTPClientTransport; + + if (mode === 'host') { + // Simulate a host app with a session-stored token and a UI hook. + const session: HostSessionStore = { getMcpToken: () => process.env.MCP_TOKEN }; + const ui: HostUi = { showReauthPrompt: msg => console.error(`[UI] ${msg}`) }; + transport = createHostManagedTransport(serverUrl, session, ui); + } else if (mode === 'oauth') { + const clientId = process.env.OAUTH_CLIENT_ID; + const clientSecret = process.env.OAUTH_CLIENT_SECRET; + if (!clientId || !clientSecret) { + console.error('OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET required for oauth mode'); + process.exit(1); + } + transport = createUserConfiguredTransport(serverUrl, clientId, clientSecret); + } else { + console.error(`Unknown mode: ${mode}. Use 'host' or 'oauth'.`); + process.exit(1); + } + + // Same connect/list code regardless of mode — the transport abstracts the difference. + await connectAndList(transport); +} + +try { + await main(); +} catch (error) { + console.error('Error:', error); + process.exitCode = 1; +} diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts new file mode 100644 index 000000000..ce68fde5a --- /dev/null +++ b/examples/client/src/simpleTokenProvider.ts @@ -0,0 +1,55 @@ +#!/usr/bin/env node + +/** + * Example demonstrating the minimal AuthProvider for bearer token authentication. + * + * AuthProvider is the base interface for all client auth. For simple cases where + * tokens are managed externally — pre-configured API tokens, gateway/proxy patterns, + * or tokens obtained through a separate auth flow — implement only `token()`. + * + * For OAuth flows (client_credentials, private_key_jwt, etc.), use the built-in + * providers which implement both `token()` and `onUnauthorized()`. + * + * Environment variables: + * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) + * MCP_TOKEN - Bearer token to use for authentication (required) + */ + +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; + +const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; + +async function main() { + const token = process.env.MCP_TOKEN; + if (!token) { + console.error('MCP_TOKEN environment variable is required'); + process.exit(1); + } + + // AuthProvider with just token() — the simplest possible auth. + // token() is called before every request, so it can handle refresh internally. + // With no onUnauthorized(), a 401 throws UnauthorizedError immediately. + const authProvider: AuthProvider = { + token: async () => token + }; + + const client = new Client({ name: 'auth-provider-example', version: '1.0.0' }, { capabilities: {} }); + + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { authProvider }); + + await client.connect(transport); + console.log('Connected successfully.'); + + const tools = await client.listTools(); + console.log('Available tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +try { + await main(); +} catch (error) { + console.error('Error running client:', error); + process.exitCode = 1; +} diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index bedfd8743..1a021be18 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -35,12 +35,110 @@ export type AddClientAuthentication = ( metadata?: AuthorizationServerMetadata ) => void | Promise; +/** + * Context passed to {@linkcode AuthProvider.onUnauthorized} when the server + * responds with 401. Provides everything needed to refresh credentials. + */ +export interface UnauthorizedContext { + /** The 401 response — inspect `WWW-Authenticate` for resource metadata, scope, etc. */ + response: Response; + /** The MCP server URL, for passing to {@linkcode auth} or discovery helpers. */ + serverUrl: URL; + /** Fetch function configured with the transport's `requestInit`, for making auth requests. */ + fetchFn: FetchLike; +} + +/** + * Minimal interface for authenticating MCP client transports with bearer tokens. + * + * Transports call {@linkcode AuthProvider.token | token()} before every request + * to obtain the current token, and {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * (if provided) when the server responds with 401, giving the provider a chance + * to refresh credentials before the transport retries once. + * + * For simple cases (API keys, gateway-managed tokens), implement only `token()`: + * ```typescript + * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; + * ``` + * + * For OAuth flows, pass an {@linkcode OAuthClientProvider} directly — transports + * accept either shape and adapt OAuth providers automatically via {@linkcode adaptOAuthProvider}. + */ +export interface AuthProvider { + /** + * Returns the current bearer token, or `undefined` if no token is available. + * Called before every request. + */ + token(): Promise; + + /** + * Called when the server responds with 401. If provided, the transport will + * await this, then retry the request once. If the retry also gets 401, or if + * this method is not provided, the transport throws {@linkcode UnauthorizedError}. + * + * Implementations should refresh tokens, re-authenticate, etc. — whatever is + * needed so the next `token()` call returns a valid token. + */ + onUnauthorized?(ctx: UnauthorizedContext): Promise; +} + +/** + * Type guard distinguishing `OAuthClientProvider` from a minimal `AuthProvider`. + * Transports use this at construction time to classify the `authProvider` option. + * + * Checks for `tokens()` + `clientInformation()` — two required `OAuthClientProvider` + * methods that a minimal `AuthProvider` `{ token: ... }` would never have. + */ +export function isOAuthClientProvider(provider: AuthProvider | OAuthClientProvider | undefined): provider is OAuthClientProvider { + if (provider == null) return false; + const p = provider as OAuthClientProvider; + return typeof p.tokens === 'function' && typeof p.clientInformation === 'function'; +} + +/** + * Standard `onUnauthorized` behavior for OAuth providers: extracts + * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. + * Used by {@linkcode adaptOAuthProvider} to bridge `OAuthClientProvider` to `AuthProvider`. + */ +export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); + const result = await auth(provider, { + serverUrl: ctx.serverUrl, + resourceMetadataUrl, + scope, + fetchFn: ctx.fetchFn + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } +} + +/** + * Adapts an `OAuthClientProvider` to the minimal `AuthProvider` interface that + * transports consume. Called once at transport construction — the transport stores + * the adapted provider for `_commonHeaders()` and 401 handling, while keeping the + * original `OAuthClientProvider` for OAuth-specific paths (`finishAuth()`, 403 upscoping). + */ +export function adaptOAuthProvider(provider: OAuthClientProvider): AuthProvider { + return { + token: async () => { + const tokens = await provider.tokens(); + return tokens?.access_token; + }, + onUnauthorized: async ctx => handleOAuthUnauthorized(provider, ctx) + }; +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * * This client relies upon a concept of an authorized "session," the exact * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. + * + * Transports accept `OAuthClientProvider` directly via the `authProvider` option — + * they adapt it to {@linkcode AuthProvider} internally via {@linkcode adaptOAuthProvider}. + * No changes are needed to existing implementations. */ export interface OAuthClientProvider { /** @@ -382,7 +480,7 @@ export function applyClientAuthentication( /** * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) */ -function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { +export function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { if (!clientSecret) { throw new Error('client_secret_basic authentication requires a client_secret'); } @@ -394,7 +492,7 @@ function applyBasicAuth(clientId: string, clientSecret: string | undefined, head /** * Applies POST body authentication (RFC 6749 Section 2.3.1) */ -function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { +export function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { params.set('client_id', clientId); if (clientSecret) { params.set('client_secret', clientSecret); @@ -404,7 +502,7 @@ function applyPostAuth(clientId: string, clientSecret: string | undefined, param /** * Applies public client authentication (RFC 6749 Section 2.1) */ -function applyPublicAuth(clientId: string, params: URLSearchParams): void { +export function applyPublicAuth(clientId: string, params: URLSearchParams): void { params.set('client_id', clientId); } @@ -1304,7 +1402,7 @@ export function prepareAuthorizationCodeRequest( * Internal helper to execute a token request with the given parameters. * Used by {@linkcode exchangeAuthorization}, {@linkcode refreshAuthorization}, and {@linkcode fetchToken}. */ -async function executeTokenRequest( +export async function executeTokenRequest( authorizationServerUrl: string | URL, { metadata, diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 133aa0004..f441e9cdb 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,8 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -23,18 +23,19 @@ export type SSEClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the SSE connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode SSEClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode SSEClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation. + * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - authProvider?: OAuthClientProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -71,7 +72,8 @@ export class SSEClientTransport implements Transport { private _scope?: string; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; + private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -86,43 +88,23 @@ export class SSEClientTransport implements Transport { this._scope = undefined; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuth(); - } + private _last401Response?: Response; private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -149,10 +131,13 @@ export class SSEClientTransport implements Transport { headers }); - if (response.status === 401 && response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + this._last401Response = response; + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } } return response; @@ -162,7 +147,24 @@ export class SSEClientTransport implements Transport { this._eventSource.onerror = event => { if (event.code === 401 && this._authProvider) { - this._authThenStart().then(resolve, reject); + if (this._authProvider.onUnauthorized && this._last401Response) { + const response = this._last401Response; + this._last401Response = undefined; + this._eventSource?.close(); + this._authProvider.onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }).then( + // onUnauthorized succeeded → retry fresh. Its onerror handles its own onerror?.() + reject. + () => this._startOrAuth().then(resolve, reject), + // onUnauthorized failed → not yet reported. + error => { + this.onerror?.(error); + reject(error); + } + ); + return; + } + const error = new UnauthorizedError(); + reject(error); + this.onerror?.(error); return; } @@ -221,11 +223,11 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!this._oauthProvider) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -244,6 +246,10 @@ export class SSEClientTransport implements Transport { } async send(message: JSONRPCMessage): Promise { + return this._send(message, false); + } + + private async _send(message: JSONRPCMessage, isAuthRetry: boolean): Promise { if (!this._endpoint) { throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); } @@ -261,27 +267,33 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; - - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; } - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + if (this._authProvider.onUnauthorized && !isAuthRetry) { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._send(message, true); + } + await response.text?.().catch(() => {}); + if (isAuthRetry) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + throw new UnauthorizedError(); } + const text = await response.text?.().catch(() => null); throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index a39a6c1d7..3d45b60e9 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,8 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -85,18 +85,21 @@ export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode StreamableHTTPClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode StreamableHTTPClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation + * directly — the transport adapts it to `AuthProvider` internally. Interactive flows: after + * {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before + * reconnecting. */ - authProvider?: OAuthClientProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes HTTP requests to the server. @@ -138,13 +141,13 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _scope?: string; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; + private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401 private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -158,7 +161,12 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._scope = undefined; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -166,38 +174,11 @@ export class StreamableHTTPClientTransport implements Transport { this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuthSse({ resumptionToken: undefined }); - } - private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._sessionId) { @@ -215,7 +196,7 @@ export class StreamableHTTPClientTransport implements Transport { }); } - private async _startOrAuthSse(options: StartSSEOptions): Promise { + private async _startOrAuthSse(options: StartSSEOptions, isAuthRetry = false): Promise { const { resumptionToken } = options; try { @@ -237,13 +218,34 @@ export class StreamableHTTPClientTransport implements Transport { }); if (!response.ok) { - await response.text?.().catch(() => {}); - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } + + if (this._authProvider.onUnauthorized && !isAuthRetry) { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._startOrAuthSse(options, true); + } + await response.text?.().catch(() => {}); + if (isAuthRetry) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + throw new UnauthorizedError(); } + await response.text?.().catch(() => {}); + // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { @@ -439,11 +441,11 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!this._oauthProvider) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -467,6 +469,14 @@ export class StreamableHTTPClientTransport implements Transport { async send( message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } + ): Promise { + return this._send(message, options, false); + } + + private async _send( + message: JSONRPCMessage | JSONRPCMessage[], + options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined, + isAuthRetry: boolean ): Promise { try { const { resumptionToken, onresumptiontoken } = options || {}; @@ -500,38 +510,36 @@ export class StreamableHTTPClientTransport implements Transport { } if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after successful authentication', { - status: 401, - text - }); + // Store WWW-Authenticate params for interactive finishAuth() path + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; } - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; - - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); + if (this._authProvider.onUnauthorized && !isAuthRetry) { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._send(message, options, true); } - - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + await response.text?.().catch(() => {}); + if (isAuthRetry) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } + throw new UnauthorizedError(); } - if (response.status === 403 && this._authProvider) { + const text = await response.text?.().catch(() => null); + + if (response.status === 403 && this._oauthProvider) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -555,18 +563,18 @@ export class StreamableHTTPClientTransport implements Transport { // Mark that upscoping was tried. this._lastUpscopingHeader = wwwAuthHeader ?? undefined; - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, - fetchFn: this._fetch + fetchFn: this._fetchWithInit }); if (result !== 'AUTHORIZED') { throw new UnauthorizedError(); } - return this.send(message); + return this._send(message, options, isAuthRetry); } } @@ -576,8 +584,6 @@ export class StreamableHTTPClientTransport implements Transport { }); } - // Reset auth loop flag on successful response - this._hasCompletedAuthFlow = false; this._lastUpscopingHeader = undefined; // If the response is 202 Accepted, there's no body to process diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 0b0aff67b..b0b9588f0 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -3,11 +3,11 @@ import { createServer } from 'node:http'; import type { AddressInfo } from 'node:net'; import type { JSONRPCMessage, OAuthTokens } from '@modelcontextprotocol/core'; -import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; +import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; @@ -1528,4 +1528,172 @@ describe('SSEClientTransport', () => { expect(globalFetchSpy).not.toHaveBeenCalled(); }); }); + + describe('minimal AuthProvider (non-OAuth)', () => { + let postResponses: number[]; + let postCount: number; + + async function setupServer(): Promise { + await resourceServer.close(); + + postCount = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.method === 'GET') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + return; + } + + if (req.method === 'POST') { + const status = postResponses[postCount] ?? 200; + postCount++; + res.writeHead(status).end(); + return; + } + }); + + resourceBaseUrl = await listenOnRandomPort(resourceServer); + } + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: '1' }; + + it('throws UnauthorizedError on POST 401 when onUnauthorized is not provided', async () => { + postResponses = [401]; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + }); + + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws SdkError', async () => { + postResponses = [401, 401]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(postCount).toBe(2); + }); + + it('resets retry guard when onUnauthorized throws, allowing retry on next send', async () => { + postResponses = [401, 401, 200]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + // First send: 401 → onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so 401 → onUnauthorized (succeeds) → retry → 200 + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(postCount).toBe(3); + }); + + it('throws when finishAuth is called with a non-OAuth AuthProvider', async () => { + postResponses = []; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + + it('SSE connect 401 retry does not poison future 401s — onUnauthorized called on each attempt', async () => { + // Regression: _startOrAuth(true) baked isAuthRetry=true into the retry EventSource's + // onerror closure, so a subsequent 401 (token expiry on reconnect) would throw + // instead of refreshing. Fix: retry always calls _startOrAuth() fresh. + await resourceServer.close(); + + let getAttempt = 0; + resourceServer = createServer((req, res) => { + if (req.method !== 'GET') { + res.writeHead(404).end(); + return; + } + getAttempt++; + if (getAttempt < 3) { + res.writeHead(401).end(); + return; + } + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + }); + resourceBaseUrl = await listenOnRandomPort(resourceServer); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + + await transport.start(); // should resolve on attempt 3 + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(getAttempt).toBe(3); + }); + + it('retry failure during SSE connect fires onerror exactly once', async () => { + // Regression: when the retry EventSource rejected, its onerror fired inside, then + // the outer .then() rejection handler fired onerror AGAIN for the same error. + // Fix: inner retry chains to .then(resolve, reject) — no outer onerror call. + // onUnauthorized's own failure is handled separately and fires onerror once. + await resourceServer.close(); + + resourceServer = createServer((req, res) => { + if (req.method === 'GET') { + res.writeHead(401).end(); // always 401 + } + }); + resourceBaseUrl = await listenOnRandomPort(resourceServer); + + const onUnauthorized: AuthProvider['onUnauthorized'] = vi + .fn() + .mockResolvedValueOnce(undefined) // first call succeeds → triggers retry + .mockRejectedValueOnce(new Error('refresh failed')); // second call (in retry) throws + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + const onerror = vi.fn(); + transport.onerror = onerror; + + await expect(transport.start()).rejects.toThrow('refresh failed'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(onerror).toHaveBeenCalledTimes(1); + expect(onerror.mock.calls[0]![0].message).toBe('refresh failed'); + }); + }); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 3830fcd1d..55bf79a50 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1705,7 +1705,9 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow('Server returned 401 after successful authentication'); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts new file mode 100644 index 000000000..d6ef35bde --- /dev/null +++ b/packages/client/test/client/tokenProvider.test.ts @@ -0,0 +1,315 @@ +import type { IncomingMessage, Server } from 'node:http'; +import { createServer } from 'node:http'; + +import type { JSONRPCMessage, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; +import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; +import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; +import type { Mock } from 'vitest'; + +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; + +describe('StreamableHTTPClientTransport with AuthProvider', () => { + let transport: StreamableHTTPClientTransport; + + afterEach(async () => { + await transport?.close().catch(() => {}); + vi.clearAllMocks(); + }); + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: 'test-id' }; + + it('should set Authorization header from AuthProvider.token()', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'my-bearer-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.token).toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); + }); + + it('should not set Authorization header when token() returns undefined', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw UnauthorizedError on 401 when onUnauthorized is not provided', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'rejected-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.token).toHaveBeenCalledTimes(1); + }); + + it('should call onUnauthorized and retry once on 401', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); + + it('should throw SdkError(ClientHttpAuthentication) if retry after onUnauthorized also gets 401', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); + + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + }); + + it('should reset retry guard when onUnauthorized throws, allowing retry on next send', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + // First send: onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so onUnauthorized gets a second chance + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + }); + + it('should work with no authProvider at all', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw when finishAuth is called with a non-OAuth AuthProvider', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + + it('should throw UnauthorizedError on GET-SSE 401 with no onUnauthorized (via resumeStream)', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.resumeStream('last-event-id')).rejects.toThrow(UnauthorizedError); + }); + + it('should call onUnauthorized and retry on GET-SSE 401 (via resumeStream)', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + // First GET: 401. Second GET (retry): 405 (server doesn't offer SSE — clean exit) + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 405, headers: new Headers(), text: async () => '' }); + + await transport.resumeStream('last-event-id'); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); +}); + +describe('AuthProvider integration — both modes against a real server', () => { + let server: Server; + let serverUrl: URL; + let capturedRequests: IncomingMessage[]; + let transport: StreamableHTTPClientTransport; + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'ping', params: {}, id: '1' }; + + beforeEach(async () => { + capturedRequests = []; + server = createServer((req, res) => { + capturedRequests.push(req); + if (req.method === 'POST') { + // Consume body then respond 202 Accepted + req.on('data', () => {}); + req.on('end', () => res.writeHead(202).end()); + } else { + // GET SSE — reject so the transport skips it + res.writeHead(405).end(); + } + }); + serverUrl = await listenOnRandomPort(server); + }); + + afterEach(async () => { + await transport?.close().catch(() => {}); + await new Promise(resolve => server.close(() => resolve())); + }); + + it('MODE A: minimal AuthProvider { token } sends Authorization header', async () => { + const authProvider: AuthProvider = { token: async () => 'mode-a-token' }; + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + + await transport.send(message); + + expect(capturedRequests).toHaveLength(1); + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-a-token'); + }); + + it('MODE A: onUnauthorized signals and throws — caller sees the error', async () => { + const uiSignal = vi.fn(); + const authProvider: AuthProvider = { + token: async () => 'rejected-token', + onUnauthorized: async () => { + uiSignal('show-reauth-prompt'); + throw new UnauthorizedError('user action required'); + } + }; + + // Server that rejects with 401 + await new Promise(resolve => server.close(() => resolve())); + server = createServer((req, res) => { + capturedRequests.push(req); + req.on('data', () => {}); + req.on('end', () => res.writeHead(401).end()); + }); + serverUrl = await listenOnRandomPort(server); + + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + + await expect(transport.send(message)).rejects.toThrow('user action required'); + expect(uiSignal).toHaveBeenCalledWith('show-reauth-prompt'); + }); + + it('MODE B: OAuthClientProvider is adapted — tokens() becomes token() on the wire', async () => { + // Minimal OAuthClientProvider — the transport should adapt it via adaptOAuthProvider + const oauthProvider: OAuthClientProvider = { + get redirectUrl() { + return undefined; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: [], grant_types: ['client_credentials'] }; + }, + clientInformation(): OAuthClientInformation { + return { client_id: 'test-client' }; + }, + tokens(): OAuthTokens { + return { access_token: 'mode-b-oauth-token', token_type: 'bearer' }; + }, + saveTokens() {}, + redirectToAuthorization() { + throw new Error('not used'); + }, + saveCodeVerifier() {}, + codeVerifier() { + throw new Error('not used'); + } + }; + + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider: oauthProvider }); + + await transport.send(message); + + expect(capturedRequests).toHaveLength(1); + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-b-oauth-token'); + }); + + it('both modes use the same option slot and same send() call', async () => { + // Mode A + const transportA = new StreamableHTTPClientTransport(serverUrl, { + authProvider: { token: async () => 'a-token' } + }); + await transportA.send(message); + await transportA.close(); + + // Mode B — same constructor, same option name, different shape + const transportB = new StreamableHTTPClientTransport(serverUrl, { + authProvider: { + get redirectUrl() { + return undefined; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: [] }; + }, + clientInformation: () => ({ client_id: 'x' }), + tokens: () => ({ access_token: 'b-token', token_type: 'bearer' }), + saveTokens() {}, + redirectToAuthorization() {}, + saveCodeVerifier() {}, + codeVerifier: () => '' + } satisfies OAuthClientProvider + }); + await transportB.send(message); + await transportB.close(); + + expect(capturedRequests.map(r => r.headers.authorization)).toEqual(['Bearer a-token', 'Bearer b-token']); + }); +});