From 4bb3370c38315533ef187e36bebf3c908a5e46d0 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 16:01:11 +0000 Subject: [PATCH 01/13] feat: add TokenProvider for composable bearer-token auth (non-breaking) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a minimal `() => Promise` function type as a lightweight alternative to OAuthClientProvider, for scenarios where bearer tokens are managed externally (gateway/proxy patterns, service accounts, API keys). - New TokenProvider type + withBearerAuth(getToken, fetchFn?) helper - New tokenProvider option on StreamableHTTPClientTransport and SSEClientTransport, used as fallback after authProvider in _commonHeaders(). authProvider takes precedence when both set. - On 401 with tokenProvider (no authProvider), transports throw UnauthorizedError — no retry, since tokenProvider() is already called before every request and would likely return the same rejected token. Callers catch UnauthorizedError, invalidate external cache, reconnect. - Exported previously-internal auth helpers for building custom flows: applyBasicAuth, applyPostAuth, applyPublicAuth, executeTokenRequest. - Tests, example, docs, changeset. Zero breakage. Bughunter fleet review: 28 findings submitted, 2 confirmed, both addressed. --- .changeset/token-provider-composable-auth.md | 10 + docs/client.md | 16 +- examples/client/src/clientGuide.examples.ts | 13 +- examples/client/src/simpleTokenProvider.ts | 69 ++++++ packages/client/src/client/auth.ts | 8 +- packages/client/src/client/sse.ts | 67 ++++-- packages/client/src/client/streamableHttp.ts | 85 ++++--- packages/client/src/client/tokenProvider.ts | 53 +++++ packages/client/src/index.ts | 1 + .../client/test/client/tokenProvider.test.ts | 208 ++++++++++++++++++ 10 files changed, 478 insertions(+), 52 deletions(-) create mode 100644 .changeset/token-provider-composable-auth.md create mode 100644 examples/client/src/simpleTokenProvider.ts create mode 100644 packages/client/src/client/tokenProvider.ts create mode 100644 packages/client/test/client/tokenProvider.test.ts diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md new file mode 100644 index 000000000..50b296298 --- /dev/null +++ b/.changeset/token-provider-composable-auth.md @@ -0,0 +1,10 @@ +--- +'@modelcontextprotocol/client': minor +--- + +Add `TokenProvider` for simple bearer-token authentication and export composable auth primitives + +- New `TokenProvider` type — a minimal `() => Promise` function interface for supplying bearer tokens. Use this instead of `OAuthClientProvider` when tokens are managed externally (gateway/proxy patterns, service accounts, upfront API tokens, or any scenario where the full OAuth redirect flow is not needed). +- New `tokenProvider` option on `StreamableHTTPClientTransport` and `SSEClientTransport`. Called before every request to obtain a fresh token. If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. +- New `withBearerAuth(getToken, fetchFn?)` helper that wraps a fetch function to inject `Authorization: Bearer` headers — useful for composing with other fetch middleware. +- Exported previously-internal auth helpers for building custom auth flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. diff --git a/docs/client.md b/docs/client.md index 782ab885b..467df2789 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,7 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require OAuth 2.0 authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an `authProvider` to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport} to enable this — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). For servers that accept plain bearer tokens, pass a `tokenProvider` function to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. For servers that require OAuth 2.0, pass an `authProvider` — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. + +### Token provider + +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials, or tokens obtained through a separate auth flow — pass a {@linkcode @modelcontextprotocol/client!client/tokenProvider.TokenProvider | TokenProvider} function. It is called before every request, so it can handle expiry and refresh internally. If the server rejects the token with 401, the transport throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} without retrying — catch it to invalidate any external cache and reconnect: + +```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" +const tokenProvider: TokenProvider = async () => getStoredToken(); + +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); +``` + +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. For finer control, {@linkcode @modelcontextprotocol/client!client/tokenProvider.withBearerAuth | withBearerAuth} wraps a fetch function directly. ### Client credentials diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index 389059024..c34a3a574 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, Tool } from '@modelcontextprotocol/client'; +import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,6 +107,16 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- +/** Example: TokenProvider for bearer auth with externally-managed tokens. */ +async function auth_tokenProvider(getStoredToken: () => Promise) { + //#region auth_tokenProvider + const tokenProvider: TokenProvider = async () => getStoredToken(); + + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); + //#endregion auth_tokenProvider + return transport; +} + /** Example: Client credentials auth for service-to-service communication. */ async function auth_clientCredentials() { //#region auth_clientCredentials @@ -540,6 +550,7 @@ void connect_stdio; void connect_sseFallback; void disconnect_streamableHttp; void serverInstructions_basic; +void auth_tokenProvider; void auth_clientCredentials; void auth_privateKeyJwt; void auth_crossAppAccess; diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts new file mode 100644 index 000000000..f6829f556 --- /dev/null +++ b/examples/client/src/simpleTokenProvider.ts @@ -0,0 +1,69 @@ +#!/usr/bin/env node + +/** + * Example demonstrating TokenProvider for simple bearer token authentication. + * + * TokenProvider is a lightweight alternative to OAuthClientProvider for cases + * where tokens are managed externally — e.g., pre-configured API tokens, + * gateway/proxy patterns, or tokens obtained through a separate auth flow. + * + * Environment variables: + * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) + * MCP_TOKEN - Bearer token to use for authentication (required) + * + * Two approaches are demonstrated: + * 1. Using `tokenProvider` option on the transport (simplest) + * 2. Using `withBearerAuth` to wrap a custom fetch function (more flexible) + */ + +import type { TokenProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport, withBearerAuth } from '@modelcontextprotocol/client'; + +const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; + +async function main() { + const token = process.env.MCP_TOKEN; + if (!token) { + console.error('MCP_TOKEN environment variable is required'); + process.exit(1); + } + + // A TokenProvider is just an async function that returns a token string. + // It is called before every request, so it can handle refresh logic internally. + const tokenProvider: TokenProvider = async () => token; + + const client = new Client({ name: 'token-provider-example', version: '1.0.0' }, { capabilities: {} }); + + // Approach 1: Pass tokenProvider directly to the transport. + // This is the simplest way to add bearer auth. + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { + tokenProvider + }); + + // Approach 2 (alternative): Use withBearerAuth to wrap fetch. + // This is useful when you need more control over the fetch behavior, + // or when composing with other fetch wrappers. + // + // const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { + // fetch: withBearerAuth(tokenProvider), + // }); + + await client.connect(transport); + console.log('Connected successfully.'); + + const tools = await client.listTools(); + console.log('Available tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +try { + await main(); +} catch (error) { + console.error('Error running client:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); +} + +// Referenced in the commented-out Approach 2 above; kept so uncommenting it type-checks. +void withBearerAuth; diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index 58ec23ddd..c47a57f27 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -381,7 +381,7 @@ export function applyClientAuthentication( /** * Applies HTTP Basic authentication (RFC 6749 Section 2.3.1) */ -function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { +export function applyBasicAuth(clientId: string, clientSecret: string | undefined, headers: Headers): void { if (!clientSecret) { throw new Error('client_secret_basic authentication requires a client_secret'); } @@ -393,7 +393,7 @@ function applyBasicAuth(clientId: string, clientSecret: string | undefined, head /** * Applies POST body authentication (RFC 6749 Section 2.3.1) */ -function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { +export function applyPostAuth(clientId: string, clientSecret: string | undefined, params: URLSearchParams): void { params.set('client_id', clientId); if (clientSecret) { params.set('client_secret', clientSecret); @@ -403,7 +403,7 @@ function applyPostAuth(clientId: string, clientSecret: string | undefined, param /** * Applies public client authentication (RFC 6749 Section 2.1) */ -function applyPublicAuth(clientId: string, params: URLSearchParams): void { +export function applyPublicAuth(clientId: string, params: URLSearchParams): void { params.set('client_id', clientId); } @@ -1265,7 +1265,7 @@ export function prepareAuthorizationCodeRequest( * Internal helper to execute a token request with the given parameters. * Used by {@linkcode exchangeAuthorization}, {@linkcode refreshAuthorization}, and {@linkcode fetchToken}. */ -async function executeTokenRequest( +export async function executeTokenRequest( authorizationServerUrl: string | URL, { metadata, diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 133aa0004..e5b04a258 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -5,6 +5,7 @@ import { EventSource } from 'eventsource'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { TokenProvider } from './tokenProvider.js'; export class SseError extends Error { constructor( @@ -36,6 +37,16 @@ export type SSEClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A simple token provider for bearer authentication. + * + * Use this instead of `authProvider` when tokens are managed externally + * (e.g., upfront auth, gateway/proxy patterns, service accounts). + * + * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + */ + tokenProvider?: TokenProvider; + /** * Customizes the initial SSE request to the server (the request that begins the stream). * @@ -72,6 +83,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _tokenProvider?: TokenProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -87,6 +99,7 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } @@ -123,6 +136,11 @@ export class SSEClientTransport implements Transport { if (tokens) { headers['Authorization'] = `Bearer ${tokens.access_token}`; } + } else if (this._tokenProvider) { + const token = await this._tokenProvider(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -161,9 +179,17 @@ export class SSEClientTransport implements Transport { this._abortController = new AbortController(); this._eventSource.onerror = event => { - if (event.code === 401 && this._authProvider) { - this._authThenStart().then(resolve, reject); - return; + if (event.code === 401) { + if (this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } + if (this._tokenProvider) { + const error = new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + reject(error); + this.onerror?.(error); + return; + } } const error = new SseError(event.code, event.message, event); @@ -263,23 +289,28 @@ export class SSEClientTransport implements Transport { if (!response.ok) { const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + if (this._authProvider) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + scope: this._scope, + fetchFn: this._fetchWithInit + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); } - - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); } throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index a39a6c1d7..bbaa08ca9 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -15,6 +15,7 @@ import { EventSourceParserStream } from 'eventsource-parser/stream'; import type { AuthResult, OAuthClientProvider } from './auth.js'; import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; +import type { TokenProvider } from './tokenProvider.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -98,6 +99,16 @@ export type StreamableHTTPClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A simple token provider for bearer authentication. + * + * Use this instead of `authProvider` when tokens are managed externally + * (e.g., upfront auth, gateway/proxy patterns, service accounts). + * + * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + */ + tokenProvider?: TokenProvider; + /** * Customizes HTTP requests to the server. */ @@ -139,6 +150,7 @@ export class StreamableHTTPClientTransport implements Transport { private _scope?: string; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _tokenProvider?: TokenProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; @@ -159,6 +171,7 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -198,6 +211,11 @@ export class StreamableHTTPClientTransport implements Transport { if (tokens) { headers['Authorization'] = `Bearer ${tokens.access_token}`; } + } else if (this._tokenProvider) { + const token = await this._tokenProvider(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; + } } if (this._sessionId) { @@ -239,9 +257,13 @@ export class StreamableHTTPClientTransport implements Transport { if (!response.ok) { await response.text?.().catch(() => {}); - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + if (response.status === 401) { + if (this._authProvider) { + return await this._authThenStart(); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + } } // 405 indicates that the server does not offer an SSE stream at GET endpoint @@ -502,33 +524,42 @@ export class StreamableHTTPClientTransport implements Transport { if (!response.ok) { const text = await response.text?.().catch(() => null); - if (response.status === 401 && this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after successful authentication', { - status: 401, - text - }); - } + if (response.status === 401) { + if (this._authProvider) { + // Prevent infinite recursion when server returns 401 after successful auth + if (this._hasCompletedAuthFlow) { + throw new SdkError( + SdkErrorCode.ClientHttpAuthentication, + 'Server returned 401 after successful authentication', + { + status: 401, + text + } + ); + } - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; - const result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } + const result = await auth(this._authProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + scope: this._scope, + fetchFn: this._fetchWithInit + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + // Mark that we completed auth flow + this._hasCompletedAuthFlow = true; + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } + if (this._tokenProvider) { + throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + } } if (response.status === 403 && this._authProvider) { diff --git a/packages/client/src/client/tokenProvider.ts b/packages/client/src/client/tokenProvider.ts new file mode 100644 index 000000000..ab8f2bbc9 --- /dev/null +++ b/packages/client/src/client/tokenProvider.ts @@ -0,0 +1,53 @@ +/** + * Minimal interface for providing bearer tokens to MCP transports. + * + * Unlike `OAuthClientProvider` which assumes interactive browser-redirect OAuth, + * `TokenProvider` is a simple function that returns a token string. + * Use this for upfront auth, gateway/proxy patterns, service accounts, + * or any scenario where tokens are managed externally. + * + * The provider is called before every request. If the server responds with 401, + * the transport throws `UnauthorizedError` without retrying — the provider is + * assumed to have already returned its freshest token. Catch `UnauthorizedError` + * to invalidate any external cache and reconnect. + * + * @example + * ```typescript + * // Static token + * const provider: TokenProvider = async () => "my-api-token"; + * + * // Token from secure storage with refresh + * const provider: TokenProvider = async () => { + * const token = await storage.getToken(); + * if (isExpiringSoon(token)) { + * return (await refreshToken(token)).accessToken; + * } + * return token.accessToken; + * }; + * ``` + */ +export type TokenProvider = () => Promise; + +/** + * Wraps a fetch function to automatically inject Bearer authentication headers. + * + * @example + * ```typescript + * const authedFetch = withBearerAuth(async () => getStoredToken()); + * const transport = new StreamableHTTPClientTransport(url, { fetch: authedFetch }); + * ``` + */ +export function withBearerAuth( + getToken: TokenProvider, + fetchFn: (url: string | URL, init?: RequestInit) => Promise = globalThis.fetch +): (url: string | URL, init?: RequestInit) => Promise { + return async (url, init) => { + const token = await getToken(); + if (token) { + const headers = new Headers(init?.headers); + headers.set('Authorization', `Bearer ${token}`); + return fetchFn(url, { ...init, headers }); + } + return fetchFn(url, init); + }; +} diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index c37d9fe28..b72b3e2d9 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -6,6 +6,7 @@ export * from './client/middleware.js'; export * from './client/sse.js'; export * from './client/stdio.js'; export * from './client/streamableHttp.js'; +export * from './client/tokenProvider.js'; export * from './client/websocket.js'; // experimental exports diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts new file mode 100644 index 000000000..111a7b6a5 --- /dev/null +++ b/packages/client/test/client/tokenProvider.test.ts @@ -0,0 +1,208 @@ +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import type { Mock } from 'vitest'; + +import type { TokenProvider } from '../../src/client/tokenProvider.js'; +import { withBearerAuth } from '../../src/client/tokenProvider.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; + +describe('withBearerAuth', () => { + it('should inject Authorization header when token is available', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => 'test-token-123'; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { method: 'POST' }); + + expect(mockFetch).toHaveBeenCalledOnce(); + const [url, init] = mockFetch.mock.calls[0]!; + expect(url).toBe('https://example.com/api'); + expect(new Headers(init.headers).get('Authorization')).toBe('Bearer test-token-123'); + }); + + it('should not inject Authorization header when token is undefined', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => undefined; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { method: 'POST' }); + + expect(mockFetch).toHaveBeenCalledOnce(); + const [, init] = mockFetch.mock.calls[0]!; + expect(new Headers(init?.headers).has('Authorization')).toBe(false); + }); + + it('should preserve existing headers', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + const getToken: TokenProvider = async () => 'my-token'; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/api', { + headers: { 'Content-Type': 'application/json', 'X-Custom': 'value' } + }); + + const [, init] = mockFetch.mock.calls[0]!; + const headers = new Headers(init.headers); + expect(headers.get('Authorization')).toBe('Bearer my-token'); + expect(headers.get('Content-Type')).toBe('application/json'); + expect(headers.get('X-Custom')).toBe('value'); + }); + + it('should call getToken on every request', async () => { + const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); + let callCount = 0; + const getToken: TokenProvider = async () => `token-${++callCount}`; + + const authedFetch = withBearerAuth(getToken, mockFetch); + await authedFetch('https://example.com/1'); + await authedFetch('https://example.com/2'); + + expect(new Headers(mockFetch.mock.calls[0]![1]!.headers).get('Authorization')).toBe('Bearer token-1'); + expect(new Headers(mockFetch.mock.calls[1]![1]!.headers).get('Authorization')).toBe('Bearer token-2'); + }); +}); + +describe('StreamableHTTPClientTransport with tokenProvider', () => { + let transport: StreamableHTTPClientTransport; + + afterEach(async () => { + await transport?.close().catch(() => {}); + vi.clearAllMocks(); + }); + + it('should set Authorization header from tokenProvider', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'my-bearer-token'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + expect(tokenProvider).toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); + }); + + it('should not set Authorization header when tokenProvider returns undefined', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => undefined); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); + + it('should throw UnauthorizedError on 401 when using tokenProvider', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'rejected-token'); + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(tokenProvider).toHaveBeenCalledTimes(1); + }); + + it('should prefer authProvider over tokenProvider when both are set', async () => { + const tokenProvider: TokenProvider = vi.fn(async () => 'token-provider-value'); + const authProvider = { + get redirectUrl() { + return 'http://localhost/callback'; + }, + get clientMetadata() { + return { redirect_uris: ['http://localhost/callback'] }; + }, + clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-secret' })), + tokens: vi.fn(() => ({ access_token: 'auth-provider-value', token_type: 'bearer' })), + saveTokens: vi.fn(), + redirectToAuthorization: vi.fn(), + saveCodeVerifier: vi.fn(), + codeVerifier: vi.fn() + }; + + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider, tokenProvider }); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + // authProvider should be used, not tokenProvider + expect(tokenProvider).not.toHaveBeenCalled(); + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.get('Authorization')).toBe('Bearer auth-provider-value'); + }); + + it('should work with no auth at all', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); + vi.spyOn(globalThis, 'fetch'); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'test', + params: {}, + id: 'test-id' + }; + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers() + }); + + await transport.send(message); + + const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; + expect(init.headers.has('Authorization')).toBe(false); + }); +}); From 0ef50969e39905c73aae7e3b18e45f3d15590d87 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 16:01:11 +0000 Subject: [PATCH 02/13] BREAKING: unify client auth around minimal AuthProvider interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Transports now accept AuthProvider { token(), onUnauthorized() } instead of being typed as OAuthClientProvider. OAuthClientProvider extends AuthProvider, so built-in providers work unchanged — custom implementations add two methods (both TypeScript-enforced). Core changes: - New AuthProvider interface — transports only need token() + onUnauthorized(), not the full 21-member OAuth interface - OAuthClientProvider extends AuthProvider; onUnauthorized() is required (not optional) on OAuthClientProvider since OAuth providers that omit it lose all 401 recovery. The 4 built-in providers implement both methods, delegating to new handleOAuthUnauthorized helper. - Transports call authProvider.token() in _commonHeaders() — one code path, no precedence rules - Transports call authProvider.onUnauthorized() on 401, retry once — ~50 lines of inline OAuth orchestration removed per transport. Circuit breaker via _authRetryInFlight (reset in outer catch so transient onUnauthorized failures don't permanently disable retries). - Response body consumption deferred until after the onUnauthorized branch so custom implementations can read ctx.response.text() - WWW-Authenticate extraction guarded with headers.has() check (pre-existing inconsistency; the SSE connect path already did this) - finishAuth() and 403 upscoping gated on isOAuthClientProvider() - TokenProvider type + tokenProvider option deleted — subsumed by { token: async () => ... } as authProvider Simple case: { authProvider: { token: async () => apiKey } } — no class needed, TypeScript structural typing. auth() and authInternal() (227 LOC of OAuth orchestration) untouched. They still take OAuthClientProvider. Only the transport/provider boundary moved. See docs/migration.md and docs/migration-SKILL.md for before/after. --- .changeset/token-provider-composable-auth.md | 19 +- docs/client.md | 14 +- docs/migration-SKILL.md | 126 ++++++--- docs/migration.md | 155 +++++++---- examples/client/src/clientGuide.examples.ts | 8 +- .../client/src/simpleOAuthClientProvider.ts | 17 +- examples/client/src/simpleTokenProvider.ts | 47 ++-- packages/client/src/client/auth.examples.ts | 10 +- packages/client/src/client/auth.ts | 91 +++++- packages/client/src/client/authExtensions.ts | 35 ++- packages/client/src/client/sse.ts | 141 ++++------ packages/client/src/client/streamableHttp.ts | 155 ++++------- packages/client/src/client/tokenProvider.ts | 53 ---- packages/client/src/index.ts | 1 - packages/client/test/client/auth.test.ts | 10 + .../client/test/client/middleware.test.ts | 4 + packages/client/test/client/sse.test.ts | 124 ++++++++- .../client/test/client/streamableHttp.test.ts | 10 +- .../client/test/client/tokenProvider.test.ts | 263 ++++++++---------- 19 files changed, 733 insertions(+), 550 deletions(-) delete mode 100644 packages/client/src/client/tokenProvider.ts diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md index 50b296298..c4ea7f5e3 100644 --- a/.changeset/token-provider-composable-auth.md +++ b/.changeset/token-provider-composable-auth.md @@ -1,10 +1,17 @@ --- -'@modelcontextprotocol/client': minor +'@modelcontextprotocol/client': major --- -Add `TokenProvider` for simple bearer-token authentication and export composable auth primitives +Unify client auth around a minimal `AuthProvider` interface -- New `TokenProvider` type — a minimal `() => Promise` function interface for supplying bearer tokens. Use this instead of `OAuthClientProvider` when tokens are managed externally (gateway/proxy patterns, service accounts, upfront API tokens, or any scenario where the full OAuth redirect flow is not needed). -- New `tokenProvider` option on `StreamableHTTPClientTransport` and `SSEClientTransport`. Called before every request to obtain a fresh token. If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. -- New `withBearerAuth(getToken, fetchFn?)` helper that wraps a fetch function to inject `Authorization: Bearer` headers — useful for composing with other fetch middleware. -- Exported previously-internal auth helpers for building custom auth flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. +**Breaking:** Transport `authProvider` option now accepts the new minimal `AuthProvider` interface instead of being typed as `OAuthClientProvider`. `OAuthClientProvider` now extends `AuthProvider`, so most existing code continues to work — but custom implementations must add a `token()` method. + +- New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). +- `OAuthClientProvider` extends `AuthProvider`. Custom implementations must add `token()` (typically `return (await this.tokens())?.access_token`) and optionally `onUnauthorized()` (typically `return handleOAuthUnauthorized(this, ctx)`). +- Built-in providers (`ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, `CrossAppAccessProvider`) implement both methods — existing user code is unchanged. +- New `handleOAuthUnauthorized(provider, ctx)` helper runs the standard OAuth flow from `onUnauthorized`. +- New `isOAuthClientProvider()` type guard for gating OAuth-specific transport features like `finishAuth()`. +- Transports no longer inline OAuth orchestration — ~50 lines of `auth()` calls, WWW-Authenticate parsing, and circuit-breaker state moved into `onUnauthorized()` implementations. +- Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. + +See `docs/migration.md` for before/after examples. diff --git a/docs/client.md b/docs/client.md index 467df2789..b5086f531 100644 --- a/docs/client.md +++ b/docs/client.md @@ -13,7 +13,7 @@ A client connects to a server, discovers what it offers — tools, resources, pr The examples below use these imports. Adjust based on which features and transport you need: ```ts source="../examples/client/src/clientGuide.examples.ts#imports" -import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -113,19 +113,19 @@ console.log(systemPrompt); ## Authentication -MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). For servers that accept plain bearer tokens, pass a `tokenProvider` function to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. For servers that require OAuth 2.0, pass an `authProvider` — the SDK provides built-in providers for common machine-to-machine flows, or you can implement the full {@linkcode @modelcontextprotocol/client!client/auth.OAuthClientProvider | OAuthClientProvider} interface for user-facing OAuth. +MCP servers can require authentication before accepting client connections (see [Authorization](https://modelcontextprotocol.io/specification/latest/basic/authorization) in the MCP specification). Pass an {@linkcode @modelcontextprotocol/client!client/auth.AuthProvider | AuthProvider} to {@linkcode @modelcontextprotocol/client!client/streamableHttp.StreamableHTTPClientTransport | StreamableHTTPClientTransport}. The transport calls `token()` before every request and `onUnauthorized()` (if provided) on 401, then retries once. -### Token provider +### Bearer tokens -For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials, or tokens obtained through a separate auth flow — pass a {@linkcode @modelcontextprotocol/client!client/tokenProvider.TokenProvider | TokenProvider} function. It is called before every request, so it can handle expiry and refresh internally. If the server rejects the token with 401, the transport throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} without retrying — catch it to invalidate any external cache and reconnect: +For servers that accept bearer tokens managed outside the SDK — API keys, tokens from a gateway or proxy, service-account credentials — implement only `token()`. With no `onUnauthorized()`, a 401 throws {@linkcode @modelcontextprotocol/client!client/auth.UnauthorizedError | UnauthorizedError} immediately: ```ts source="../examples/client/src/clientGuide.examples.ts#auth_tokenProvider" -const tokenProvider: TokenProvider = async () => getStoredToken(); +const authProvider: AuthProvider = { token: async () => getStoredToken() }; -const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); +const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); ``` -See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. For finer control, {@linkcode @modelcontextprotocol/client!client/tokenProvider.withBearerAuth | withBearerAuth} wraps a fetch function directly. +See [`simpleTokenProvider.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/client/src/simpleTokenProvider.ts) for a complete runnable example. ### Client credentials diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 38a8c372b..d942277eb 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -47,15 +47,15 @@ Replace all `@modelcontextprotocol/sdk/...` imports using this table. ### Server imports -| v1 import path | v2 package | -| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | -| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | +| v1 import path | v2 package | +| ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `@modelcontextprotocol/sdk/server/mcp.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/index.js` | `@modelcontextprotocol/server` | +| `@modelcontextprotocol/sdk/server/stdio.js` | `@modelcontextprotocol/server` | | `@modelcontextprotocol/sdk/server/streamableHttp.js` | `@modelcontextprotocol/node` (class renamed to `NodeStreamableHTTPServerTransport`) OR `@modelcontextprotocol/server` (web-standard `WebStandardStreamableHTTPServerTransport` for Cloudflare Workers, Deno, etc.) | -| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | -| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | -| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | +| `@modelcontextprotocol/sdk/server/sse.js` | REMOVED (migrate to Streamable HTTP) | +| `@modelcontextprotocol/sdk/server/auth/*` | REMOVED (use external auth library) | +| `@modelcontextprotocol/sdk/server/middleware.js` | `@modelcontextprotocol/express` (signature changed, see section 8) | ### Types / shared imports @@ -203,7 +203,41 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). +### Client `OAuthClientProvider` now extends `AuthProvider` + +Transport `authProvider` options now accept the minimal `AuthProvider` interface. `OAuthClientProvider` extends it, so built-in providers work unchanged — custom implementations must add `token()`. + +| v1 pattern | v2 equivalent | +| ----------------------------------------------------- | --------------------------------------------------------------------------- | +| `authProvider?: OAuthClientProvider` (option type) | `authProvider?: AuthProvider` (accepts `OAuthClientProvider` via extension) | +| Transport reads `authProvider.tokens()?.access_token` | Transport calls `authProvider.token()` | +| Transport inlines `auth()` on 401 | Transport calls `authProvider.onUnauthorized()` then retries once | +| `_hasCompletedAuthFlow` circuit breaker | `_authRetryInFlight` circuit breaker | +| N/A | `handleOAuthUnauthorized(provider, ctx)` — standard `onUnauthorized` impl | +| N/A | `isOAuthClientProvider(provider)` — type guard | +| N/A | `UnauthorizedContext` — `{ response, serverUrl, fetchFn }` | + +**For custom `OAuthClientProvider` implementations**, add both methods (both required — TypeScript enforces this): + +```typescript +async token(): Promise { + return (await this.tokens())?.access_token; +} + +async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); +} +``` + +**For simple bearer tokens** (previously required stubbing 8 `OAuthClientProvider` members): + +```typescript +// v2: one-liner +const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; +``` + +**Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all +Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). ## 6. McpServer API Changes @@ -283,8 +317,8 @@ Note: the third argument (`metadata`) is required — pass `{}` if no metadata. |----------------|-----------------| | `{ name: z.string() }` | `z.object({ name: z.string() })` | | `{ count: z.number().optional() }` | `z.object({ count: z.number().optional() })` | -| `{}` (empty) | `z.object({})` | -| `undefined` (no schema) | `undefined` or omit the field | +| `{}` (empty) | `z.object({})` | +| `undefined` (no schema) | `undefined` or omit the field | ### Removed core exports @@ -379,31 +413,31 @@ Request/notification params remain fully typed. Remove unused schema imports aft `RequestHandlerExtra` → structured context types with nested groups. Rename `extra` → `ctx` in all handler callbacks. -| v1 | v2 | -|----|-----| -| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | -| `extra` (param name) | `ctx` | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| v1 | v2 | +| -------------------------------- | -------------------------------------------------------------------------- | +| `RequestHandlerExtra` | `ServerContext` (server) / `ClientContext` (client) / `BaseContext` (base) | +| `extra` (param name) | `ctx` | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.requestInfo` | `ctx.http?.req` (only `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only `ServerContext`) | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | `ServerContext` convenience methods (new in v2, no v1 equivalent): -| Method | Description | Replaces | -|--------|-------------|----------| -| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | -| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | -| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | +| Method | Description | Replaces | +| ---------------------------------------------- | ------------------------------------------------------ | ---------------------------------------------------- | +| `ctx.mcpReq.log(level, data, logger?)` | Send log notification (respects client's level filter) | `server.sendLoggingMessage(...)` from within handler | +| `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | +| `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | ## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` @@ -422,14 +456,14 @@ const elicit = await ctx.mcpReq.send({ method: 'elicitation/create', params: { . const tool = await client.callTool({ name: 'my-tool', arguments: {} }); ``` -| v1 call | v2 call | -|---------|---------| -| `client.request(req, ResultSchema)` | `client.request(req)` | -| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | -| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | -| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | -| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | -| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +| v1 call | v2 call | +| ------------------------------------------------------------ | ---------------------------------- | +| `client.request(req, ResultSchema)` | `client.request(req)` | +| `client.request(req, ResultSchema, options)` | `client.request(req, options)` | +| `ctx.mcpReq.send(req, ResultSchema)` | `ctx.mcpReq.send(req)` | +| `ctx.mcpReq.send(req, ResultSchema, options)` | `ctx.mcpReq.send(req, options)` | +| `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | +| `client.callTool(params, schema, options)` | `client.callTool(params, options)` | Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. @@ -440,6 +474,7 @@ Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResu ## 13. Runtime-Specific JSON Schema Validators (Enhancement) The SDK now auto-selects the appropriate JSON Schema validator based on runtime: + - Node.js → `AjvJsonSchemaValidator` (no change from v1) - Cloudflare Workers (workerd) → `CfWorkerJsonSchemaValidator` (previously required manual config) @@ -447,9 +482,12 @@ The SDK now auto-selects the appropriate JSON Schema validator based on runtime: ```typescript // v1 (Cloudflare Workers): Required explicit validator -new McpServer({ name: 'server', version: '1.0.0' }, { - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() -}); +new McpServer( + { name: 'server', version: '1.0.0' }, + { + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() + } +); // v2 (Cloudflare Workers): Auto-selected, explicit config optional new McpServer({ name: 'server', version: '1.0.0' }, {}); diff --git a/docs/migration.md b/docs/migration.md index d2a9db947..6703482f1 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -264,6 +264,7 @@ server.registerTool('ping', { ``` This applies to: + - `inputSchema` in `registerTool()` - `outputSchema` in `registerTool()` - `argsSchema` in `registerPrompt()` @@ -360,25 +361,21 @@ Common method string replacements: ### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making requests. **`client.request()` — Before (v1):** ```typescript import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, - CallToolResultSchema -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }, CallToolResultSchema); ``` **After (v2):** ```typescript -const result = await client.request( - { method: 'tools/call', params: { name: 'my-tool', arguments: {} } } -); +const result = await client.request({ method: 'tools/call', params: { name: 'my-tool', arguments: {} } }); ``` **`ctx.mcpReq.send()` — Before (v1):** @@ -411,10 +408,7 @@ server.setRequestHandler('tools/call', async (request, ctx) => { ```typescript import { CompatibilityCallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; -const result = await client.callTool( - { name: 'my-tool', arguments: {} }, - CompatibilityCallToolResultSchema -); +const result = await client.callTool({ name: 'my-tool', arguments: {} }, CompatibilityCallToolResultSchema); ``` **After (v2):** @@ -473,32 +467,32 @@ import { JSONRPCErrorResponse, ResourceTemplateReference, isJSONRPCErrorResponse The `RequestHandlerExtra` type has been replaced with a structured context type hierarchy using nested groups: -| v1 | v2 | -|----|-----| +| v1 | v2 | +| ---------------------------------------- | ---------------------------------------------------------------------- | | `RequestHandlerExtra` (flat, all fields) | `ServerContext` (server handlers) or `ClientContext` (client handlers) | -| `extra` parameter name | `ctx` parameter name | -| `extra.signal` | `ctx.mcpReq.signal` | -| `extra.requestId` | `ctx.mcpReq.id` | -| `extra._meta` | `ctx.mcpReq._meta` | -| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | -| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | -| `extra.authInfo` | `ctx.http?.authInfo` | -| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | -| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | -| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | -| `extra.sessionId` | `ctx.sessionId` | -| `extra.taskStore` | `ctx.task?.store` | -| `extra.taskId` | `ctx.task?.id` | -| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | +| `extra` parameter name | `ctx` parameter name | +| `extra.signal` | `ctx.mcpReq.signal` | +| `extra.requestId` | `ctx.mcpReq.id` | +| `extra._meta` | `ctx.mcpReq._meta` | +| `extra.sendRequest(...)` | `ctx.mcpReq.send(...)` | +| `extra.sendNotification(...)` | `ctx.mcpReq.notify(...)` | +| `extra.authInfo` | `ctx.http?.authInfo` | +| `extra.requestInfo` | `ctx.http?.req` (only on `ServerContext`) | +| `extra.closeSSEStream` | `ctx.http?.closeSSE` (only on `ServerContext`) | +| `extra.closeStandaloneSSEStream` | `ctx.http?.closeStandaloneSSE` (only on `ServerContext`) | +| `extra.sessionId` | `ctx.sessionId` | +| `extra.taskStore` | `ctx.task?.store` | +| `extra.taskId` | `ctx.task?.id` | +| `extra.taskRequestedTtl` | `ctx.task?.requestedTtl` | **Before (v1):** ```typescript server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { - const headers = extra.requestInfo?.headers; - const taskStore = extra.taskStore; - await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = extra.requestInfo?.headers; + const taskStore = extra.taskStore; + await extra.sendNotification({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -506,10 +500,10 @@ server.setRequestHandler(CallToolRequestSchema, async (request, extra) => { ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - const headers = ctx.http?.req?.headers; - const taskStore = ctx.task?.store; - await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); - return { content: [{ type: 'text', text: 'result' }] }; + const headers = ctx.http?.req?.headers; + const taskStore = ctx.task?.store; + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 'abc', progress: 50, total: 100 } }); + return { content: [{ type: 'text', text: 'result' }] }; }); ``` @@ -525,22 +519,22 @@ Context fields are organized into 4 groups: ```typescript server.setRequestHandler('tools/call', async (request, ctx) => { - // Send a log message (respects client's log level filter) - await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); - - // Request client to sample an LLM - const samplingResult = await ctx.mcpReq.requestSampling({ - messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], - maxTokens: 100, - }); - - // Elicit user input via a form - const elicitResult = await ctx.mcpReq.elicitInput({ - message: 'Please provide details', - requestedSchema: { type: 'object', properties: { name: { type: 'string' } } }, - }); - - return { content: [{ type: 'text', text: 'done' }] }; + // Send a log message (respects client's log level filter) + await ctx.mcpReq.log('info', 'Processing tool call', 'my-logger'); + + // Request client to sample an LLM + const samplingResult = await ctx.mcpReq.requestSampling({ + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 100 + }); + + // Elicit user input via a form + const elicitResult = await ctx.mcpReq.elicitInput({ + message: 'Please provide details', + requestedSchema: { type: 'object', properties: { name: { type: 'string' } } } + }); + + return { content: [{ type: 'text', text: 'done' }] }; }); ``` @@ -667,13 +661,52 @@ try { #### Why this change? -Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was semantically inconsistent. +Previously, `ErrorCode.RequestTimeout` (-32001) and `ErrorCode.ConnectionClosed` (-32000) were used for local timeout/connection errors. However, these errors never cross the wire as JSON-RPC responses - they are rejected locally. Using protocol error codes for local errors was +semantically inconsistent. The new design: - `ProtocolError` with `ProtocolErrorCode`: For errors that are serialized and sent as JSON-RPC error responses - `SdkError` with `SdkErrorCode`: For local errors that are thrown/rejected locally and never leave the SDK +### Client `authProvider` unified around `AuthProvider` + +Transport `authProvider` options now accept the minimal `AuthProvider` interface rather than being typed as `OAuthClientProvider`. `OAuthClientProvider` extends `AuthProvider`, so built-in providers and most existing code continue to work unchanged — but custom +`OAuthClientProvider` implementations must add a `token()` method. + +**What changed:** transports now call `authProvider.token()` before every request (instead of `authProvider.tokens()?.access_token`), and call `authProvider.onUnauthorized()` on 401 (instead of inlining OAuth orchestration). One code path handles both simple bearer tokens and +full OAuth. + +**If you implement `OAuthClientProvider` directly** (the interactive browser-redirect pattern), add: + +```ts +class MyProvider implements OAuthClientProvider { + // ...existing 8 required members... + + // Required: return the current access token + async token(): Promise { + return (await this.tokens())?.access_token; + } + + // Required: runs the OAuth flow on 401 — without this, 401 throws with no recovery + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } +} +``` + +**If you use `ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, or `CrossAppAccessProvider`** — no change. These already implement both methods. + +**If you have simple bearer tokens** (API keys, gateway tokens, externally-managed tokens), you can now skip `OAuthClientProvider` entirely: + +```ts +// Before: had to implement 8 OAuthClientProvider members with no-op stubs +// After: +const transport = new StreamableHTTPClientTransport(url, { + authProvider: { token: async () => process.env.API_KEY } +}); +``` + ### OAuth error refactoring The OAuth error classes have been consolidated into a single `OAuthError` class with an `OAuthErrorCode` enum. @@ -764,11 +797,11 @@ This means Cloudflare Workers users no longer need to explicitly pass the valida import { McpServer, CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { - capabilities: { tools: {} }, - jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 - } + { name: 'my-server', version: '1.0.0' }, + { + capabilities: { tools: {} }, + jsonSchemaValidator: new CfWorkerJsonSchemaValidator() // Required in v1 + } ); ``` @@ -778,9 +811,9 @@ const server = new McpServer( import { McpServer } from '@modelcontextprotocol/server'; const server = new McpServer( - { name: 'my-server', version: '1.0.0' }, - { capabilities: { tools: {} } } - // Validator auto-selected based on runtime + { name: 'my-server', version: '1.0.0' }, + { capabilities: { tools: {} } } + // Validator auto-selected based on runtime ); ``` diff --git a/examples/client/src/clientGuide.examples.ts b/examples/client/src/clientGuide.examples.ts index c34a3a574..f07d272db 100644 --- a/examples/client/src/clientGuide.examples.ts +++ b/examples/client/src/clientGuide.examples.ts @@ -8,7 +8,7 @@ */ //#region imports -import type { Prompt, Resource, TokenProvider, Tool } from '@modelcontextprotocol/client'; +import type { AuthProvider, Prompt, Resource, Tool } from '@modelcontextprotocol/client'; import { applyMiddlewares, Client, @@ -107,12 +107,12 @@ async function serverInstructions_basic(client: Client) { // Authentication // --------------------------------------------------------------------------- -/** Example: TokenProvider for bearer auth with externally-managed tokens. */ +/** Example: Minimal AuthProvider for bearer auth with externally-managed tokens. */ async function auth_tokenProvider(getStoredToken: () => Promise) { //#region auth_tokenProvider - const tokenProvider: TokenProvider = async () => getStoredToken(); + const authProvider: AuthProvider = { token: async () => getStoredToken() }; - const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { tokenProvider }); + const transport = new StreamableHTTPClientTransport(new URL('http://localhost:3000/mcp'), { authProvider }); //#endregion auth_tokenProvider return transport; } diff --git a/examples/client/src/simpleOAuthClientProvider.ts b/examples/client/src/simpleOAuthClientProvider.ts index 96655c9f6..6248d1f90 100644 --- a/examples/client/src/simpleOAuthClientProvider.ts +++ b/examples/client/src/simpleOAuthClientProvider.ts @@ -1,4 +1,11 @@ -import type { OAuthClientInformationMixed, OAuthClientMetadata, OAuthClientProvider, OAuthTokens } from '@modelcontextprotocol/client'; +import type { + OAuthClientInformationMixed, + OAuthClientMetadata, + OAuthClientProvider, + OAuthTokens, + UnauthorizedContext +} from '@modelcontextprotocol/client'; +import { handleOAuthUnauthorized } from '@modelcontextprotocol/client'; /** * In-memory OAuth client provider for demonstration purposes @@ -24,6 +31,14 @@ export class InMemoryOAuthClientProvider implements OAuthClientProvider { private _onRedirect: (url: URL) => void; + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): string | URL { return this._redirectUrl; } diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts index f6829f556..7b5f1a4c1 100644 --- a/examples/client/src/simpleTokenProvider.ts +++ b/examples/client/src/simpleTokenProvider.ts @@ -1,23 +1,22 @@ #!/usr/bin/env node /** - * Example demonstrating TokenProvider for simple bearer token authentication. + * Example demonstrating the minimal AuthProvider for bearer token authentication. * - * TokenProvider is a lightweight alternative to OAuthClientProvider for cases - * where tokens are managed externally — e.g., pre-configured API tokens, - * gateway/proxy patterns, or tokens obtained through a separate auth flow. + * AuthProvider is the base interface for all client auth. For simple cases where + * tokens are managed externally — pre-configured API tokens, gateway/proxy patterns, + * or tokens obtained through a separate auth flow — implement only `token()`. + * + * For OAuth flows (client_credentials, private_key_jwt, etc.), use the built-in + * providers which implement both `token()` and `onUnauthorized()`. * * Environment variables: * MCP_SERVER_URL - Server URL (default: http://localhost:3000/mcp) * MCP_TOKEN - Bearer token to use for authentication (required) - * - * Two approaches are demonstrated: - * 1. Using `tokenProvider` option on the transport (simplest) - * 2. Using `withBearerAuth` to wrap a custom fetch function (more flexible) */ -import type { TokenProvider } from '@modelcontextprotocol/client'; -import { Client, StreamableHTTPClientTransport, withBearerAuth } from '@modelcontextprotocol/client'; +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; const DEFAULT_SERVER_URL = process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'; @@ -28,25 +27,16 @@ async function main() { process.exit(1); } - // A TokenProvider is just an async function that returns a token string. - // It is called before every request, so it can handle refresh logic internally. - const tokenProvider: TokenProvider = async () => token; + // AuthProvider with just token() — the simplest possible auth. + // token() is called before every request, so it can handle refresh internally. + // With no onUnauthorized(), a 401 throws UnauthorizedError immediately. + const authProvider: AuthProvider = { + token: async () => token + }; - const client = new Client({ name: 'token-provider-example', version: '1.0.0' }, { capabilities: {} }); + const client = new Client({ name: 'auth-provider-example', version: '1.0.0' }, { capabilities: {} }); - // Approach 1: Pass tokenProvider directly to the transport. - // This is the simplest way to add bearer auth. - const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { - tokenProvider - }); - - // Approach 2 (alternative): Use withBearerAuth to wrap fetch. - // This is useful when you need more control over the fetch behavior, - // or when composing with other fetch wrappers. - // - // const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { - // fetch: withBearerAuth(tokenProvider), - // }); + const transport = new StreamableHTTPClientTransport(new URL(DEFAULT_SERVER_URL), { authProvider }); await client.connect(transport); console.log('Connected successfully.'); @@ -64,6 +54,3 @@ try { // eslint-disable-next-line unicorn/no-process-exit process.exit(1); } - -// Referenced in the commented-out Approach 2 above; kept so uncommenting it type-checks. -void withBearerAuth; diff --git a/packages/client/src/client/auth.examples.ts b/packages/client/src/client/auth.examples.ts index 17c04e6a0..15b6487a7 100644 --- a/packages/client/src/client/auth.examples.ts +++ b/packages/client/src/client/auth.examples.ts @@ -9,8 +9,8 @@ import type { AuthorizationServerMetadata } from '@modelcontextprotocol/core'; -import type { OAuthClientProvider } from './auth.js'; -import { fetchToken } from './auth.js'; +import type { OAuthClientProvider, UnauthorizedContext } from './auth.js'; +import { fetchToken, handleOAuthUnauthorized } from './auth.js'; /** * Base class providing no-op implementations of required OAuthClientProvider methods. @@ -29,6 +29,12 @@ abstract class MyProviderBase implements OAuthClientProvider { tokens(): undefined { return; } + async token(): Promise { + return undefined; + } + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } saveTokens() { return Promise.resolve(); } diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index c47a57f27..bca45ad66 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -34,14 +34,103 @@ export type AddClientAuthentication = ( metadata?: AuthorizationServerMetadata ) => void | Promise; +/** + * Context passed to {@linkcode AuthProvider.onUnauthorized} when the server + * responds with 401. Provides everything needed to refresh credentials. + */ +export interface UnauthorizedContext { + /** The 401 response — inspect `WWW-Authenticate` for resource metadata, scope, etc. */ + response: Response; + /** The MCP server URL, for passing to {@linkcode auth} or discovery helpers. */ + serverUrl: URL; + /** Fetch function configured with the transport's `requestInit`, for making auth requests. */ + fetchFn: FetchLike; +} + +/** + * Minimal interface for authenticating MCP client transports with bearer tokens. + * + * Transports call {@linkcode AuthProvider.token | token()} before every request + * to obtain the current token, and {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * (if provided) when the server responds with 401, giving the provider a chance + * to refresh credentials before the transport retries once. + * + * For simple cases (API keys, gateway-managed tokens), implement only `token()`: + * ```typescript + * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; + * ``` + * + * For OAuth flows, use {@linkcode OAuthClientProvider} which extends this interface, + * or one of the built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.). + */ +export interface AuthProvider { + /** + * Returns the current bearer token, or `undefined` if no token is available. + * Called before every request. + */ + token(): Promise; + + /** + * Called when the server responds with 401. If provided, the transport will + * await this, then retry the request once. If the retry also gets 401, or if + * this method is not provided, the transport throws {@linkcode UnauthorizedError}. + * + * Implementations should refresh tokens, re-authenticate, etc. — whatever is + * needed so the next `token()` call returns a valid token. + */ + onUnauthorized?(ctx: UnauthorizedContext): Promise; +} + +/** + * Type guard: checks whether an `AuthProvider` is a full `OAuthClientProvider`. + * Use this to gate OAuth-specific transport features like `finishAuth()` and + * 403 scope upscoping. + */ +export function isOAuthClientProvider(provider: AuthProvider | undefined): provider is OAuthClientProvider { + return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; +} + +/** + * Default `onUnauthorized` implementation for OAuth providers: extracts + * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. + * Built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.) + * delegate to this. Custom `OAuthClientProvider` implementations can do the same. + */ +export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); + const result = await auth(provider, { + serverUrl: ctx.serverUrl, + resourceMetadataUrl, + scope, + fetchFn: ctx.fetchFn + }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * * This client relies upon a concept of an authorized "session," the exact * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. + * + * Extends {@linkcode AuthProvider} — implementations must provide `token()` + * (typically `return (await this.tokens())?.access_token`) and `onUnauthorized()` + * (typically `return handleOAuthUnauthorized(this, ctx)`). Without `onUnauthorized()`, + * 401 responses throw immediately with no token refresh or reauth. */ -export interface OAuthClientProvider { +export interface OAuthClientProvider extends AuthProvider { + /** + * Runs the OAuth re-authentication flow on 401. Required on `OAuthClientProvider` + * (optional on the base `AuthProvider`) because OAuth providers that omit this lose + * all 401 recovery — no token refresh, no redirect to authorization. + * + * Most implementations should delegate: `return handleOAuthUnauthorized(this, ctx)`. + */ + onUnauthorized(ctx: UnauthorizedContext): Promise; + /** * The URL to redirect the user agent to after authorization. * Return `undefined` for non-interactive flows that don't require user interaction diff --git a/packages/client/src/client/authExtensions.ts b/packages/client/src/client/authExtensions.ts index ae614f7ba..7508298b7 100644 --- a/packages/client/src/client/authExtensions.ts +++ b/packages/client/src/client/authExtensions.ts @@ -8,7 +8,8 @@ import type { FetchLike, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; import type { CryptoKey, JWK } from 'jose'; -import type { AddClientAuthentication, OAuthClientProvider } from './auth.js'; +import type { AddClientAuthentication, OAuthClientProvider, UnauthorizedContext } from './auth.js'; +import { handleOAuthUnauthorized } from './auth.js'; /** * Helper to produce a `private_key_jwt` client authentication function. @@ -150,6 +151,14 @@ export class ClientCredentialsProvider implements OAuthClientProvider { }; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -269,6 +278,14 @@ export class PrivateKeyJwtProvider implements OAuthClientProvider { }); } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -366,6 +383,14 @@ export class StaticPrivateKeyJwtProvider implements OAuthClientProvider { }; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } @@ -564,6 +589,14 @@ export class CrossAppAccessProvider implements OAuthClientProvider { this._fetchFn = options.fetchFn ?? fetch; } + async token(): Promise { + return this._tokens?.access_token; + } + + async onUnauthorized(ctx: UnauthorizedContext): Promise { + await handleOAuthUnauthorized(this, ctx); + } + get redirectUrl(): undefined { return undefined; } diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index e5b04a258..025c785ea 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,9 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; -import type { TokenProvider } from './tokenProvider.js'; +import type { AuthProvider } from './auth.js'; +import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -24,28 +23,19 @@ export type SSEClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the SSE connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode SSEClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode SSEClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected. - */ - authProvider?: OAuthClientProvider; - - /** - * A simple token provider for bearer authentication. - * - * Use this instead of `authProvider` when tokens are managed externally - * (e.g., upfront auth, gateway/proxy patterns, service accounts). - * - * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation. + * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - tokenProvider?: TokenProvider; + authProvider?: AuthProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -82,8 +72,7 @@ export class SSEClientTransport implements Transport { private _scope?: string; private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _tokenProvider?: TokenProvider; + private _authProvider?: AuthProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -99,48 +88,18 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; - this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuth(); - } + private _authRetryInFlight = false; + private _last401Response?: Response; private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } - } else if (this._tokenProvider) { - const token = await this._tokenProvider(); - if (token) { - headers['Authorization'] = `Bearer ${token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._protocolVersion) { headers['mcp-protocol-version'] = this._protocolVersion; @@ -167,10 +126,13 @@ export class SSEClientTransport implements Transport { headers }); - if (response.status === 401 && response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; + if (response.status === 401) { + this._last401Response = response; + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; + } } return response; @@ -179,17 +141,23 @@ export class SSEClientTransport implements Transport { this._abortController = new AbortController(); this._eventSource.onerror = event => { - if (event.code === 401) { - if (this._authProvider) { - this._authThenStart().then(resolve, reject); - return; - } - if (this._tokenProvider) { - const error = new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); - reject(error); - this.onerror?.(error); + if (event.code === 401 && this._authProvider) { + if (this._authProvider.onUnauthorized && this._last401Response && !this._authRetryInFlight) { + this._authRetryInFlight = true; + const response = this._last401Response; + this._authProvider + .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) + .then(() => this._startOrAuth()) + .then(resolve, reject) + .finally(() => { + this._authRetryInFlight = false; + }); return; } + const error = new UnauthorizedError(); + reject(error); + this.onerror?.(error); + return; } const error = new SseError(event.code, event.message, event); @@ -247,8 +215,8 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!isOAuthClientProvider(this._authProvider)) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } const result = await auth(this._authProvider, { @@ -287,38 +255,39 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401) { - if (this._authProvider) { + if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + } - const result = await auth(this._authProvider, { + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, fetchFn: this._fetchWithInit }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } + const text = await response.text?.().catch(() => null); throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); } + this._authRetryInFlight = false; + // Release connection - POST responses don't have content we need await response.text?.().catch(() => {}); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index bbaa08ca9..e6443a90b 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,9 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthResult, OAuthClientProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js'; -import type { TokenProvider } from './tokenProvider.js'; +import type { AuthProvider } from './auth.js'; +import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -86,28 +85,20 @@ export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. * - * When an `authProvider` is specified and the connection is started: - * 1. The connection is attempted with any existing access token from the `authProvider`. - * 2. If the access token has expired, the `authProvider` is used to refresh the token. - * 3. If token refresh fails or no access token exists, and auth is required, {@linkcode OAuthClientProvider.redirectToAuthorization} is called, and an {@linkcode UnauthorizedError} will be thrown from {@linkcode index.Protocol.connect | connect}/{@linkcode StreamableHTTPClientTransport.start | start}. + * {@linkcode AuthProvider.token | token()} is called before every request to obtain the + * bearer token. When the server responds with 401, {@linkcode AuthProvider.onUnauthorized | onUnauthorized()} + * is called (if provided) to refresh credentials, then the request is retried once. If + * the retry also gets 401, or `onUnauthorized` is not provided, {@linkcode UnauthorizedError} + * is thrown. * - * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call {@linkcode StreamableHTTPClientTransport.finishAuth} with the authorization code before retrying the connection. + * For simple bearer tokens: `{ token: async () => myApiKey }`. * - * If an `authProvider` is not provided, and auth is required, an {@linkcode UnauthorizedError} will be thrown. - * - * {@linkcode UnauthorizedError} might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation + * (which extends `AuthProvider`). Interactive flows: after {@linkcode UnauthorizedError}, redirect the + * user, then call {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization + * code before reconnecting. */ - authProvider?: OAuthClientProvider; - - /** - * A simple token provider for bearer authentication. - * - * Use this instead of `authProvider` when tokens are managed externally - * (e.g., upfront auth, gateway/proxy patterns, service accounts). - * - * If both `authProvider` and `tokenProvider` are set, `authProvider` takes precedence. - */ - tokenProvider?: TokenProvider; + authProvider?: AuthProvider; /** * Customizes HTTP requests to the server. @@ -149,14 +140,13 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _scope?: string; private _requestInit?: RequestInit; - private _authProvider?: OAuthClientProvider; - private _tokenProvider?: TokenProvider; + private _authProvider?: AuthProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401 + private _authRetryInFlight = false; // Circuit breaker: single retry per operation on 401 private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -171,7 +161,6 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; - this._tokenProvider = opts?.tokenProvider; this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -179,43 +168,11 @@ export class StreamableHTTPClientTransport implements Transport { this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } - private async _authThenStart(): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); - } - - let result: AuthResult; - try { - result = await auth(this._authProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - } catch (error) { - this.onerror?.(error as Error); - throw error; - } - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return await this._startOrAuthSse({ resumptionToken: undefined }); - } - private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; - if (this._authProvider) { - const tokens = await this._authProvider.tokens(); - if (tokens) { - headers['Authorization'] = `Bearer ${tokens.access_token}`; - } - } else if (this._tokenProvider) { - const token = await this._tokenProvider(); - if (token) { - headers['Authorization'] = `Bearer ${token}`; - } + const token = await this._authProvider?.token(); + if (token) { + headers['Authorization'] = `Bearer ${token}`; } if (this._sessionId) { @@ -255,17 +212,34 @@ export class StreamableHTTPClientTransport implements Transport { }); if (!response.ok) { - await response.text?.().catch(() => {}); - if (response.status === 401) { - if (this._authProvider) { - return await this._authThenStart(); + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + try { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + return await this._startOrAuthSse(options); + } finally { + this._authRetryInFlight = false; + } + } + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } + await response.text?.().catch(() => {}); + // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { @@ -461,8 +435,8 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!this._authProvider) { - throw new UnauthorizedError('No auth provider'); + if (!isOAuthClientProvider(this._authProvider)) { + throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } const result = await auth(this._authProvider, { @@ -522,47 +496,33 @@ export class StreamableHTTPClientTransport implements Transport { } if (!response.ok) { - const text = await response.text?.().catch(() => null); - if (response.status === 401) { - if (this._authProvider) { - // Prevent infinite recursion when server returns 401 after successful auth - if (this._hasCompletedAuthFlow) { - throw new SdkError( - SdkErrorCode.ClientHttpAuthentication, - 'Server returned 401 after successful authentication', - { - status: 401, - text - } - ); - } - + // Store WWW-Authenticate params for interactive finishAuth() path + if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; + } - const result = await auth(this._authProvider, { + if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + this._authRetryInFlight = true; + await this._authProvider.onUnauthorized({ + response, serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, fetchFn: this._fetchWithInit }); - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - // Mark that we completed auth flow - this._hasCompletedAuthFlow = true; // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } - if (this._tokenProvider) { - throw new UnauthorizedError('Server returned 401 — token from tokenProvider was rejected'); + if (this._authProvider) { + await response.text?.().catch(() => {}); + throw new UnauthorizedError(); } } - if (response.status === 403 && this._authProvider) { + const text = await response.text?.().catch(() => null); + + if (response.status === 403 && isOAuthClientProvider(this._authProvider)) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -608,7 +568,7 @@ export class StreamableHTTPClientTransport implements Transport { } // Reset auth loop flag on successful response - this._hasCompletedAuthFlow = false; + this._authRetryInFlight = false; this._lastUpscopingHeader = undefined; // If the response is 202 Accepted, there's no body to process @@ -658,6 +618,7 @@ export class StreamableHTTPClientTransport implements Transport { await response.text?.().catch(() => {}); } } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/tokenProvider.ts b/packages/client/src/client/tokenProvider.ts deleted file mode 100644 index ab8f2bbc9..000000000 --- a/packages/client/src/client/tokenProvider.ts +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Minimal interface for providing bearer tokens to MCP transports. - * - * Unlike `OAuthClientProvider` which assumes interactive browser-redirect OAuth, - * `TokenProvider` is a simple function that returns a token string. - * Use this for upfront auth, gateway/proxy patterns, service accounts, - * or any scenario where tokens are managed externally. - * - * The provider is called before every request. If the server responds with 401, - * the transport throws `UnauthorizedError` without retrying — the provider is - * assumed to have already returned its freshest token. Catch `UnauthorizedError` - * to invalidate any external cache and reconnect. - * - * @example - * ```typescript - * // Static token - * const provider: TokenProvider = async () => "my-api-token"; - * - * // Token from secure storage with refresh - * const provider: TokenProvider = async () => { - * const token = await storage.getToken(); - * if (isExpiringSoon(token)) { - * return (await refreshToken(token)).accessToken; - * } - * return token.accessToken; - * }; - * ``` - */ -export type TokenProvider = () => Promise; - -/** - * Wraps a fetch function to automatically inject Bearer authentication headers. - * - * @example - * ```typescript - * const authedFetch = withBearerAuth(async () => getStoredToken()); - * const transport = new StreamableHTTPClientTransport(url, { fetch: authedFetch }); - * ``` - */ -export function withBearerAuth( - getToken: TokenProvider, - fetchFn: (url: string | URL, init?: RequestInit) => Promise = globalThis.fetch -): (url: string | URL, init?: RequestInit) => Promise { - return async (url, init) => { - const token = await getToken(); - if (token) { - const headers = new Headers(init?.headers); - headers.set('Authorization', `Bearer ${token}`); - return fetchFn(url, { ...init, headers }); - } - return fetchFn(url, init); - }; -} diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index b72b3e2d9..c37d9fe28 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -6,7 +6,6 @@ export * from './client/middleware.js'; export * from './client/sse.js'; export * from './client/stdio.js'; export * from './client/streamableHttp.js'; -export * from './client/tokenProvider.js'; export * from './client/websocket.js'; // experimental exports diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 9d8f5cf6b..12d6793af 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -1038,6 +1038,8 @@ describe('OAuth Authorization', () => { client_secret: 'test-client-secret' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -1983,6 +1985,8 @@ describe('OAuth Authorization', () => { }, clientInformation: vi.fn(), tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2056,6 +2060,8 @@ describe('OAuth Authorization', () => { client_id: 'client-id' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2971,6 +2977,8 @@ describe('OAuth Authorization', () => { client_secret: 'secret123' }), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -3424,6 +3432,8 @@ describe('OAuth Authorization', () => { clientInformation: vi.fn().mockResolvedValue(undefined), saveClientInformation: vi.fn().mockResolvedValue(undefined), tokens: vi.fn().mockResolvedValue(undefined), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn().mockResolvedValue(undefined), saveCodeVerifier: vi.fn().mockResolvedValue(undefined), diff --git a/packages/client/test/client/middleware.test.ts b/packages/client/test/client/middleware.test.ts index 64bbfa673..d2084af99 100644 --- a/packages/client/test/client/middleware.test.ts +++ b/packages/client/test/client/middleware.test.ts @@ -33,6 +33,8 @@ describe('withOAuth', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), @@ -759,6 +761,8 @@ describe('Integration Tests', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), + token: vi.fn(async () => undefined), + onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 0b0aff67b..3e1a3f895 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -7,8 +7,8 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { UnauthorizedError } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; +import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; /** @@ -430,11 +430,15 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), + token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; }); @@ -443,6 +447,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -451,7 +456,7 @@ describe('SSEClientTransport', () => { await transport.start(); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attaches custom header from provider on initial SSE connection', async () => { @@ -459,6 +464,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' }; @@ -474,7 +480,7 @@ describe('SSEClientTransport', () => { expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attaches auth header from provider on POST requests', async () => { @@ -482,6 +488,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -499,7 +506,7 @@ describe('SSEClientTransport', () => { await transport.send(message); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.tokens).toHaveBeenCalled(); + expect(mockAuthProvider.token).toHaveBeenCalled(); }); it('attempts auth flow on 401 during SSE connection', async () => { @@ -631,6 +638,7 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); + mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' @@ -666,6 +674,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -795,6 +804,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -948,6 +958,7 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); + mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -1218,11 +1229,15 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn().mockResolvedValue(clientInfo), tokens: vi.fn().mockResolvedValue(tokens), + token: vi.fn(async () => tokens?.access_token), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn().mockResolvedValue('test-verifier'), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; }; @@ -1528,4 +1543,99 @@ describe('SSEClientTransport', () => { expect(globalFetchSpy).not.toHaveBeenCalled(); }); }); + + describe('minimal AuthProvider (non-OAuth)', () => { + let postResponses: number[]; + let postCount: number; + + async function setupServer(): Promise { + await resourceServer.close(); + + postCount = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.method === 'GET') { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + return; + } + + if (req.method === 'POST') { + const status = postResponses[postCount] ?? 200; + postCount++; + res.writeHead(status).end(); + return; + } + }); + + resourceBaseUrl = await listenOnRandomPort(resourceServer); + } + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: '1' }; + + it('throws UnauthorizedError on POST 401 when onUnauthorized is not provided', async () => { + postResponses = [401]; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + }); + + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws', async () => { + postResponses = [401, 401]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(postCount).toBe(2); + }); + + it('resets retry guard when onUnauthorized throws, allowing retry on next send', async () => { + postResponses = [401, 401, 200]; + await setupServer(); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + // First send: 401 → onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + + // Second send: flag should be reset, so 401 → onUnauthorized (succeeds) → retry → 200 + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(postCount).toBe(3); + }); + + it('throws when finishAuth is called with a non-OAuth AuthProvider', async () => { + postResponses = []; + await setupServer(); + + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + await transport.start(); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + }); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 3830fcd1d..4059bb981 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -3,7 +3,7 @@ import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontex import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { UnauthorizedError } from '../../src/client/auth.js'; +import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -21,11 +21,15 @@ describe('StreamableHTTPClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), + token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn() + invalidateCredentials: vi.fn(), + onUnauthorized: vi.fn(async ctx => { + await handleOAuthUnauthorized(mockAuthProvider, ctx); + }) }; transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); vi.spyOn(globalThis, 'fetch'); @@ -1705,7 +1709,7 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow('Server returned 401 after successful authentication'); + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index 111a7b6a5..c683a4012 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -1,68 +1,11 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; import type { Mock } from 'vitest'; -import type { TokenProvider } from '../../src/client/tokenProvider.js'; -import { withBearerAuth } from '../../src/client/tokenProvider.js'; -import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; +import type { AuthProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; +import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; -describe('withBearerAuth', () => { - it('should inject Authorization header when token is available', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => 'test-token-123'; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { method: 'POST' }); - - expect(mockFetch).toHaveBeenCalledOnce(); - const [url, init] = mockFetch.mock.calls[0]!; - expect(url).toBe('https://example.com/api'); - expect(new Headers(init.headers).get('Authorization')).toBe('Bearer test-token-123'); - }); - - it('should not inject Authorization header when token is undefined', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => undefined; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { method: 'POST' }); - - expect(mockFetch).toHaveBeenCalledOnce(); - const [, init] = mockFetch.mock.calls[0]!; - expect(new Headers(init?.headers).has('Authorization')).toBe(false); - }); - - it('should preserve existing headers', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - const getToken: TokenProvider = async () => 'my-token'; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/api', { - headers: { 'Content-Type': 'application/json', 'X-Custom': 'value' } - }); - - const [, init] = mockFetch.mock.calls[0]!; - const headers = new Headers(init.headers); - expect(headers.get('Authorization')).toBe('Bearer my-token'); - expect(headers.get('Content-Type')).toBe('application/json'); - expect(headers.get('X-Custom')).toBe('value'); - }); - - it('should call getToken on every request', async () => { - const mockFetch = vi.fn().mockResolvedValue(new Response('ok')); - let callCount = 0; - const getToken: TokenProvider = async () => `token-${++callCount}`; - - const authedFetch = withBearerAuth(getToken, mockFetch); - await authedFetch('https://example.com/1'); - await authedFetch('https://example.com/2'); - - expect(new Headers(mockFetch.mock.calls[0]![1]!.headers).get('Authorization')).toBe('Bearer token-1'); - expect(new Headers(mockFetch.mock.calls[1]![1]!.headers).get('Authorization')).toBe('Bearer token-2'); - }); -}); - -describe('StreamableHTTPClientTransport with tokenProvider', () => { +describe('StreamableHTTPClientTransport with AuthProvider', () => { let transport: StreamableHTTPClientTransport; afterEach(async () => { @@ -70,48 +13,28 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { vi.clearAllMocks(); }); - it('should set Authorization header from tokenProvider', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'my-bearer-token'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'test', params: {}, id: 'test-id' }; + + it('should set Authorization header from AuthProvider.token()', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'my-bearer-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); - expect(tokenProvider).toHaveBeenCalled(); + expect(authProvider.token).toHaveBeenCalled(); const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; expect(init.headers.get('Authorization')).toBe('Bearer my-bearer-token'); }); - it('should not set Authorization header when tokenProvider returns undefined', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => undefined); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + it('should not set Authorization header when token() returns undefined', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); @@ -119,18 +42,11 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { expect(init.headers.has('Authorization')).toBe(false); }); - it('should throw UnauthorizedError on 401 when using tokenProvider', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'rejected-token'); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { tokenProvider }); + it('should throw UnauthorizedError on 401 when onUnauthorized is not provided', async () => { + const authProvider: AuthProvider = { token: vi.fn(async () => 'rejected-token') }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: false, status: 401, @@ -139,70 +55,125 @@ describe('StreamableHTTPClientTransport with tokenProvider', () => { }); await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); - expect(tokenProvider).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(1); }); - it('should prefer authProvider over tokenProvider when both are set', async () => { - const tokenProvider: TokenProvider = vi.fn(async () => 'token-provider-value'); - const authProvider = { - get redirectUrl() { - return 'http://localhost/callback'; - }, - get clientMetadata() { - return { redirect_uris: ['http://localhost/callback'] }; - }, - clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-secret' })), - tokens: vi.fn(() => ({ access_token: 'auth-provider-value', token_type: 'bearer' })), - saveTokens: vi.fn(), - redirectToAuthorization: vi.fn(), - saveCodeVerifier: vi.fn(), - codeVerifier: vi.fn() + it('should call onUnauthorized and retry once on 401', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); + + await transport.send(message); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); - transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider, tokenProvider }); + it('should throw UnauthorizedError if retry after onUnauthorized also gets 401', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'still-bad'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + }); + + it('should reset retry guard when onUnauthorized throws, allowing retry on next send', async () => { + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn().mockRejectedValueOnce(new Error('transient network error')).mockResolvedValueOnce(undefined) }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); - await transport.send(message); + // First send: onUnauthorized throws transient error + await expect(transport.send(message)).rejects.toThrow('transient network error'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); - // authProvider should be used, not tokenProvider - expect(tokenProvider).not.toHaveBeenCalled(); - const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; - expect(init.headers.get('Authorization')).toBe('Bearer auth-provider-value'); + // Second send: flag should be reset, so onUnauthorized gets a second chance + await transport.send(message); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); }); - it('should work with no auth at all', async () => { + it('should work with no authProvider at all', async () => { transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp')); vi.spyOn(globalThis, 'fetch'); - const message: JSONRPCMessage = { - jsonrpc: '2.0', - method: 'test', - params: {}, - id: 'test-id' - }; - - (globalThis.fetch as Mock).mockResolvedValueOnce({ - ok: true, - status: 202, - headers: new Headers() - }); + (globalThis.fetch as Mock).mockResolvedValueOnce({ ok: true, status: 202, headers: new Headers() }); await transport.send(message); const [, init] = (globalThis.fetch as Mock).mock.calls[0]!; expect(init.headers.has('Authorization')).toBe(false); }); + + it('should throw when finishAuth is called with a non-OAuth AuthProvider', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + + await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); + }); + + it('should throw UnauthorizedError on GET-SSE 401 with no onUnauthorized (via resumeStream)', async () => { + const authProvider: AuthProvider = { token: async () => 'api-key' }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + (globalThis.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + headers: new Headers(), + text: async () => 'unauthorized' + }); + + await expect(transport.resumeStream('last-event-id')).rejects.toThrow(UnauthorizedError); + }); + + it('should call onUnauthorized and retry on GET-SSE 401 (via resumeStream)', async () => { + let currentToken = 'old-token'; + const authProvider: AuthProvider = { + token: vi.fn(async () => currentToken), + onUnauthorized: vi.fn(async () => { + currentToken = 'new-token'; + }) + }; + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider }); + vi.spyOn(globalThis, 'fetch'); + + // First GET: 401. Second GET (retry): 405 (server doesn't offer SSE — clean exit) + (globalThis.fetch as Mock) + .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) + .mockResolvedValueOnce({ ok: false, status: 405, headers: new Headers(), text: async () => '' }); + + await transport.resumeStream('last-event-id'); + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); + expect(authProvider.token).toHaveBeenCalledTimes(2); + const [, retryInit] = (globalThis.fetch as Mock).mock.calls[1]!; + expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); + }); }); From 91d03cafca4e110afee4a67f7a341f2a1d9fd46b Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 21:44:00 +0000 Subject: [PATCH 03/13] refactor: adapt OAuthClientProvider at transport boundary (non-breaking) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Alternative to the breaking 'extends AuthProvider' approach. Instead of requiring OAuthClientProvider implementations to add token() + onUnauthorized(), the transport constructor classifies the authProvider option once and adapts OAuth providers via adaptOAuthProvider(). - OAuthClientProvider interface is unchanged from v1 - Transport option: authProvider?: AuthProvider | OAuthClientProvider - Constructor: if OAuth, store both original (for finishAuth/403) and adapted (for _commonHeaders/401) — classification happens once, no runtime type guards in the hot path - 4 built-in providers no longer need token()/onUnauthorized() - migration.md/migration-SKILL.md entries removed — nothing to migrate - Changeset downgraded to minor Net -142 lines vs the breaking approach. Same transport simplification, zero migration burden. Duck-typing via isOAuthClientProvider() ('tokens' + 'clientMetadata' in provider) at construction only. --- .changeset/token-provider-composable-auth.md | 19 ++++--- docs/migration-SKILL.md | 33 ------------ docs/migration.md | 38 -------------- .../client/src/simpleOAuthClientProvider.ts | 17 +------ packages/client/src/client/auth.examples.ts | 10 +--- packages/client/src/client/auth.ts | 50 ++++++++++--------- packages/client/src/client/authExtensions.ts | 35 +------------ packages/client/src/client/sse.ts | 18 ++++--- packages/client/src/client/streamableHttp.ts | 29 +++++++---- packages/client/test/client/auth.test.ts | 10 ---- .../client/test/client/middleware.test.ts | 4 -- packages/client/test/client/sse.test.ts | 27 +++------- .../client/test/client/streamableHttp.test.ts | 8 +-- 13 files changed, 78 insertions(+), 220 deletions(-) diff --git a/.changeset/token-provider-composable-auth.md b/.changeset/token-provider-composable-auth.md index c4ea7f5e3..f5c064e7f 100644 --- a/.changeset/token-provider-composable-auth.md +++ b/.changeset/token-provider-composable-auth.md @@ -1,17 +1,16 @@ --- -'@modelcontextprotocol/client': major +'@modelcontextprotocol/client': minor --- -Unify client auth around a minimal `AuthProvider` interface - -**Breaking:** Transport `authProvider` option now accepts the new minimal `AuthProvider` interface instead of being typed as `OAuthClientProvider`. `OAuthClientProvider` now extends `AuthProvider`, so most existing code continues to work — but custom implementations must add a `token()` method. +Add `AuthProvider` for composable bearer-token auth; transports adapt `OAuthClientProvider` automatically - New `AuthProvider` interface: `{ token(): Promise; onUnauthorized?(ctx): Promise }`. Transports call `token()` before every request and `onUnauthorized()` on 401 (then retry once). -- `OAuthClientProvider` extends `AuthProvider`. Custom implementations must add `token()` (typically `return (await this.tokens())?.access_token`) and optionally `onUnauthorized()` (typically `return handleOAuthUnauthorized(this, ctx)`). -- Built-in providers (`ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, `CrossAppAccessProvider`) implement both methods — existing user code is unchanged. -- New `handleOAuthUnauthorized(provider, ctx)` helper runs the standard OAuth flow from `onUnauthorized`. -- New `isOAuthClientProvider()` type guard for gating OAuth-specific transport features like `finishAuth()`. -- Transports no longer inline OAuth orchestration — ~50 lines of `auth()` calls, WWW-Authenticate parsing, and circuit-breaker state moved into `onUnauthorized()` implementations. +- Transport `authProvider` option now accepts `AuthProvider | OAuthClientProvider`. OAuth providers are adapted internally via `adaptOAuthProvider()` — no changes needed to existing `OAuthClientProvider` implementations. +- For simple bearer tokens (API keys, gateway-managed tokens, service accounts): `{ authProvider: { token: async () => myKey } }` — one-line object literal, no class. +- New `adaptOAuthProvider(provider)` export for explicit adaptation. +- New `handleOAuthUnauthorized(provider, ctx)` helper — the standard OAuth `onUnauthorized` behavior. +- New `isOAuthClientProvider()` type guard. +- New `UnauthorizedContext` type. - Exported previously-internal auth helpers for building custom flows: `applyBasicAuth`, `applyPostAuth`, `applyPublicAuth`, `executeTokenRequest`. -See `docs/migration.md` for before/after examples. +Transports are simplified internally — ~50 lines of inline OAuth orchestration (auth() calls, WWW-Authenticate parsing, circuit-breaker state) moved into the adapter's `onUnauthorized()` implementation. `OAuthClientProvider` itself is unchanged. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index d942277eb..2832d3248 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -203,39 +203,6 @@ import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; if (error instanceof OAuthError && error.code === OAuthErrorCode.InvalidClient) { ... } ``` -### Client `OAuthClientProvider` now extends `AuthProvider` - -Transport `authProvider` options now accept the minimal `AuthProvider` interface. `OAuthClientProvider` extends it, so built-in providers work unchanged — custom implementations must add `token()`. - -| v1 pattern | v2 equivalent | -| ----------------------------------------------------- | --------------------------------------------------------------------------- | -| `authProvider?: OAuthClientProvider` (option type) | `authProvider?: AuthProvider` (accepts `OAuthClientProvider` via extension) | -| Transport reads `authProvider.tokens()?.access_token` | Transport calls `authProvider.token()` | -| Transport inlines `auth()` on 401 | Transport calls `authProvider.onUnauthorized()` then retries once | -| `_hasCompletedAuthFlow` circuit breaker | `_authRetryInFlight` circuit breaker | -| N/A | `handleOAuthUnauthorized(provider, ctx)` — standard `onUnauthorized` impl | -| N/A | `isOAuthClientProvider(provider)` — type guard | -| N/A | `UnauthorizedContext` — `{ response, serverUrl, fetchFn }` | - -**For custom `OAuthClientProvider` implementations**, add both methods (both required — TypeScript enforces this): - -```typescript -async token(): Promise { - return (await this.tokens())?.access_token; -} - -async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); -} -``` - -**For simple bearer tokens** (previously required stubbing 8 `OAuthClientProvider` members): - -```typescript -// v2: one-liner -const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; -``` - **Unchanged APIs** (only import paths changed): `Client` constructor and most methods, `McpServer` constructor, `server.connect()`, `server.close()`, all client transports (`StreamableHTTPClientTransport`, `SSEClientTransport`, `StdioClientTransport`), `StdioServerTransport`, all Zod schemas, all callback return types. Note: `callTool()` and `request()` signatures changed (schema parameter removed, see section 11). diff --git a/docs/migration.md b/docs/migration.md index 6703482f1..9927ab2ea 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -669,44 +669,6 @@ The new design: - `ProtocolError` with `ProtocolErrorCode`: For errors that are serialized and sent as JSON-RPC error responses - `SdkError` with `SdkErrorCode`: For local errors that are thrown/rejected locally and never leave the SDK -### Client `authProvider` unified around `AuthProvider` - -Transport `authProvider` options now accept the minimal `AuthProvider` interface rather than being typed as `OAuthClientProvider`. `OAuthClientProvider` extends `AuthProvider`, so built-in providers and most existing code continue to work unchanged — but custom -`OAuthClientProvider` implementations must add a `token()` method. - -**What changed:** transports now call `authProvider.token()` before every request (instead of `authProvider.tokens()?.access_token`), and call `authProvider.onUnauthorized()` on 401 (instead of inlining OAuth orchestration). One code path handles both simple bearer tokens and -full OAuth. - -**If you implement `OAuthClientProvider` directly** (the interactive browser-redirect pattern), add: - -```ts -class MyProvider implements OAuthClientProvider { - // ...existing 8 required members... - - // Required: return the current access token - async token(): Promise { - return (await this.tokens())?.access_token; - } - - // Required: runs the OAuth flow on 401 — without this, 401 throws with no recovery - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } -} -``` - -**If you use `ClientCredentialsProvider`, `PrivateKeyJwtProvider`, `StaticPrivateKeyJwtProvider`, or `CrossAppAccessProvider`** — no change. These already implement both methods. - -**If you have simple bearer tokens** (API keys, gateway tokens, externally-managed tokens), you can now skip `OAuthClientProvider` entirely: - -```ts -// Before: had to implement 8 OAuthClientProvider members with no-op stubs -// After: -const transport = new StreamableHTTPClientTransport(url, { - authProvider: { token: async () => process.env.API_KEY } -}); -``` - ### OAuth error refactoring The OAuth error classes have been consolidated into a single `OAuthError` class with an `OAuthErrorCode` enum. diff --git a/examples/client/src/simpleOAuthClientProvider.ts b/examples/client/src/simpleOAuthClientProvider.ts index 6248d1f90..96655c9f6 100644 --- a/examples/client/src/simpleOAuthClientProvider.ts +++ b/examples/client/src/simpleOAuthClientProvider.ts @@ -1,11 +1,4 @@ -import type { - OAuthClientInformationMixed, - OAuthClientMetadata, - OAuthClientProvider, - OAuthTokens, - UnauthorizedContext -} from '@modelcontextprotocol/client'; -import { handleOAuthUnauthorized } from '@modelcontextprotocol/client'; +import type { OAuthClientInformationMixed, OAuthClientMetadata, OAuthClientProvider, OAuthTokens } from '@modelcontextprotocol/client'; /** * In-memory OAuth client provider for demonstration purposes @@ -31,14 +24,6 @@ export class InMemoryOAuthClientProvider implements OAuthClientProvider { private _onRedirect: (url: URL) => void; - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): string | URL { return this._redirectUrl; } diff --git a/packages/client/src/client/auth.examples.ts b/packages/client/src/client/auth.examples.ts index 15b6487a7..17c04e6a0 100644 --- a/packages/client/src/client/auth.examples.ts +++ b/packages/client/src/client/auth.examples.ts @@ -9,8 +9,8 @@ import type { AuthorizationServerMetadata } from '@modelcontextprotocol/core'; -import type { OAuthClientProvider, UnauthorizedContext } from './auth.js'; -import { fetchToken, handleOAuthUnauthorized } from './auth.js'; +import type { OAuthClientProvider } from './auth.js'; +import { fetchToken } from './auth.js'; /** * Base class providing no-op implementations of required OAuthClientProvider methods. @@ -29,12 +29,6 @@ abstract class MyProviderBase implements OAuthClientProvider { tokens(): undefined { return; } - async token(): Promise { - return undefined; - } - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } saveTokens() { return Promise.resolve(); } diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index bca45ad66..d26bb5727 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -60,8 +60,8 @@ export interface UnauthorizedContext { * const authProvider: AuthProvider = { token: async () => process.env.API_KEY }; * ``` * - * For OAuth flows, use {@linkcode OAuthClientProvider} which extends this interface, - * or one of the built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.). + * For OAuth flows, pass an {@linkcode OAuthClientProvider} directly — transports + * accept either shape and adapt OAuth providers automatically via {@linkcode adaptOAuthProvider}. */ export interface AuthProvider { /** @@ -82,19 +82,17 @@ export interface AuthProvider { } /** - * Type guard: checks whether an `AuthProvider` is a full `OAuthClientProvider`. - * Use this to gate OAuth-specific transport features like `finishAuth()` and - * 403 scope upscoping. + * Type guard distinguishing `OAuthClientProvider` from a minimal `AuthProvider`. + * Transports use this at construction time to classify the `authProvider` option. */ -export function isOAuthClientProvider(provider: AuthProvider | undefined): provider is OAuthClientProvider { +export function isOAuthClientProvider(provider: AuthProvider | OAuthClientProvider | undefined): provider is OAuthClientProvider { return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; } /** - * Default `onUnauthorized` implementation for OAuth providers: extracts + * Standard `onUnauthorized` behavior for OAuth providers: extracts * `WWW-Authenticate` parameters from the 401 response and runs {@linkcode auth}. - * Built-in providers ({@linkcode index.ClientCredentialsProvider | ClientCredentialsProvider} etc.) - * delegate to this. Custom `OAuthClientProvider` implementations can do the same. + * Used by {@linkcode adaptOAuthProvider} to bridge `OAuthClientProvider` to `AuthProvider`. */ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx: UnauthorizedContext): Promise { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(ctx.response); @@ -109,6 +107,22 @@ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx } } +/** + * Adapts an `OAuthClientProvider` to the minimal `AuthProvider` interface that + * transports consume. Called once at transport construction — the transport stores + * the adapted provider for `_commonHeaders()` and 401 handling, while keeping the + * original `OAuthClientProvider` for OAuth-specific paths (`finishAuth()`, 403 upscoping). + */ +export function adaptOAuthProvider(provider: OAuthClientProvider): AuthProvider { + return { + token: async () => { + const tokens = await provider.tokens(); + return tokens?.access_token; + }, + onUnauthorized: async ctx => handleOAuthUnauthorized(provider, ctx) + }; +} + /** * Implements an end-to-end OAuth client to be used with one MCP server. * @@ -116,21 +130,11 @@ export async function handleOAuthUnauthorized(provider: OAuthClientProvider, ctx * meaning of which is application-defined. Tokens, authorization codes, and * code verifiers should not cross different sessions. * - * Extends {@linkcode AuthProvider} — implementations must provide `token()` - * (typically `return (await this.tokens())?.access_token`) and `onUnauthorized()` - * (typically `return handleOAuthUnauthorized(this, ctx)`). Without `onUnauthorized()`, - * 401 responses throw immediately with no token refresh or reauth. + * Transports accept `OAuthClientProvider` directly via the `authProvider` option — + * they adapt it to {@linkcode AuthProvider} internally via {@linkcode adaptOAuthProvider}. + * No changes are needed to existing implementations. */ -export interface OAuthClientProvider extends AuthProvider { - /** - * Runs the OAuth re-authentication flow on 401. Required on `OAuthClientProvider` - * (optional on the base `AuthProvider`) because OAuth providers that omit this lose - * all 401 recovery — no token refresh, no redirect to authorization. - * - * Most implementations should delegate: `return handleOAuthUnauthorized(this, ctx)`. - */ - onUnauthorized(ctx: UnauthorizedContext): Promise; - +export interface OAuthClientProvider { /** * The URL to redirect the user agent to after authorization. * Return `undefined` for non-interactive flows that don't require user interaction diff --git a/packages/client/src/client/authExtensions.ts b/packages/client/src/client/authExtensions.ts index 7508298b7..ae614f7ba 100644 --- a/packages/client/src/client/authExtensions.ts +++ b/packages/client/src/client/authExtensions.ts @@ -8,8 +8,7 @@ import type { FetchLike, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; import type { CryptoKey, JWK } from 'jose'; -import type { AddClientAuthentication, OAuthClientProvider, UnauthorizedContext } from './auth.js'; -import { handleOAuthUnauthorized } from './auth.js'; +import type { AddClientAuthentication, OAuthClientProvider } from './auth.js'; /** * Helper to produce a `private_key_jwt` client authentication function. @@ -151,14 +150,6 @@ export class ClientCredentialsProvider implements OAuthClientProvider { }; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -278,14 +269,6 @@ export class PrivateKeyJwtProvider implements OAuthClientProvider { }); } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -383,14 +366,6 @@ export class StaticPrivateKeyJwtProvider implements OAuthClientProvider { }; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } @@ -589,14 +564,6 @@ export class CrossAppAccessProvider implements OAuthClientProvider { this._fetchFn = options.fetchFn ?? fetch; } - async token(): Promise { - return this._tokens?.access_token; - } - - async onUnauthorized(ctx: UnauthorizedContext): Promise { - await handleOAuthUnauthorized(this, ctx); - } - get redirectUrl(): undefined { return undefined; } diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 025c785ea..c613bd2b1 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -3,8 +3,8 @@ import { createFetchWithInit, JSONRPCMessageSchema, normalizeHeaders, SdkError, import type { ErrorEvent, EventSourceInit } from 'eventsource'; import { EventSource } from 'eventsource'; -import type { AuthProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; export class SseError extends Error { constructor( @@ -35,7 +35,7 @@ export type SSEClientTransportOptions = { * Interactive flows: after {@linkcode UnauthorizedError}, redirect the user, then call * {@linkcode SSEClientTransport.finishAuth | finishAuth} with the authorization code before reconnecting. */ - authProvider?: AuthProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes the initial SSE request to the server (the request that begins the stream). @@ -73,6 +73,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _protocolVersion?: string; @@ -87,7 +88,12 @@ export class SSEClientTransport implements Transport { this._scope = undefined; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } @@ -215,11 +221,11 @@ export class SSEClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!isOAuthClientProvider(this._authProvider)) { + if (!this._oauthProvider) { throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index e6443a90b..a5fd0500f 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -13,8 +13,8 @@ import { } from '@modelcontextprotocol/core'; import { EventSourceParserStream } from 'eventsource-parser/stream'; -import type { AuthProvider } from './auth.js'; -import { auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; +import type { AuthProvider, OAuthClientProvider } from './auth.js'; +import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -94,11 +94,12 @@ export type StreamableHTTPClientTransportOptions = { * For simple bearer tokens: `{ token: async () => myApiKey }`. * * For OAuth flows, pass an {@linkcode index.OAuthClientProvider | OAuthClientProvider} implementation - * (which extends `AuthProvider`). Interactive flows: after {@linkcode UnauthorizedError}, redirect the - * user, then call {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization - * code before reconnecting. + * directly — the transport adapts it to `AuthProvider` internally. Interactive flows: after + * {@linkcode UnauthorizedError}, redirect the user, then call + * {@linkcode StreamableHTTPClientTransport.finishAuth | finishAuth} with the authorization code before + * reconnecting. */ - authProvider?: AuthProvider; + authProvider?: AuthProvider | OAuthClientProvider; /** * Customizes HTTP requests to the server. @@ -141,6 +142,7 @@ export class StreamableHTTPClientTransport implements Transport { private _scope?: string; private _requestInit?: RequestInit; private _authProvider?: AuthProvider; + private _oauthProvider?: OAuthClientProvider; private _fetch?: FetchLike; private _fetchWithInit: FetchLike; private _sessionId?: string; @@ -160,7 +162,12 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._scope = undefined; this._requestInit = opts?.requestInit; - this._authProvider = opts?.authProvider; + if (isOAuthClientProvider(opts?.authProvider)) { + this._oauthProvider = opts.authProvider; + this._authProvider = adaptOAuthProvider(opts.authProvider); + } else { + this._authProvider = opts?.authProvider; + } this._fetch = opts?.fetch; this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); this._sessionId = opts?.sessionId; @@ -435,11 +442,11 @@ export class StreamableHTTPClientTransport implements Transport { * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. */ async finishAuth(authorizationCode: string): Promise { - if (!isOAuthClientProvider(this._authProvider)) { + if (!this._oauthProvider) { throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, @@ -522,7 +529,7 @@ export class StreamableHTTPClientTransport implements Transport { const text = await response.text?.().catch(() => null); - if (response.status === 403 && isOAuthClientProvider(this._authProvider)) { + if (response.status === 403 && this._oauthProvider) { const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); if (error === 'insufficient_scope') { @@ -546,7 +553,7 @@ export class StreamableHTTPClientTransport implements Transport { // Mark that upscoping was tried. this._lastUpscopingHeader = wwwAuthHeader ?? undefined; - const result = await auth(this._authProvider, { + const result = await auth(this._oauthProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, diff --git a/packages/client/test/client/auth.test.ts b/packages/client/test/client/auth.test.ts index 12d6793af..9d8f5cf6b 100644 --- a/packages/client/test/client/auth.test.ts +++ b/packages/client/test/client/auth.test.ts @@ -1038,8 +1038,6 @@ describe('OAuth Authorization', () => { client_secret: 'test-client-secret' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -1985,8 +1983,6 @@ describe('OAuth Authorization', () => { }, clientInformation: vi.fn(), tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2060,8 +2056,6 @@ describe('OAuth Authorization', () => { client_id: 'client-id' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -2977,8 +2971,6 @@ describe('OAuth Authorization', () => { client_secret: 'secret123' }), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), @@ -3432,8 +3424,6 @@ describe('OAuth Authorization', () => { clientInformation: vi.fn().mockResolvedValue(undefined), saveClientInformation: vi.fn().mockResolvedValue(undefined), tokens: vi.fn().mockResolvedValue(undefined), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn().mockResolvedValue(undefined), redirectToAuthorization: vi.fn().mockResolvedValue(undefined), saveCodeVerifier: vi.fn().mockResolvedValue(undefined), diff --git a/packages/client/test/client/middleware.test.ts b/packages/client/test/client/middleware.test.ts index d2084af99..64bbfa673 100644 --- a/packages/client/test/client/middleware.test.ts +++ b/packages/client/test/client/middleware.test.ts @@ -33,8 +33,6 @@ describe('withOAuth', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), @@ -761,8 +759,6 @@ describe('Integration Tests', () => { return { redirect_uris: ['http://localhost/callback'] }; }, tokens: vi.fn(), - token: vi.fn(async () => undefined), - onUnauthorized: vi.fn(async () => {}), saveTokens: vi.fn(), clientInformation: vi.fn(), redirectToAuthorization: vi.fn(), diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 3e1a3f895..10fcd76bc 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -8,7 +8,7 @@ import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; -import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; import { SSEClientTransport } from '../../src/client/sse.js'; /** @@ -430,15 +430,11 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), - token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; }); @@ -447,7 +443,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -456,7 +451,7 @@ describe('SSEClientTransport', () => { await transport.start(); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attaches custom header from provider on initial SSE connection', async () => { @@ -464,7 +459,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' }; @@ -480,7 +474,7 @@ describe('SSEClientTransport', () => { expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); expect(lastServerRequest.headers['x-custom-header']).toBe('custom-value'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attaches auth header from provider on POST requests', async () => { @@ -488,7 +482,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); transport = new SSEClientTransport(resourceBaseUrl, { authProvider: mockAuthProvider @@ -506,7 +499,7 @@ describe('SSEClientTransport', () => { await transport.send(message); expect(lastServerRequest.headers.authorization).toBe('Bearer test-token'); - expect(mockAuthProvider.token).toHaveBeenCalled(); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); it('attempts auth flow on 401 during SSE connection', async () => { @@ -638,7 +631,6 @@ describe('SSEClientTransport', () => { access_token: 'test-token', token_type: 'Bearer' }); - mockAuthProvider.token.mockResolvedValue('test-token'); const customHeaders = { 'X-Custom-Header': 'custom-value' @@ -674,7 +666,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -804,7 +795,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -958,7 +948,6 @@ describe('SSEClientTransport', () => { refresh_token: 'refresh-token' }; mockAuthProvider.tokens.mockImplementation(() => currentTokens); - mockAuthProvider.token.mockImplementation(async () => currentTokens.access_token); mockAuthProvider.saveTokens.mockImplementation(tokens => { currentTokens = tokens; }); @@ -1229,15 +1218,11 @@ describe('SSEClientTransport', () => { }, clientInformation: vi.fn().mockResolvedValue(clientInfo), tokens: vi.fn().mockResolvedValue(tokens), - token: vi.fn(async () => tokens?.access_token), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn().mockResolvedValue('test-verifier'), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; }; diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 4059bb981..4134bda80 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -3,7 +3,7 @@ import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontex import type { Mock, Mocked } from 'vitest'; import type { OAuthClientProvider } from '../../src/client/auth.js'; -import { handleOAuthUnauthorized, UnauthorizedError } from '../../src/client/auth.js'; +import { UnauthorizedError } from '../../src/client/auth.js'; import type { StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -21,15 +21,11 @@ describe('StreamableHTTPClientTransport', () => { }, clientInformation: vi.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), tokens: vi.fn(), - token: vi.fn(async () => undefined), saveTokens: vi.fn(), redirectToAuthorization: vi.fn(), saveCodeVerifier: vi.fn(), codeVerifier: vi.fn(), - invalidateCredentials: vi.fn(), - onUnauthorized: vi.fn(async ctx => { - await handleOAuthUnauthorized(mockAuthProvider, ctx); - }) + invalidateCredentials: vi.fn() }; transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { authProvider: mockAuthProvider }); vi.spyOn(globalThis, 'fetch'); From 98845df9123e0a52868a1d23c1e25b445b363b9a Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 22:05:49 +0000 Subject: [PATCH 04/13] fix: address round-4 review comments on 401 handling Four fixes from claude[bot] review on the AuthProvider approach: 1. Drain 401 response body after onUnauthorized() succeeds, before the retry. Unconsumed bodies block socket recycling in undici. All three 401 sites now drain before return. 2. _startOrAuthSse() 401 retry was return await, causing onerror to fire twice (recursive call's catch + outer catch both fire). Changed to return (not awaited) matching the send() pattern. Removed the try/finally, added flag reset to success path + outer catch instead. 3. Migration docs still referenced SdkErrorCode.ClientHttpAuthentication for the 401-after-auth case, but that throw site was replaced by _authRetryInFlight which throws UnauthorizedError. Updated both migration.md and migration-SKILL.md. 4. Pre-existing: 403 upscoping auth() call passed this._fetch instead of this._fetchWithInit, dropping custom requestInit options during token requests. All other auth() calls in this transport already used _fetchWithInit. --- docs/migration-SKILL.md | 3 +-- docs/migration.md | 9 ++++---- packages/client/src/client/sse.ts | 1 + packages/client/src/client/streamableHttp.ts | 23 ++++++++++---------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 2832d3248..545008d2a 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -116,7 +116,7 @@ Two error classes now exist: | Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | | HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | | Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 401 after auth flow | `StreamableHTTPError` | `UnauthorizedError` | | 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | | Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | | Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | @@ -131,7 +131,6 @@ New `SdkErrorCode` enum values: - `SdkErrorCode.ConnectionClosed` = `'CONNECTION_CLOSED'` - `SdkErrorCode.SendFailed` = `'SEND_FAILED'` - `SdkErrorCode.ClientHttpNotImplemented` = `'CLIENT_HTTP_NOT_IMPLEMENTED'` -- `SdkErrorCode.ClientHttpAuthentication` = `'CLIENT_HTTP_AUTHENTICATION'` - `SdkErrorCode.ClientHttpForbidden` = `'CLIENT_HTTP_FORBIDDEN'` - `SdkErrorCode.ClientHttpUnexpectedContent` = `'CLIENT_HTTP_UNEXPECTED_CONTENT'` - `SdkErrorCode.ClientHttpFailedToOpenStream` = `'CLIENT_HTTP_FAILED_TO_OPEN_STREAM'` diff --git a/docs/migration.md b/docs/migration.md index 9927ab2ea..546c89be7 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -606,7 +606,7 @@ The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: | `SdkErrorCode.ConnectionClosed` | Connection was closed | | `SdkErrorCode.SendFailed` | Failed to send message | | `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after successful auth | +| `UnauthorizedError` (thrown, not `SdkError`) | Server returned 401 after re-auth attempt | | `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | | `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | | `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | @@ -638,11 +638,10 @@ import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; try { await transport.send(message); } catch (error) { - if (error instanceof SdkError) { + if (error instanceof UnauthorizedError) { + console.log('Token rejected — reconnect with fresh credentials'); + } else if (error instanceof SdkError) { switch (error.code) { - case SdkErrorCode.ClientHttpAuthentication: - console.log('Auth failed after completing auth flow'); - break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); break; diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index c613bd2b1..a8646adea 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -275,6 +275,7 @@ export class SSEClientTransport implements Transport { serverUrl: this._url, fetchFn: this._fetchWithInit }); + await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index a5fd0500f..21ffeb375 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -228,16 +228,14 @@ export class StreamableHTTPClientTransport implements Transport { if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { this._authRetryInFlight = true; - try { - await this._authProvider.onUnauthorized({ - response, - serverUrl: this._url, - fetchFn: this._fetchWithInit - }); - return await this._startOrAuthSse(options); - } finally { - this._authRetryInFlight = false; - } + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); + await response.text?.().catch(() => {}); + // Purposely _not_ awaited, so we don't call onerror twice + return this._startOrAuthSse(options); } if (this._authProvider) { await response.text?.().catch(() => {}); @@ -259,8 +257,10 @@ export class StreamableHTTPClientTransport implements Transport { }); } + this._authRetryInFlight = false; this._handleSseStream(response.body, options, true); } catch (error) { + this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } @@ -518,6 +518,7 @@ export class StreamableHTTPClientTransport implements Transport { serverUrl: this._url, fetchFn: this._fetchWithInit }); + await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } @@ -557,7 +558,7 @@ export class StreamableHTTPClientTransport implements Transport { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, scope: this._scope, - fetchFn: this._fetch + fetchFn: this._fetchWithInit }); if (result !== 'AUTHORIZED') { From b869d963d15c4dca7997a0155897694a544fefb8 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 19 Mar 2026 22:11:34 +0000 Subject: [PATCH 05/13] fix: restore SdkError(ClientHttpAuthentication) for circuit-breaker case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 401-after-re-auth case (circuit breaker trips) should throw a distinct error from the normal 'token rejected' case: - First 401 with no onUnauthorized → UnauthorizedError — caller re-auths externally and reconnects - Second 401 after onUnauthorized succeeded → SdkError with ClientHttpAuthentication — server is misbehaving, don't blindly retry, escalate The previous commit collapsed these into UnauthorizedError, which risks callers catching it, re-authing, and looping. Restored the SdkError throw at all three 401 sites when _authRetryInFlight is already set. Reverted migration doc changes — ClientHttpAuthentication is not dead code. --- docs/migration-SKILL.md | 27 +++++++------- docs/migration.md | 37 ++++++++++--------- packages/client/src/client/sse.ts | 7 +++- packages/client/src/client/streamableHttp.ts | 14 ++++++- packages/client/test/client/sse.test.ts | 8 ++-- .../client/test/client/streamableHttp.test.ts | 4 +- .../client/test/client/tokenProvider.test.ts | 7 +++- 7 files changed, 64 insertions(+), 40 deletions(-) diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 545008d2a..e5de064da 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -107,19 +107,19 @@ Two error classes now exist: - **`ProtocolError`** (renamed from `McpError`): Protocol errors that cross the wire as JSON-RPC responses - **`SdkError`** (new): Local SDK errors that never cross the wire -| Error scenario | v1 type | v2 type | -| -------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | -| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | -| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | -| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | -| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | -| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | -| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | -| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | -| 401 after auth flow | `StreamableHTTPError` | `UnauthorizedError` | -| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | -| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | -| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | +| Error scenario | v1 type | v2 type | +| --------------------------------- | -------------------------------------------- | ----------------------------------------------------------------- | +| Request timeout | `McpError` with `ErrorCode.RequestTimeout` | `SdkError` with `SdkErrorCode.RequestTimeout` | +| Connection closed | `McpError` with `ErrorCode.ConnectionClosed` | `SdkError` with `SdkErrorCode.ConnectionClosed` | +| Capability not supported | `new Error(...)` | `SdkError` with `SdkErrorCode.CapabilityNotSupported` | +| Not connected | `new Error('Not connected')` | `SdkError` with `SdkErrorCode.NotConnected` | +| Invalid params (server response) | `McpError` with `ErrorCode.InvalidParams` | `ProtocolError` with `ProtocolErrorCode.InvalidParams` | +| HTTP transport error | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttp*` | +| Failed to open SSE stream | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToOpenStream` | +| 401 after re-auth (circuit break) | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpAuthentication` | +| 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | +| Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | +| Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | New `SdkErrorCode` enum values: @@ -131,6 +131,7 @@ New `SdkErrorCode` enum values: - `SdkErrorCode.ConnectionClosed` = `'CONNECTION_CLOSED'` - `SdkErrorCode.SendFailed` = `'SEND_FAILED'` - `SdkErrorCode.ClientHttpNotImplemented` = `'CLIENT_HTTP_NOT_IMPLEMENTED'` +- `SdkErrorCode.ClientHttpAuthentication` = `'CLIENT_HTTP_AUTHENTICATION'` - `SdkErrorCode.ClientHttpForbidden` = `'CLIENT_HTTP_FORBIDDEN'` - `SdkErrorCode.ClientHttpUnexpectedContent` = `'CLIENT_HTTP_UNEXPECTED_CONTENT'` - `SdkErrorCode.ClientHttpFailedToOpenStream` = `'CLIENT_HTTP_FAILED_TO_OPEN_STREAM'` diff --git a/docs/migration.md b/docs/migration.md index 546c89be7..64540bc92 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -596,21 +596,21 @@ try { The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: -| Code | Description | -| ------------------------------------------------- | ------------------------------------------ | -| `SdkErrorCode.NotConnected` | Transport is not connected | -| `SdkErrorCode.AlreadyConnected` | Transport is already connected | -| `SdkErrorCode.NotInitialized` | Protocol is not initialized | -| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | -| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | -| `SdkErrorCode.ConnectionClosed` | Connection was closed | -| `SdkErrorCode.SendFailed` | Failed to send message | -| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | -| `UnauthorizedError` (thrown, not `SdkError`) | Server returned 401 after re-auth attempt | -| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | -| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | -| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | -| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | +| Code | Description | +| ------------------------------------------------- | ------------------------------------------- | +| `SdkErrorCode.NotConnected` | Transport is not connected | +| `SdkErrorCode.AlreadyConnected` | Transport is already connected | +| `SdkErrorCode.NotInitialized` | Protocol is not initialized | +| `SdkErrorCode.CapabilityNotSupported` | Required capability is not supported | +| `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | +| `SdkErrorCode.ConnectionClosed` | Connection was closed | +| `SdkErrorCode.SendFailed` | Failed to send message | +| `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | +| `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after re-authentication | +| `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | +| `SdkErrorCode.ClientHttpUnexpectedContent` | Unexpected content type in HTTP response | +| `SdkErrorCode.ClientHttpFailedToOpenStream` | Failed to open SSE stream | +| `SdkErrorCode.ClientHttpFailedToTerminateSession` | Failed to terminate session | #### `StreamableHTTPError` removed @@ -638,10 +638,11 @@ import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; try { await transport.send(message); } catch (error) { - if (error instanceof UnauthorizedError) { - console.log('Token rejected — reconnect with fresh credentials'); - } else if (error instanceof SdkError) { + if (error instanceof SdkError) { switch (error.code) { + case SdkErrorCode.ClientHttpAuthentication: + console.log('Auth failed — server rejected token after re-auth'); + break; case SdkErrorCode.ClientHttpForbidden: console.log('Forbidden after upscoping attempt'); break; diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index a8646adea..e714ad4d2 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -279,8 +279,13 @@ export class SSEClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 21ffeb375..dea9d3882 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -237,8 +237,13 @@ export class StreamableHTTPClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this._startOrAuthSse(options); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } @@ -522,8 +527,13 @@ export class StreamableHTTPClientTransport implements Transport { // Purposely _not_ awaited, so we don't call onerror twice return this.send(message); } + await response.text?.().catch(() => {}); + if (this._authRetryInFlight) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }); + } if (this._authProvider) { - await response.text?.().catch(() => {}); throw new UnauthorizedError(); } } diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 10fcd76bc..fd2b184cc 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -3,7 +3,7 @@ import { createServer } from 'node:http'; import type { AddressInfo } from 'node:net'; import type { JSONRPCMessage, OAuthTokens } from '@modelcontextprotocol/core'; -import { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/core'; +import { OAuthError, OAuthErrorCode, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock, Mocked, MockedFunction, MockInstance } from 'vitest'; @@ -1575,7 +1575,7 @@ describe('SSEClientTransport', () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); }); - it('enforces circuit breaker on double-401: onUnauthorized called once, then throws', async () => { + it('enforces circuit breaker on double-401: onUnauthorized called once, then throws SdkError', async () => { postResponses = [401, 401]; await setupServer(); @@ -1586,7 +1586,9 @@ describe('SSEClientTransport', () => { transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); await transport.start(); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); expect(postCount).toBe(2); }); diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 4134bda80..55bf79a50 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -1705,7 +1705,9 @@ describe('StreamableHTTPClientTransport', () => { // Retry the original request - still 401 (broken server) .mockResolvedValueOnce(unauthedResponse); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ access_token: 'new-access-token', token_type: 'Bearer', diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index c683a4012..3ab6c9623 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -1,4 +1,5 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; import type { Mock } from 'vitest'; import type { AuthProvider } from '../../src/client/auth.js'; @@ -81,7 +82,7 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); }); - it('should throw UnauthorizedError if retry after onUnauthorized also gets 401', async () => { + it('should throw SdkError(ClientHttpAuthentication) if retry after onUnauthorized also gets 401', async () => { const authProvider: AuthProvider = { token: vi.fn(async () => 'still-bad'), onUnauthorized: vi.fn(async () => {}) @@ -93,7 +94,9 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }) .mockResolvedValueOnce({ ok: false, status: 401, headers: new Headers(), text: async () => 'unauthorized' }); - await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + const error = await transport.send(message).catch(e => e); + expect(error).toBeInstanceOf(SdkError); + expect((error as SdkError).code).toBe(SdkErrorCode.ClientHttpAuthentication); expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(1); }); From 50b386a18923f0d1eacb9be0fc396466ee323d26 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 11:00:20 +0000 Subject: [PATCH 06/13] fix: address round-5 review comments on 401 edge cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes from claude[bot] review: 1. _startOrAuthSse 405 return doesn't reset _authRetryInFlight — 401 → onUnauthorized → retry → 405 would leave the flag set forever, disabling onUnauthorized for subsequent send() 401s. Added reset before the 405 return. 2. SSE onerror handler — when onUnauthorized rejects in the connect path, the error went through .then(resolve, reject) without calling this.onerror. Every other error path in both transports calls both. Added this.onerror?.() before reject. 3. 401 with no authProvider drained body twice — the 401 block ran unconditionally, drained at line 232/285/521, fell through (no authProvider = no throw), then the generic error handler drained again (empty) and produced 'HTTP 401: null'. Gated the entire 401 block on this._authProvider (matching pre-PR structure) so no-auth 401s hit the generic error directly with intact body text. --- packages/client/src/client/sse.ts | 13 +++++++------ packages/client/src/client/streamableHttp.ts | 17 +++++++---------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index e714ad4d2..4b62baca2 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -154,7 +154,10 @@ export class SSEClientTransport implements Transport { this._authProvider .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) .then(() => this._startOrAuth()) - .then(resolve, reject) + .then(resolve, error => { + this.onerror?.(error); + reject(error); + }) .finally(() => { this._authRetryInFlight = false; }); @@ -261,14 +264,14 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - if (response.status === 401) { + if (response.status === 401 && this._authProvider) { if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; } - if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { this._authRetryInFlight = true; await this._authProvider.onUnauthorized({ response, @@ -285,9 +288,7 @@ export class SSEClientTransport implements Transport { status: 401 }); } - if (this._authProvider) { - throw new UnauthorizedError(); - } + throw new UnauthorizedError(); } const text = await response.text?.().catch(() => null); diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index dea9d3882..710aa7f57 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -219,14 +219,14 @@ export class StreamableHTTPClientTransport implements Transport { }); if (!response.ok) { - if (response.status === 401) { + if (response.status === 401 && this._authProvider) { if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); this._resourceMetadataUrl = resourceMetadataUrl; this._scope = scope; } - if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { this._authRetryInFlight = true; await this._authProvider.onUnauthorized({ response, @@ -243,9 +243,7 @@ export class StreamableHTTPClientTransport implements Transport { status: 401 }); } - if (this._authProvider) { - throw new UnauthorizedError(); - } + throw new UnauthorizedError(); } await response.text?.().catch(() => {}); @@ -253,6 +251,7 @@ export class StreamableHTTPClientTransport implements Transport { // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { + this._authRetryInFlight = false; return; } @@ -508,7 +507,7 @@ export class StreamableHTTPClientTransport implements Transport { } if (!response.ok) { - if (response.status === 401) { + if (response.status === 401 && this._authProvider) { // Store WWW-Authenticate params for interactive finishAuth() path if (response.headers.has('www-authenticate')) { const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); @@ -516,7 +515,7 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = scope; } - if (this._authProvider?.onUnauthorized && !this._authRetryInFlight) { + if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { this._authRetryInFlight = true; await this._authProvider.onUnauthorized({ response, @@ -533,9 +532,7 @@ export class StreamableHTTPClientTransport implements Transport { status: 401 }); } - if (this._authProvider) { - throw new UnauthorizedError(); - } + throw new UnauthorizedError(); } const text = await response.text?.().catch(() => null); From da55f55dfcdc193576304f4cd7e3919fbd0f098c Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 11:25:13 +0000 Subject: [PATCH 07/13] docs: add dual-mode auth example to validate AuthProvider decomposition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Demonstrates two auth setups through the same authProvider slot: MODE A (host-managed): Enclosing app owns the token. Minimal AuthProvider with { token, onUnauthorized } — onUnauthorized signals the host UI and throws instead of refreshing, since the host owns the token lifecycle. MODE B (user-configured): OAuth credentials supplied directly. Passes a ClientCredentialsProvider; transport adapts it to AuthProvider via adaptOAuthProvider (synthesizing token()/onUnauthorized()). Same connectAndList() caller code for both — the transport abstracts the difference. Validates the decomposition holds with zero branching in user code. --- examples/client/src/dualModeAuth.ts | 115 ++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 examples/client/src/dualModeAuth.ts diff --git a/examples/client/src/dualModeAuth.ts b/examples/client/src/dualModeAuth.ts new file mode 100644 index 000000000..75f5c1bec --- /dev/null +++ b/examples/client/src/dualModeAuth.ts @@ -0,0 +1,115 @@ +#!/usr/bin/env node + +/** + * Two auth patterns through the same `authProvider` option. + * + * The transport accepts either a minimal `AuthProvider` (just `token()` + + * optional `onUnauthorized()`) or a full `OAuthClientProvider`, adapting + * the latter automatically. This means your connect/call code is identical + * regardless of which pattern fits your deployment. + * + * HOST-MANAGED — token lives in an enclosing app + * The app fetches and stores tokens; the MCP client just reads them. + * On 401, there is nothing to refresh — signal the UI and throw so the + * user can re-authenticate through the host's flow. + * + * USER-CONFIGURED — OAuth credentials supplied directly + * Pass a built-in or custom OAuthClientProvider. The transport handles + * the full OAuth flow: token refresh on 401, or redirect for interactive + * authorization. + */ + +import type { AuthProvider } from '@modelcontextprotocol/client'; +import { Client, ClientCredentialsProvider, StreamableHTTPClientTransport, UnauthorizedError } from '@modelcontextprotocol/client'; + +// --- Stubs for host-app integration points --------------------------------- + +/** Whatever the host app uses to store session state (e.g., cookies, keychain, in-memory). */ +interface HostSessionStore { + getMcpToken(): string | undefined; +} + +/** Whatever the host app uses to surface UI prompts. */ +interface HostUi { + showReauthPrompt(message: string): void; +} + +// --- MODE A: Host-managed auth --------------------------------------------- + +function createHostManagedTransport(serverUrl: URL, session: HostSessionStore, ui: HostUi): StreamableHTTPClientTransport { + const authProvider: AuthProvider = { + // Called before every request — just read whatever the host has. + token: async () => session.getMcpToken(), + + // Called on 401 — don't refresh (the host owns the token), signal the UI and bail. + // The transport will retry once after this returns, so we throw to stop it: + // the user needs to act before a retry makes sense. + onUnauthorized: async () => { + ui.showReauthPrompt('MCP connection lost — click to reconnect'); + throw new UnauthorizedError('Host token rejected — user action required'); + } + }; + + return new StreamableHTTPClientTransport(serverUrl, { authProvider }); +} + +// --- MODE B: User-configured OAuth ----------------------------------------- + +function createUserConfiguredTransport(serverUrl: URL, clientId: string, clientSecret: string): StreamableHTTPClientTransport { + // Built-in OAuth provider — the transport adapts it to AuthProvider internally. + // On 401, adaptOAuthProvider synthesizes onUnauthorized → handleOAuthUnauthorized, + // which runs token refresh (or redirect for interactive flows). + const authProvider = new ClientCredentialsProvider({ clientId, clientSecret }); + + return new StreamableHTTPClientTransport(serverUrl, { authProvider }); +} + +// --- Same caller code for both modes --------------------------------------- + +async function connectAndList(transport: StreamableHTTPClientTransport): Promise { + const client = new Client({ name: 'dual-mode-example', version: '1.0.0' }, { capabilities: {} }); + await client.connect(transport); + + const tools = await client.listTools(); + console.log('Tools:', tools.tools.map(t => t.name).join(', ') || '(none)'); + + await transport.close(); +} + +// --- Driver ---------------------------------------------------------------- + +async function main() { + const serverUrl = new URL(process.env.MCP_SERVER_URL || 'http://localhost:3000/mcp'); + const mode = process.argv[2] || 'host'; + + let transport: StreamableHTTPClientTransport; + + if (mode === 'host') { + // Simulate a host app with a session-stored token and a UI hook. + const session: HostSessionStore = { getMcpToken: () => process.env.MCP_TOKEN }; + const ui: HostUi = { showReauthPrompt: msg => console.error(`[UI] ${msg}`) }; + transport = createHostManagedTransport(serverUrl, session, ui); + } else if (mode === 'oauth') { + const clientId = process.env.OAUTH_CLIENT_ID; + const clientSecret = process.env.OAUTH_CLIENT_SECRET; + if (!clientId || !clientSecret) { + console.error('OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET required for oauth mode'); + process.exit(1); + } + transport = createUserConfiguredTransport(serverUrl, clientId, clientSecret); + } else { + console.error(`Unknown mode: ${mode}. Use 'host' or 'oauth'.`); + process.exit(1); + } + + // Same connect/list code regardless of mode — the transport abstracts the difference. + await connectAndList(transport); +} + +try { + await main(); +} catch (error) { + console.error('Error:', error); + // eslint-disable-next-line unicorn/no-process-exit + process.exit(1); +} From cae3d78a34f539d832f2218e5c00a600daf518ad Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 11:25:13 +0000 Subject: [PATCH 08/13] refactor: make isOAuthClientProvider check more explicit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Check typeof === 'function' on two required methods (tokens + clientInformation) instead of bare 'in' operator. Slightly more robust — verifies they're actually callable, not just properties with those names. Same semantics, reads cleaner. --- packages/client/src/client/auth.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/client/src/client/auth.ts b/packages/client/src/client/auth.ts index d26bb5727..7200a14af 100644 --- a/packages/client/src/client/auth.ts +++ b/packages/client/src/client/auth.ts @@ -84,9 +84,14 @@ export interface AuthProvider { /** * Type guard distinguishing `OAuthClientProvider` from a minimal `AuthProvider`. * Transports use this at construction time to classify the `authProvider` option. + * + * Checks for `tokens()` + `clientInformation()` — two required `OAuthClientProvider` + * methods that a minimal `AuthProvider` `{ token: ... }` would never have. */ export function isOAuthClientProvider(provider: AuthProvider | OAuthClientProvider | undefined): provider is OAuthClientProvider { - return provider !== undefined && 'tokens' in provider && 'clientMetadata' in provider; + if (provider == null) return false; + const p = provider as OAuthClientProvider; + return typeof p.tokens === 'function' && typeof p.clientInformation === 'function'; } /** From 51c75981df44badd49940a1715f7e427626089c7 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 13:27:59 +0000 Subject: [PATCH 09/13] fix: separate retry flags for GET-SSE vs POST; forward send() options on retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two fixes from claude[bot] round-6 review: 1. Shared _authRetryInFlight between _startOrAuthSse() and send() created a race: if the fire-and-forget GET SSE gets 401 and sets the flag while awaiting onUnauthorized(), a concurrent POST send() that also gets 401 would see flag=true and throw ClientHttpAuthentication without ever attempting its own re-auth. The old _hasCompletedAuthFlow was only set in send() — I introduced the regression when adding 401 handling to _startOrAuthSse. Split into _authRetryInFlight (send path) and _sseAuthRetryInFlight (GET-SSE path). 2. Pre-existing: send() 401/403 retries called this.send(message) without forwarding the options parameter, dropping onresumptiontoken on the retried request. Added options to both call sites. --- packages/client/src/client/streamableHttp.ts | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 710aa7f57..1d7423e88 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -148,7 +148,8 @@ export class StreamableHTTPClientTransport implements Transport { private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _authRetryInFlight = false; // Circuit breaker: single retry per operation on 401 + private _authRetryInFlight = false; // Circuit breaker for send() 401 retry + private _sseAuthRetryInFlight = false; // Circuit breaker for _startOrAuthSse() 401 retry — separate so concurrent GET/POST 401s don't interfere private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -226,8 +227,8 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = scope; } - if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { - this._authRetryInFlight = true; + if (this._authProvider.onUnauthorized && !this._sseAuthRetryInFlight) { + this._sseAuthRetryInFlight = true; await this._authProvider.onUnauthorized({ response, serverUrl: this._url, @@ -238,7 +239,7 @@ export class StreamableHTTPClientTransport implements Transport { return this._startOrAuthSse(options); } await response.text?.().catch(() => {}); - if (this._authRetryInFlight) { + if (this._sseAuthRetryInFlight) { throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { status: 401 }); @@ -251,7 +252,7 @@ export class StreamableHTTPClientTransport implements Transport { // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { - this._authRetryInFlight = false; + this._sseAuthRetryInFlight = false; return; } @@ -261,10 +262,10 @@ export class StreamableHTTPClientTransport implements Transport { }); } - this._authRetryInFlight = false; + this._sseAuthRetryInFlight = false; this._handleSseStream(response.body, options, true); } catch (error) { - this._authRetryInFlight = false; + this._sseAuthRetryInFlight = false; this.onerror?.(error as Error); throw error; } @@ -524,7 +525,7 @@ export class StreamableHTTPClientTransport implements Transport { }); await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + return this.send(message, options); } await response.text?.().catch(() => {}); if (this._authRetryInFlight) { @@ -572,7 +573,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError(); } - return this.send(message); + return this.send(message, options); } } From 784ad6956ad823f49f2453932bff6d4682352ac7 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 13:44:59 +0000 Subject: [PATCH 10/13] test: add wire-level integration tests for both auth modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Proof-of-life that both auth shapes work against a real HTTP server: - MODE A (minimal AuthProvider): { token: () => 'token' } → server sees Authorization: Bearer token - MODE A 401: onUnauthorized signals UI and throws → caller sees the thrown error (the host-managed pattern where the enclosing app handles reauth) - MODE B (OAuthClientProvider): passed directly, adapter synthesizes token() from tokens() → server sees Authorization: Bearer - Combined: same constructor option slot, same send() call, both shapes hit the same server Uses real node:http server (not fetch mocks) to verify the Authorization header actually reaches the wire. --- .../client/test/client/tokenProvider.test.ts | 137 +++++++++++++++++- 1 file changed, 135 insertions(+), 2 deletions(-) diff --git a/packages/client/test/client/tokenProvider.test.ts b/packages/client/test/client/tokenProvider.test.ts index 3ab6c9623..d6ef35bde 100644 --- a/packages/client/test/client/tokenProvider.test.ts +++ b/packages/client/test/client/tokenProvider.test.ts @@ -1,8 +1,12 @@ -import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import type { IncomingMessage, Server } from 'node:http'; +import { createServer } from 'node:http'; + +import type { JSONRPCMessage, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; +import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; import type { Mock } from 'vitest'; -import type { AuthProvider } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -180,3 +184,132 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); }); }); + +describe('AuthProvider integration — both modes against a real server', () => { + let server: Server; + let serverUrl: URL; + let capturedRequests: IncomingMessage[]; + let transport: StreamableHTTPClientTransport; + + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'ping', params: {}, id: '1' }; + + beforeEach(async () => { + capturedRequests = []; + server = createServer((req, res) => { + capturedRequests.push(req); + if (req.method === 'POST') { + // Consume body then respond 202 Accepted + req.on('data', () => {}); + req.on('end', () => res.writeHead(202).end()); + } else { + // GET SSE — reject so the transport skips it + res.writeHead(405).end(); + } + }); + serverUrl = await listenOnRandomPort(server); + }); + + afterEach(async () => { + await transport?.close().catch(() => {}); + await new Promise(resolve => server.close(() => resolve())); + }); + + it('MODE A: minimal AuthProvider { token } sends Authorization header', async () => { + const authProvider: AuthProvider = { token: async () => 'mode-a-token' }; + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + + await transport.send(message); + + expect(capturedRequests).toHaveLength(1); + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-a-token'); + }); + + it('MODE A: onUnauthorized signals and throws — caller sees the error', async () => { + const uiSignal = vi.fn(); + const authProvider: AuthProvider = { + token: async () => 'rejected-token', + onUnauthorized: async () => { + uiSignal('show-reauth-prompt'); + throw new UnauthorizedError('user action required'); + } + }; + + // Server that rejects with 401 + await new Promise(resolve => server.close(() => resolve())); + server = createServer((req, res) => { + capturedRequests.push(req); + req.on('data', () => {}); + req.on('end', () => res.writeHead(401).end()); + }); + serverUrl = await listenOnRandomPort(server); + + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); + + await expect(transport.send(message)).rejects.toThrow('user action required'); + expect(uiSignal).toHaveBeenCalledWith('show-reauth-prompt'); + }); + + it('MODE B: OAuthClientProvider is adapted — tokens() becomes token() on the wire', async () => { + // Minimal OAuthClientProvider — the transport should adapt it via adaptOAuthProvider + const oauthProvider: OAuthClientProvider = { + get redirectUrl() { + return undefined; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: [], grant_types: ['client_credentials'] }; + }, + clientInformation(): OAuthClientInformation { + return { client_id: 'test-client' }; + }, + tokens(): OAuthTokens { + return { access_token: 'mode-b-oauth-token', token_type: 'bearer' }; + }, + saveTokens() {}, + redirectToAuthorization() { + throw new Error('not used'); + }, + saveCodeVerifier() {}, + codeVerifier() { + throw new Error('not used'); + } + }; + + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider: oauthProvider }); + + await transport.send(message); + + expect(capturedRequests).toHaveLength(1); + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-b-oauth-token'); + }); + + it('both modes use the same option slot and same send() call', async () => { + // Mode A + const transportA = new StreamableHTTPClientTransport(serverUrl, { + authProvider: { token: async () => 'a-token' } + }); + await transportA.send(message); + await transportA.close(); + + // Mode B — same constructor, same option name, different shape + const transportB = new StreamableHTTPClientTransport(serverUrl, { + authProvider: { + get redirectUrl() { + return undefined; + }, + get clientMetadata(): OAuthClientMetadata { + return { redirect_uris: [] }; + }, + clientInformation: () => ({ client_id: 'x' }), + tokens: () => ({ access_token: 'b-token', token_type: 'bearer' }), + saveTokens() {}, + redirectToAuthorization() {}, + saveCodeVerifier() {}, + codeVerifier: () => '' + } satisfies OAuthClientProvider + }); + await transportB.send(message); + await transportB.close(); + + expect(capturedRequests.map(r => r.headers.authorization)).toEqual(['Bearer a-token', 'Bearer b-token']); + }); +}); From b31d8d932bad56cbb0f52b7e2eb8fe22663a68cd Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 13:51:26 +0000 Subject: [PATCH 11/13] chore: use process.exitCode instead of process.exit in examples Removes eslint-disable suppression. process.exitCode = 1 lets the event loop drain before exit; process.exit(1) kills immediately and can cut off pending writes. --- examples/client/src/dualModeAuth.ts | 3 +-- examples/client/src/simpleTokenProvider.ts | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/client/src/dualModeAuth.ts b/examples/client/src/dualModeAuth.ts index 75f5c1bec..4dd1eaded 100644 --- a/examples/client/src/dualModeAuth.ts +++ b/examples/client/src/dualModeAuth.ts @@ -110,6 +110,5 @@ try { await main(); } catch (error) { console.error('Error:', error); - // eslint-disable-next-line unicorn/no-process-exit - process.exit(1); + process.exitCode = 1; } diff --git a/examples/client/src/simpleTokenProvider.ts b/examples/client/src/simpleTokenProvider.ts index 7b5f1a4c1..ce68fde5a 100644 --- a/examples/client/src/simpleTokenProvider.ts +++ b/examples/client/src/simpleTokenProvider.ts @@ -51,6 +51,5 @@ try { await main(); } catch (error) { console.error('Error running client:', error); - // eslint-disable-next-line unicorn/no-process-exit - process.exit(1); + process.exitCode = 1; } From 71a926013916ed51f4b15588640b33949ab863fe Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 14:54:21 +0000 Subject: [PATCH 12/13] refactor: replace _authRetryInFlight class field with isAuthRetry parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stops the 10-comment whack-a-mole around flag lifecycle. A mutable boolean class field is the wrong primitive for 'retry once per operation' when operations are concurrent and recursive — every reset point creates a race, every missed reset creates a stuck flag. Now all four 401 paths use parameter-passed isAuthRetry: - StreamableHTTP _startOrAuthSse(options, isAuthRetry = false): recursion passes true. No class field, no reset sites. - StreamableHTTP send() delegates to private _send(message, options, isAuthRetry). Recursion passes true. No class field. - SSE _startOrAuth(isAuthRetry = false): onerror callback captures isAuthRetry from closure; retry calls _startOrAuth(true). - SSE send() delegates to private _send(message, isAuthRetry). Per-operation state dies with the stack frame. Concurrent operations cannot observe each other's retry state. 12 reset sites deleted. Also makes SSE onerror fallback consistent with other paths — throws SdkError(ClientHttpAuthentication) for the circuit-breaker case instead of plain UnauthorizedError. Not addressed (noted for auth() cleanup): concurrent 401s still each call onUnauthorized() independently. Deduplicating that (in-flight promise pattern) would be a behavior change. --- packages/client/src/client/sse.ts | 31 +++++++++--------- packages/client/src/client/streamableHttp.ts | 34 +++++++++----------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index 4b62baca2..ebb651f1b 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -98,7 +98,6 @@ export class SSEClientTransport implements Transport { this._fetchWithInit = createFetchWithInit(opts?.fetch, opts?.requestInit); } - private _authRetryInFlight = false; private _last401Response?: Response; private async _commonHeaders(): Promise { @@ -119,7 +118,7 @@ export class SSEClientTransport implements Transport { }); } - private _startOrAuth(): Promise { + private _startOrAuth(isAuthRetry = false): Promise { const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch; return new Promise((resolve, reject) => { this._eventSource = new EventSource(this._url.href, { @@ -148,22 +147,22 @@ export class SSEClientTransport implements Transport { this._eventSource.onerror = event => { if (event.code === 401 && this._authProvider) { - if (this._authProvider.onUnauthorized && this._last401Response && !this._authRetryInFlight) { - this._authRetryInFlight = true; + if (this._authProvider.onUnauthorized && this._last401Response && !isAuthRetry) { const response = this._last401Response; this._authProvider .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) - .then(() => this._startOrAuth()) + .then(() => this._startOrAuth(true)) .then(resolve, error => { this.onerror?.(error); reject(error); - }) - .finally(() => { - this._authRetryInFlight = false; }); return; } - const error = new UnauthorizedError(); + const error = isAuthRetry + ? new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { + status: 401 + }) + : new UnauthorizedError(); reject(error); this.onerror?.(error); return; @@ -247,6 +246,10 @@ export class SSEClientTransport implements Transport { } async send(message: JSONRPCMessage): Promise { + return this._send(message, false); + } + + private async _send(message: JSONRPCMessage, isAuthRetry: boolean): Promise { if (!this._endpoint) { throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); } @@ -271,8 +274,7 @@ export class SSEClientTransport implements Transport { this._scope = scope; } - if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { - this._authRetryInFlight = true; + if (this._authProvider.onUnauthorized && !isAuthRetry) { await this._authProvider.onUnauthorized({ response, serverUrl: this._url, @@ -280,10 +282,10 @@ export class SSEClientTransport implements Transport { }); await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + return this._send(message, true); } await response.text?.().catch(() => {}); - if (this._authRetryInFlight) { + if (isAuthRetry) { throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { status: 401 }); @@ -295,12 +297,9 @@ export class SSEClientTransport implements Transport { throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`); } - this._authRetryInFlight = false; - // Release connection - POST responses don't have content we need await response.text?.().catch(() => {}); } catch (error) { - this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 1d7423e88..3d45b60e9 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -148,8 +148,6 @@ export class StreamableHTTPClientTransport implements Transport { private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; - private _authRetryInFlight = false; // Circuit breaker for send() 401 retry - private _sseAuthRetryInFlight = false; // Circuit breaker for _startOrAuthSse() 401 retry — separate so concurrent GET/POST 401s don't interfere private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; @@ -198,7 +196,7 @@ export class StreamableHTTPClientTransport implements Transport { }); } - private async _startOrAuthSse(options: StartSSEOptions): Promise { + private async _startOrAuthSse(options: StartSSEOptions, isAuthRetry = false): Promise { const { resumptionToken } = options; try { @@ -227,8 +225,7 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = scope; } - if (this._authProvider.onUnauthorized && !this._sseAuthRetryInFlight) { - this._sseAuthRetryInFlight = true; + if (this._authProvider.onUnauthorized && !isAuthRetry) { await this._authProvider.onUnauthorized({ response, serverUrl: this._url, @@ -236,10 +233,10 @@ export class StreamableHTTPClientTransport implements Transport { }); await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice - return this._startOrAuthSse(options); + return this._startOrAuthSse(options, true); } await response.text?.().catch(() => {}); - if (this._sseAuthRetryInFlight) { + if (isAuthRetry) { throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { status: 401 }); @@ -252,7 +249,6 @@ export class StreamableHTTPClientTransport implements Transport { // 405 indicates that the server does not offer an SSE stream at GET endpoint // This is an expected case that should not trigger an error if (response.status === 405) { - this._sseAuthRetryInFlight = false; return; } @@ -262,10 +258,8 @@ export class StreamableHTTPClientTransport implements Transport { }); } - this._sseAuthRetryInFlight = false; this._handleSseStream(response.body, options, true); } catch (error) { - this._sseAuthRetryInFlight = false; this.onerror?.(error as Error); throw error; } @@ -475,6 +469,14 @@ export class StreamableHTTPClientTransport implements Transport { async send( message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } + ): Promise { + return this._send(message, options, false); + } + + private async _send( + message: JSONRPCMessage | JSONRPCMessage[], + options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined, + isAuthRetry: boolean ): Promise { try { const { resumptionToken, onresumptiontoken } = options || {}; @@ -516,8 +518,7 @@ export class StreamableHTTPClientTransport implements Transport { this._scope = scope; } - if (this._authProvider.onUnauthorized && !this._authRetryInFlight) { - this._authRetryInFlight = true; + if (this._authProvider.onUnauthorized && !isAuthRetry) { await this._authProvider.onUnauthorized({ response, serverUrl: this._url, @@ -525,10 +526,10 @@ export class StreamableHTTPClientTransport implements Transport { }); await response.text?.().catch(() => {}); // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message, options); + return this._send(message, options, true); } await response.text?.().catch(() => {}); - if (this._authRetryInFlight) { + if (isAuthRetry) { throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { status: 401 }); @@ -573,7 +574,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError(); } - return this.send(message, options); + return this._send(message, options, isAuthRetry); } } @@ -583,8 +584,6 @@ export class StreamableHTTPClientTransport implements Transport { }); } - // Reset auth loop flag on successful response - this._authRetryInFlight = false; this._lastUpscopingHeader = undefined; // If the response is 202 Accepted, there's no body to process @@ -634,7 +633,6 @@ export class StreamableHTTPClientTransport implements Transport { await response.text?.().catch(() => {}); } } catch (error) { - this._authRetryInFlight = false; this.onerror?.(error as Error); throw error; } From 56ffc7b4dcca60008825457e7498a88fb0afd6ca Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 20 Mar 2026 15:19:43 +0000 Subject: [PATCH 13/13] fix: SSE onerror should not poison EventSource lifetime with isAuthRetry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The isAuthRetry parameter approach works for recursive method calls (send, _startOrAuthSse) but not for the EventSource onerror callback. Passing _startOrAuth(true) on retry permanently captures isAuthRetry=true in the new EventSource's closure — if that EventSource auto-reconnects later (network blip on a long-lived stream) and gets 401, onUnauthorized is skipped and the transport cannot recover. Verified against eventsource lib: non-200 → failConnection (CLOSED, no reconnect); stream end after OPEN → scheduleReconnect → reconnect attempt can get 401 → failConnection → onerror fires. The 'hours later' scenario is real. Fix: retry always calls _startOrAuth() fresh (no parameter). Matches pre-PR _authThenStart() behavior. Trade-off: no circuit breaker on the SSE connect path — if onUnauthorized succeeds but server keeps 401ing, it loops (same as pre-PR). Also fixes double-onerror: two-arg .then(onSuccess, onFail) separates retry failures (inner _startOrAuth already fired onerror) from onUnauthorized failures (not yet reported). Added close + clear _last401Response before retry for hygiene. Two regression tests added, both verified to FAIL against the buggy code: - 401→401→200: onUnauthorized called TWICE, start() resolves - 401→onUnauthorized succeeds→401→onUnauthorized throws: onerror fires ONCE with the thrown error --- packages/client/src/client/sse.ts | 24 ++++----- packages/client/test/client/sse.test.ts | 71 +++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index ebb651f1b..f441e9cdb 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -118,7 +118,7 @@ export class SSEClientTransport implements Transport { }); } - private _startOrAuth(isAuthRetry = false): Promise { + private _startOrAuth(): Promise { const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch; return new Promise((resolve, reject) => { this._eventSource = new EventSource(this._url.href, { @@ -147,22 +147,22 @@ export class SSEClientTransport implements Transport { this._eventSource.onerror = event => { if (event.code === 401 && this._authProvider) { - if (this._authProvider.onUnauthorized && this._last401Response && !isAuthRetry) { + if (this._authProvider.onUnauthorized && this._last401Response) { const response = this._last401Response; - this._authProvider - .onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }) - .then(() => this._startOrAuth(true)) - .then(resolve, error => { + this._last401Response = undefined; + this._eventSource?.close(); + this._authProvider.onUnauthorized({ response, serverUrl: this._url, fetchFn: this._fetchWithInit }).then( + // onUnauthorized succeeded → retry fresh. Its onerror handles its own onerror?.() + reject. + () => this._startOrAuth().then(resolve, reject), + // onUnauthorized failed → not yet reported. + error => { this.onerror?.(error); reject(error); - }); + } + ); return; } - const error = isAuthRetry - ? new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { - status: 401 - }) - : new UnauthorizedError(); + const error = new UnauthorizedError(); reject(error); this.onerror?.(error); return; diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index fd2b184cc..b0b9588f0 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -1624,5 +1624,76 @@ describe('SSEClientTransport', () => { await expect(transport.finishAuth('auth-code')).rejects.toThrow('finishAuth requires an OAuthClientProvider'); }); + + it('SSE connect 401 retry does not poison future 401s — onUnauthorized called on each attempt', async () => { + // Regression: _startOrAuth(true) baked isAuthRetry=true into the retry EventSource's + // onerror closure, so a subsequent 401 (token expiry on reconnect) would throw + // instead of refreshing. Fix: retry always calls _startOrAuth() fresh. + await resourceServer.close(); + + let getAttempt = 0; + resourceServer = createServer((req, res) => { + if (req.method !== 'GET') { + res.writeHead(404).end(); + return; + } + getAttempt++; + if (getAttempt < 3) { + res.writeHead(401).end(); + return; + } + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' + }); + res.write('event: endpoint\n'); + res.write(`data: ${resourceBaseUrl.href}post\n\n`); + }); + resourceBaseUrl = await listenOnRandomPort(resourceServer); + + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized: vi.fn(async () => {}) + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + + await transport.start(); // should resolve on attempt 3 + + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(getAttempt).toBe(3); + }); + + it('retry failure during SSE connect fires onerror exactly once', async () => { + // Regression: when the retry EventSource rejected, its onerror fired inside, then + // the outer .then() rejection handler fired onerror AGAIN for the same error. + // Fix: inner retry chains to .then(resolve, reject) — no outer onerror call. + // onUnauthorized's own failure is handled separately and fires onerror once. + await resourceServer.close(); + + resourceServer = createServer((req, res) => { + if (req.method === 'GET') { + res.writeHead(401).end(); // always 401 + } + }); + resourceBaseUrl = await listenOnRandomPort(resourceServer); + + const onUnauthorized: AuthProvider['onUnauthorized'] = vi + .fn() + .mockResolvedValueOnce(undefined) // first call succeeds → triggers retry + .mockRejectedValueOnce(new Error('refresh failed')); // second call (in retry) throws + const authProvider: AuthProvider = { + token: vi.fn(async () => 'token'), + onUnauthorized + }; + transport = new SSEClientTransport(resourceBaseUrl, { authProvider }); + const onerror = vi.fn(); + transport.onerror = onerror; + + await expect(transport.start()).rejects.toThrow('refresh failed'); + expect(authProvider.onUnauthorized).toHaveBeenCalledTimes(2); + expect(onerror).toHaveBeenCalledTimes(1); + expect(onerror.mock.calls[0]![0].message).toBe('refresh failed'); + }); }); });