diff --git a/.changeset/add-mcp-auth-and-authorization.md b/.changeset/add-mcp-auth-and-authorization.md new file mode 100644 index 000000000..d655e05c2 --- /dev/null +++ b/.changeset/add-mcp-auth-and-authorization.md @@ -0,0 +1,11 @@ +--- +'@modelcontextprotocol/core': minor +'@modelcontextprotocol/server': minor +--- + +Implement generalized authentication and authorization layer for MCP servers. + +- Added `Authenticator` and `BearerTokenAuthenticator` to `@modelcontextprotocol/server`. +- Integrated scope-based authorization checks into `McpServer` for tools, resources, and prompts. +- Fixed asynchronous error propagation in the core `Protocol` class to support proper 401/403 HTTP status mapping in transports. +- Updated `WebStandardStreamableHTTPServerTransport` to correctly map authentication and authorization failures to their respective HTTP status codes. diff --git a/README.md b/README.md index f068ec915..20bd59294 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ Next steps: - Local SDK docs: - [docs/server.md](docs/server.md) – building MCP servers, transports, tools/resources/prompts, sampling, elicitation, tasks, and deployment patterns. + - [docs/auth.md](docs/auth.md) – implementing authentication and authorization in MCP servers. - [docs/client.md](docs/client.md) – building MCP clients: connecting, tools, resources, prompts, server-initiated requests, and error handling - [docs/faq.md](docs/faq.md) – frequently asked questions and troubleshooting - External references: diff --git a/docs/auth.md b/docs/auth.md new file mode 100644 index 000000000..7a1dd030f --- /dev/null +++ b/docs/auth.md @@ -0,0 +1,110 @@ +# Authentication and Authorization + +The MCP TypeScript SDK provides optional, opt-in support for authentication (AuthN) and authorization (AuthZ). This enables you to protect your MCP server resources, tools, and prompts using industry-standard schemes like OAuth 2.1 Bearer tokens. + +## Key Concepts + +- **Authenticator**: Responsible for extracting and validating authentication information from an incoming request. +- **AuthInfo**: A structure containing information about the authenticated entity (e.g., user name, active scopes). +- **Authorizer**: Used by the MCP server to verify if the authenticated entity has the required scopes to access a specific resource, tool, or prompt. +- **Scopes**: Optional strings associated with registered items that define the required permissions. + +## Implementing Authentication + +To enable authentication, provide an `authenticator` in the `ServerOptions` when creating your server. + +### Using Bearer Token Authentication + +The SDK includes a `BearerTokenAuthenticator` for validating OAuth 2.1 Bearer tokens. + +```typescript +import { McpServer, BearerTokenAuthenticator } from "@modelcontextprotocol/server"; + +const server = new McpServer({ + name: "my-authenticated-server", + version: "1.0.0", +}, { + authenticator: new BearerTokenAuthenticator(async (token) => { + // Validate the token (e.g., verify with an OAuth provider) + if (token === "valid-token") { + return { + token, + clientId: "john_doe", + scopes: ["read:resources", "execute:tools"] + }; + } + return undefined; // Invalid token + }) +}); +``` + +## Implementing Authorization + +Authorization is enforced using the `scopes` property when registering tools, resources, or prompts. + +### Scoped Tools + +```typescript +server.tool( + "secure_tool", + { + description: "A tool that requires specific scopes", + scopes: ["execute:tools"] + }, + async (args) => { + return { content: [{ type: "text", text: "Success!" }] }; + } +); +``` + +### Scoped Resources + +```typescript +server.resource( + "secure_resource", + "secure://data", + { scopes: ["read:resources"] }, + async (uri) => { + return { contents: [{ uri: uri.href, text: "Top secret data" }] }; + } +); +``` + +## Middleware Support + +For framework-specific integrations, use the provided middleware to pre-authenticate requests. + +### Express Middleware + +```typescript +import express from "express"; +import { auth } from "@modelcontextprotocol/express"; + +const app = express(); +app.use(auth({ authenticator })); + +app.post("/mcp", (req, res) => { + // req.auth is now populated + transport.handleRequest(req, res); +}); +``` + +### Hono Middleware + +```typescript +import { Hono } from "hono"; +import { auth } from "@modelcontextprotocol/hono"; + +const app = new Hono(); +app.use("/mcp/*", auth({ authenticator })); + +app.all("/mcp", async (c) => { + const authInfo = c.get("mcpAuthInfo"); + return transport.handleRequest(c.req.raw, { authInfo }); +}); +``` + +## Error Handling + +- **401 Unauthorized**: Returned when authentication is required but missing or invalid. Includes `WWW-Authenticate: Bearer` header. +- **403 Forbidden**: Returned when the authenticated entity lacks the required scopes. diff --git a/docs/server.md b/docs/server.md index 3c246ac12..43ea3a730 100644 --- a/docs/server.md +++ b/docs/server.md @@ -11,6 +11,7 @@ Building a server takes three steps: 1. Create an {@linkcode @modelcontextprotocol/server!server/mcp.McpServer | McpServer} and register your [tools, resources, and prompts](#tools-resources-and-prompts). 2. Create a transport — [Streamable HTTP](#streamable-http) for remote servers or [stdio](#stdio) for local, process‑spawned integrations. 3. Wire the transport into your HTTP framework (or use stdio directly) and call `server.connect(transport)`. +1. (Optional) Configure [authentication and authorization](#authentication-and-authorization) to protect your server. The sections below cover each of these. For a feature‑rich starting point, see [`simpleStreamableHttp.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/main/examples/server/src/simpleStreamableHttp.ts) — remove what you don't need and register your own tools, resources, and prompts. For stateless or JSON‑response‑mode alternatives, see the examples linked in [Transports](#transports) below. @@ -444,6 +445,25 @@ Task-based execution enables "call-now, fetch-later" patterns for long-running o > [!WARNING] > The tasks API is experimental and may change without notice. +## Authentication and Authorization + +The MCP TypeScript SDK provides optional, opt-in support for authentication (AuthN) and authorization (AuthZ). For a comprehensive guide, see the [Authentication and Authorization guide](./auth.md). + +Quick example: + +```ts +const server = new McpServer({ name: 'my-server', version: '1.0.0' }, { + authenticator: new BearerTokenAuthenticator(async (token) => { + if (token === 'secret') return { token, clientId: 'admin', scopes: ['all'] }; + return undefined; + }) +}); + +server.tool('secure-tool', { scopes: ['all'] }, async (args) => { + return { content: [{ type: 'text', text: 'Success' }] }; +}); +``` + ## Deployment ### DNS rebinding protection diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index b82731582..90619b2d9 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -704,16 +704,24 @@ export abstract class Protocol { }; const _onmessage = this._transport?.onmessage; - this._transport.onmessage = (message, extra) => { - _onmessage?.(message, extra); - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - this._onresponse(message); - } else if (isJSONRPCRequest(message)) { - this._onrequest(message, extra); - } else if (isJSONRPCNotification(message)) { - this._onnotification(message); - } else { - this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + this._transport.onmessage = async (message, extra) => { + try { + if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { + await _onmessage?.(message, extra); + this._onresponse(message); + } else if (isJSONRPCRequest(message)) { + await this._onrequest(message, extra); + } else if (isJSONRPCNotification(message)) { + await this._onnotification(message); + } else { + await _onmessage?.(message, extra); + this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + } + } catch (error) { + if (error instanceof ProtocolError && (error.code === ProtocolErrorCode.Unauthorized || error.code === ProtocolErrorCode.Forbidden)) { + throw error; + } + this._onerror(error instanceof Error ? error : new Error(String(error))); } }; @@ -758,7 +766,7 @@ export abstract class Protocol { .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); } - private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { + protected async _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): Promise { const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; // Capture the current transport at request time to ensure responses go to the correct client @@ -838,7 +846,7 @@ export abstract class Protocol { const ctx = this.buildContext(baseCtx, extra); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() + return Promise.resolve() .then(() => { // If this request asked for task creation, check capability first if (taskCreationParams) { @@ -879,6 +887,10 @@ export abstract class Protocol { return; } + if (error instanceof ProtocolError && (error.message.includes('Unauthorized') || error.message.includes('Forbidden'))) { + throw error; + } + const errorResponse: JSONRPCErrorResponse = { jsonrpc: '2.0', id: request.id, @@ -903,7 +915,13 @@ export abstract class Protocol { : capturedTransport?.send(errorResponse)); } ) - .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) + .catch(error => { + if (error instanceof ProtocolError && (error.message.includes('Unauthorized') || error.message.includes('Forbidden'))) { + throw error; + } + // Do not report as protocol error if it's already an auth error we're escaping + this._onerror(new Error(`Failed to send response: ${error}`)); + }) .finally(() => { this._requestHandlerAbortControllers.delete(request.id); }); diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index a04e054ba..6241175ef 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -114,7 +114,7 @@ export interface Transport { * * The {@linkcode MessageExtraInfo.requestInfo | requestInfo} can be used to get the original request information (headers, etc.) */ - onmessage?: (message: T, extra?: MessageExtraInfo) => void; + onmessage?: (message: T, extra?: MessageExtraInfo) => void | Promise; /** * The session ID generated for this connection. diff --git a/packages/core/src/types/types.ts b/packages/core/src/types/types.ts index 6ac79777b..a15496a5e 100644 --- a/packages/core/src/types/types.ts +++ b/packages/core/src/types/types.ts @@ -233,6 +233,8 @@ export enum ProtocolErrorCode { InternalError = -32_603, // MCP-specific error codes + Unauthorized = 401, + Forbidden = 403, ResourceNotFound = -32_002, UrlElicitationRequired = -32_042 } diff --git a/packages/middleware/express/src/express.ts b/packages/middleware/express/src/express.ts index 252502952..989827f56 100644 --- a/packages/middleware/express/src/express.ts +++ b/packages/middleware/express/src/express.ts @@ -2,6 +2,8 @@ import type { Express } from 'express'; import express from 'express'; import { hostHeaderValidation, localhostHostValidation } from './middleware/hostHeaderValidation.js'; +export { auth } from './middleware/auth.js'; +export type { AuthMiddlewareOptions } from './middleware/auth.js'; /** * Options for creating an MCP Express application. diff --git a/packages/middleware/express/src/index.ts b/packages/middleware/express/src/index.ts index 2d7d20a64..e7dbbd229 100644 --- a/packages/middleware/express/src/index.ts +++ b/packages/middleware/express/src/index.ts @@ -1,2 +1,3 @@ export * from './express.js'; +export { auth } from './middleware/auth.js'; export * from './middleware/hostHeaderValidation.js'; diff --git a/packages/middleware/express/src/middleware/auth.ts b/packages/middleware/express/src/middleware/auth.ts new file mode 100644 index 000000000..cee89e759 --- /dev/null +++ b/packages/middleware/express/src/middleware/auth.ts @@ -0,0 +1,58 @@ +import { Request, Response, NextFunction, RequestHandler } from 'express'; +import { Authenticator, AuthInfo } from '@modelcontextprotocol/server'; + +/** + * Options for the MCP Express authentication middleware. + */ +export interface AuthMiddlewareOptions { + /** + * The authenticator to use for validating requests. + */ + authenticator: Authenticator; +} + +/** + * Creates an Express middleware for MCP authentication. + * + * This middleware extracts authentication information from the request using the provided authenticator + * and attaches it to the request object as `req.auth`. The MCP Express transport will then + * pick up this information automatically. + * + * @param options - Middleware options + * @returns An Express middleware function + * + * @example + * ```ts + * const authenticator = new BearerTokenAuthenticator((token) => Promise.resolve({ token, clientId: 'user', scopes: ['read'] })); + * app.use(auth({ authenticator })); + * ``` + */ +export function auth(options: AuthMiddlewareOptions): RequestHandler { + return async (req: Request & { auth?: AuthInfo }, res: Response, next: NextFunction) => { + try { + const headers: Record = {}; + for (const [key, value] of Object.entries(req.headers)) { + if (typeof value === 'string') { + headers[key] = value; + } else if (Array.isArray(value)) { + headers[key] = value.join(', '); + } + } + + const authInfo = await options.authenticator.authenticate({ + method: req.method, + headers, + }); + if (authInfo) { + req.auth = authInfo; + } + next(); + } catch (error) { + // If authentication fails, we let the MCP server handle it later, + // or the developer can choose to reject here. + // By default, we just proceed to allow the MCP server to decide (e.g., if auth is optional). + console.error('[MCP Express Auth Middleware] Authentication failed:', error); + next(); + } + }; +} diff --git a/packages/middleware/hono/src/hono.ts b/packages/middleware/hono/src/hono.ts index eda3e5d8f..01e14333a 100644 --- a/packages/middleware/hono/src/hono.ts +++ b/packages/middleware/hono/src/hono.ts @@ -2,6 +2,8 @@ import type { Context } from 'hono'; import { Hono } from 'hono'; import { hostHeaderValidation, localhostHostValidation } from './middleware/hostHeaderValidation.js'; +export { auth } from './middleware/auth.js'; +export type { AuthMiddlewareOptions } from './middleware/auth.js'; /** * Options for creating an MCP Hono application. diff --git a/packages/middleware/hono/src/index.ts b/packages/middleware/hono/src/index.ts index a8c65a2e9..02a72075b 100644 --- a/packages/middleware/hono/src/index.ts +++ b/packages/middleware/hono/src/index.ts @@ -1,2 +1,3 @@ export * from './hono.js'; +export { auth } from './middleware/auth.js'; export * from './middleware/hostHeaderValidation.js'; diff --git a/packages/middleware/hono/src/middleware/auth.ts b/packages/middleware/hono/src/middleware/auth.ts new file mode 100644 index 000000000..3c76583e2 --- /dev/null +++ b/packages/middleware/hono/src/middleware/auth.ts @@ -0,0 +1,58 @@ +import { Context, Next } from 'hono'; +import { Authenticator, AuthInfo } from '@modelcontextprotocol/server'; + +/** + * Options for the MCP Hono authentication middleware. + */ +export interface AuthMiddlewareOptions { + /** + * The authenticator to use for validating requests. + */ + authenticator: Authenticator; +} + +/** + * Creates a Hono middleware for MCP authentication. + * + * This middleware extracts authentication information from the raw request using the provided authenticator + * and attaches it to the Hono context as `mcpAuthInfo`. + * + * @param options - Middleware options + * @returns A Hono middleware function + * + * @example + * ```ts + * const authenticator = new BearerTokenAuthenticator({ + * validate: async (token) => ({ name: 'user', scopes: ['read'] }) + * }); + * app.use('/mcp/*', auth({ authenticator })); + * + * app.all('/mcp', async (c) => { + * const authInfo = c.get('mcpAuthInfo'); + * return transport.handleRequest(c.req.raw, { authInfo }); + * }); + * ``` + */ +export function auth(options: AuthMiddlewareOptions) { + return async (c: Context, next: Next) => { + try { + const headers: Record = {}; + c.req.raw.headers.forEach((v, k) => { + headers[k] = v; + }); + + const authInfo = await options.authenticator.authenticate({ + method: c.req.method, + headers, + }); + + if (authInfo) { + c.set('mcpAuthInfo', authInfo); + } + await next(); + } catch (error) { + // Proceed to allow MCP server to handle it or if auth is optional. + await next(); + } + }; +} diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index c1558e445..8d6fc3802 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -22,6 +22,7 @@ interface McpServerInternal { inputSchema: AnySchema | undefined, outputSchema: AnySchema | undefined, annotations: ToolAnnotations | undefined, + scopes: string[] | undefined, execution: ToolExecution | undefined, _meta: Record | undefined, handler: AnyToolHandler @@ -83,6 +84,7 @@ export class ExperimentalMcpServerTasks { description?: string; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + scopes?: string[]; execution?: TaskToolExecution; _meta?: Record; }, @@ -97,6 +99,7 @@ export class ExperimentalMcpServerTasks { inputSchema: InputArgs; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + scopes?: string[]; execution?: TaskToolExecution; _meta?: Record; }, @@ -111,6 +114,7 @@ export class ExperimentalMcpServerTasks { inputSchema?: InputArgs; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + scopes?: string[]; execution?: TaskToolExecution; _meta?: Record; }, @@ -131,6 +135,7 @@ export class ExperimentalMcpServerTasks { config.inputSchema, config.outputSchema, config.annotations, + config.scopes, execution, config._meta, handler as AnyToolHandler diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 1a8dbf143..cd6b71249 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,3 +1,6 @@ +export * from './server/auth/authenticator.js'; +export * from './server/auth/bearer.js'; +export * from './server/auth/authorizer.js'; export * from './server/completable.js'; export * from './server/mcp.js'; export * from './server/middleware/hostHeaderValidation.js'; diff --git a/packages/server/src/server/auth/authenticator.ts b/packages/server/src/server/auth/authenticator.ts new file mode 100644 index 000000000..3c258820c --- /dev/null +++ b/packages/server/src/server/auth/authenticator.ts @@ -0,0 +1,39 @@ +import { AuthInfo, RequestId } from '@modelcontextprotocol/core'; + +/** + * Interface for authenticating MCP requests. + */ +export interface Authenticator { + /** + * Authenticates an incoming request. + * + * @param request - Information about the request being made, including headers if available. + * @returns Information about the authenticated entity, or `undefined` if authentication failed. + */ + authenticate(request: AuthenticateRequest): Promise; + + /** + * Returns the name of the authentication scheme (e.g., 'Bearer'). + */ + readonly scheme: string; +} + +/** + * Information provided to authenticators to validate a request. + */ +export interface AuthenticateRequest { + /** + * The JSON-RPC ID of the request. + */ + requestId?: RequestId; + + /** + * The method being called. + */ + method?: string; + + /** + * Any headers associated with the request (e.g., from an HTTP transport). + */ + headers?: Record; +} diff --git a/packages/server/src/server/auth/authorizer.ts b/packages/server/src/server/auth/authorizer.ts new file mode 100644 index 000000000..4657f0c12 --- /dev/null +++ b/packages/server/src/server/auth/authorizer.ts @@ -0,0 +1,26 @@ +import { AuthInfo } from '@modelcontextprotocol/core'; + +/** + * Validates if the given `AuthInfo` has the required scopes. + */ +export class Authorizer { + /** + * Checks if the authenticated entity is authorized based on required scopes. + * + * @param authInfo - Information about the authenticated entity. + * @param requiredScopes - Scopes required for the operation. + * @returns `true` if authorized, `false` otherwise. + */ + static isAuthorized(authInfo: AuthInfo | undefined, requiredScopes: string[] | undefined): boolean { + if (!requiredScopes || requiredScopes.length === 0) { + return true; + } + + if (!authInfo || !authInfo.scopes) { + return false; + } + + // All required scopes must be present in the authInfo's scopes. + return requiredScopes.every(scope => authInfo.scopes?.includes(scope)); + } +} diff --git a/packages/server/src/server/auth/bearer.ts b/packages/server/src/server/auth/bearer.ts new file mode 100644 index 000000000..88f5ce63b --- /dev/null +++ b/packages/server/src/server/auth/bearer.ts @@ -0,0 +1,52 @@ +import { AuthInfo } from '@modelcontextprotocol/core'; +import { AuthenticateRequest, Authenticator } from './authenticator.js'; + +/** + * Validates a Bearer token. + * + * @param token - The Bearer token to validate. + * @returns Information about the authenticated entity, or `undefined` if validation failed. + */ +export type BearerTokenValidator = (token: string) => Promise; + +/** + * An authenticator for the Bearer authentication scheme. + */ +export class BearerTokenAuthenticator implements Authenticator { + /** + * Creates a new `BearerTokenAuthenticator` with the given validator. + * + * @param _validator - Function to validate the Bearer token. + */ + constructor(private readonly _validator: BearerTokenValidator) {} + + /** + * Returns the name of the authentication scheme. + */ + readonly scheme = 'Bearer'; + + /** + * Authenticates an incoming request by extracting the Bearer token from the `Authorization` header. + * + * @param request - Information about the request. + * @returns Information about the authenticated entity, or `undefined` if authentication failed. + */ + async authenticate(request: AuthenticateRequest): Promise { + const authHeader = request.headers?.['authorization']; + if (!authHeader) { + return undefined; + } + + const match = authHeader.match(/^Bearer\s+(.+)$/i); + if (!match) { + return undefined; + } + + const token = match[1]; + if (!token) { + return undefined; + } + + return await this._validator(token); + } +} diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 05136f5b6..a8487585a 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -44,8 +44,9 @@ import { validateAndWarnToolName } from '@modelcontextprotocol/core'; -import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; +import { Authenticator } from './auth/authenticator.js'; +import { Authorizer } from './auth/authorizer.js'; import { getCompleter, isCompletable } from './completable.js'; import type { ServerOptions } from './server.js'; import { Server } from './server.js'; @@ -81,6 +82,13 @@ export class McpServer { this.server = new Server(serverInfo, options); } + /** + * Returns the authenticator for this server, if one was provided. + */ + get authenticator(): Authenticator | undefined { + return this.server.authenticator; + } + /** * Access experimental features. * @@ -200,6 +208,13 @@ export class McpServer { return await this.handleAutomaticTaskPolling(tool, request, ctx); } + // Authorization check + if (tool.scopes && tool.scopes.length > 0) { + if (!ctx.http?.authInfo || !Authorizer.isAuthorized(ctx.http.authInfo, tool.scopes)) { + throw new ProtocolError(ProtocolErrorCode.Forbidden, 'Forbidden'); + } + } + // Normal execution path const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); const result = await this.executeToolHandler(tool, args, ctx); @@ -213,8 +228,14 @@ export class McpServer { await this.validateToolOutput(tool, result, request.params.name); return result; } catch (error) { - if (error instanceof ProtocolError && error.code === ProtocolErrorCode.UrlElicitationRequired) { - throw error; // Return the error to the caller without wrapping in CallToolResult + if (error instanceof ProtocolError) { + if ( + error.code === ProtocolErrorCode.Forbidden || + error.code === ProtocolErrorCode.Unauthorized || + error.code === ProtocolErrorCode.UrlElicitationRequired + ) { + throw error; + } } return this.createToolError(error instanceof Error ? error.message : String(error)); } @@ -258,7 +279,7 @@ export class McpServer { const parseResult = await parseSchemaAsync(tool.inputSchema, args ?? {}); if (!parseResult.success) { - const errorMessage = parseResult.error.issues.map((i: { message: string }) => i.message).join(', '); + const errorMessage = (parseResult as any).error.issues.map((i: { message: string }) => i.message).join(', '); throw new ProtocolError( ProtocolErrorCode.InvalidParams, `Input validation error: Invalid arguments for tool ${toolName}: ${errorMessage}` @@ -295,7 +316,7 @@ export class McpServer { // if the tool has an output schema, validate structured content const parseResult = await parseSchemaAsync(tool.outputSchema, result.structuredContent); if (!parseResult.success) { - const errorMessage = parseResult.error.issues.map((i: { message: string }) => i.message).join(', '); + const errorMessage = (parseResult as { error: any }).error.issues.map((i: { message: string }) => i.message).join(', '); throw new ProtocolError( ProtocolErrorCode.InvalidParams, `Output validation error: Invalid structured content for tool ${toolName}: ${errorMessage}` @@ -495,6 +516,14 @@ export class McpServer { if (!resource.enabled) { throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Resource ${uri} disabled`); } + + // Authorization check + if (resource.scopes && resource.scopes.length > 0) { + if (!ctx.http?.authInfo || !Authorizer.isAuthorized(ctx.http.authInfo, resource.scopes)) { + throw new ProtocolError(ProtocolErrorCode.Forbidden, 'Forbidden'); + } + } + return resource.readCallback(uri, ctx); } @@ -502,6 +531,13 @@ export class McpServer { for (const template of Object.values(this._registeredResourceTemplates)) { const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); if (variables) { + // Authorization check + if (template.scopes && template.scopes.length > 0) { + if (!ctx.http?.authInfo || !Authorizer.isAuthorized(ctx.http.authInfo, template.scopes)) { + throw new ProtocolError(ProtocolErrorCode.Forbidden, 'Forbidden'); + } + } + return template.readCallback(uri, variables, ctx); } } @@ -554,6 +590,13 @@ export class McpServer { throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); } + // Authorization check + if (prompt.scopes && prompt.scopes.length > 0) { + if (!ctx.http?.authInfo || !Authorizer.isAuthorized(ctx.http.authInfo, prompt.scopes)) { + throw new ProtocolError(ProtocolErrorCode.Forbidden, 'Forbidden'); + } + } + // Handler encapsulates parsing and callback invocation with proper types return prompt.handler(request.params.arguments, ctx); }); @@ -580,19 +623,25 @@ export class McpServer { * ); * ``` */ - registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: string, + config: ResourceMetadata & { scopes?: string[] }, + readCallback: ReadResourceCallback + ): RegisteredResource; registerResource( name: string, uriOrTemplate: ResourceTemplate, - config: ResourceMetadata, + config: ResourceMetadata & { scopes?: string[] }, readCallback: ReadResourceTemplateCallback ): RegisteredResourceTemplate; registerResource( name: string, uriOrTemplate: string | ResourceTemplate, - config: ResourceMetadata, + config: ResourceMetadata & { scopes?: string[] }, readCallback: ReadResourceCallback | ReadResourceTemplateCallback ): RegisteredResource | RegisteredResourceTemplate { + const { scopes, ...metadata } = config; if (typeof uriOrTemplate === 'string') { if (this._registeredResources[uriOrTemplate]) { throw new Error(`Resource ${uriOrTemplate} is already registered`); @@ -602,7 +651,8 @@ export class McpServer { name, (config as BaseMetadata).title, uriOrTemplate, - config, + metadata, + scopes, readCallback as ReadResourceCallback ); @@ -618,7 +668,8 @@ export class McpServer { name, (config as BaseMetadata).title, uriOrTemplate, - config, + metadata, + scopes, readCallback as ReadResourceTemplateCallback ); @@ -633,6 +684,7 @@ export class McpServer { title: string | undefined, uri: string, metadata: ResourceMetadata | undefined, + scopes: string[] | undefined, readCallback: ReadResourceCallback ): RegisteredResource { const registeredResource: RegisteredResource = { @@ -652,6 +704,7 @@ export class McpServer { if (updates.name !== undefined) registeredResource.name = updates.name; if (updates.title !== undefined) registeredResource.title = updates.title; if (updates.metadata !== undefined) registeredResource.metadata = updates.metadata; + if (updates.scopes !== undefined) registeredResource.scopes = updates.scopes; if (updates.callback !== undefined) registeredResource.readCallback = updates.callback; if (updates.enabled !== undefined) registeredResource.enabled = updates.enabled; this.sendResourceListChanged(); @@ -666,12 +719,14 @@ export class McpServer { title: string | undefined, template: ResourceTemplate, metadata: ResourceMetadata | undefined, + scopes: string[] | undefined, readCallback: ReadResourceTemplateCallback ): RegisteredResourceTemplate { const registeredResourceTemplate: RegisteredResourceTemplate = { resourceTemplate: template, title, metadata, + scopes, readCallback, enabled: true, disable: () => registeredResourceTemplate.update({ enabled: false }), @@ -685,6 +740,7 @@ export class McpServer { if (updates.title !== undefined) registeredResourceTemplate.title = updates.title; if (updates.template !== undefined) registeredResourceTemplate.resourceTemplate = updates.template; if (updates.metadata !== undefined) registeredResourceTemplate.metadata = updates.metadata; + if (updates.scopes !== undefined) registeredResourceTemplate.scopes = updates.scopes; if (updates.callback !== undefined) registeredResourceTemplate.readCallback = updates.callback; if (updates.enabled !== undefined) registeredResourceTemplate.enabled = updates.enabled; this.sendResourceListChanged(); @@ -707,6 +763,7 @@ export class McpServer { title: string | undefined, description: string | undefined, argsSchema: AnySchema | undefined, + scopes: string[] | undefined, callback: PromptCallback ): RegisteredPrompt { // Track current schema and callback for handler regeneration @@ -717,6 +774,7 @@ export class McpServer { title, description, argsSchema, + scopes, handler: createPromptHandler(name, argsSchema, callback), enabled: true, disable: () => registeredPrompt.update({ enabled: false }), @@ -746,6 +804,7 @@ export class McpServer { } if (updates.enabled !== undefined) registeredPrompt.enabled = updates.enabled; + if (updates.scopes !== undefined) registeredPrompt.scopes = updates.scopes; this.sendPromptListChanged(); } }; @@ -775,6 +834,7 @@ export class McpServer { inputSchema: AnySchema | undefined, outputSchema: AnySchema | undefined, annotations: ToolAnnotations | undefined, + scopes: string[] | undefined, execution: ToolExecution | undefined, _meta: Record | undefined, handler: AnyToolHandler @@ -792,6 +852,7 @@ export class McpServer { outputSchema, annotations, execution, + scopes, _meta, handler: handler, executor: createToolExecutor(inputSchema, handler), @@ -874,6 +935,7 @@ export class McpServer { inputSchema?: InputArgs; outputSchema?: OutputArgs; annotations?: ToolAnnotations; + scopes?: string[]; _meta?: Record; }, cb: ToolCallback @@ -882,7 +944,7 @@ export class McpServer { throw new Error(`Tool ${name} is already registered`); } - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; + const { title, description, inputSchema, outputSchema, annotations, scopes, _meta } = config; return this._createRegisteredTool( name, @@ -891,6 +953,7 @@ export class McpServer { inputSchema, outputSchema, annotations, + scopes, { taskSupport: 'forbidden' }, _meta, cb as ToolCallback @@ -929,6 +992,7 @@ export class McpServer { title?: string; description?: string; argsSchema?: Args; + scopes?: string[]; }, cb: PromptCallback ): RegisteredPrompt { @@ -936,13 +1000,14 @@ export class McpServer { throw new Error(`Prompt ${name} is already registered`); } - const { title, description, argsSchema } = config; + const { title, description, argsSchema, scopes } = config; const registeredPrompt = this._createRegisteredPrompt( name, title, description, argsSchema, + scopes, cb as PromptCallback ); @@ -1073,10 +1138,17 @@ export type BaseToolCallback = BaseToolCallback; +/** + * Handler for a tool that creates a task. + */ +export type McpToolTaskHandler = { + createTask: BaseToolCallback; +}; + /** * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). */ -export type AnyToolHandler = ToolCallback | ToolTaskHandler; +export type AnyToolHandler = ToolCallback | McpToolTaskHandler; /** * Internal executor type that encapsulates handler invocation with proper types. @@ -1091,6 +1163,7 @@ export type RegisteredTool = { annotations?: ToolAnnotations; execution?: ToolExecution; _meta?: Record; + scopes?: string[]; handler: AnyToolHandler; /** @hidden */ executor: ToolExecutor; @@ -1104,6 +1177,7 @@ export type RegisteredTool = { paramsSchema?: AnySchema; outputSchema?: AnySchema; annotations?: ToolAnnotations; + scopes?: string[]; _meta?: Record; callback?: ToolCallback; enabled?: boolean; @@ -1168,6 +1242,7 @@ export type RegisteredResource = { name: string; title?: string; metadata?: ResourceMetadata; + scopes?: string[]; readCallback: ReadResourceCallback; enabled: boolean; enable(): void; @@ -1177,6 +1252,7 @@ export type RegisteredResource = { title?: string; uri?: string | null; metadata?: ResourceMetadata; + scopes?: string[]; callback?: ReadResourceCallback; enabled?: boolean; }): void; @@ -1196,6 +1272,7 @@ export type RegisteredResourceTemplate = { resourceTemplate: ResourceTemplate; title?: string; metadata?: ResourceMetadata; + scopes?: string[]; readCallback: ReadResourceTemplateCallback; enabled: boolean; enable(): void; @@ -1205,6 +1282,7 @@ export type RegisteredResourceTemplate = { title?: string; template?: ResourceTemplate; metadata?: ResourceMetadata; + scopes?: string[]; callback?: ReadResourceTemplateCallback; enabled?: boolean; }): void; @@ -1231,6 +1309,7 @@ export type RegisteredPrompt = { title?: string; description?: string; argsSchema?: AnySchema; + scopes?: string[]; /** @hidden */ handler: PromptHandler; enabled: boolean; @@ -1241,6 +1320,7 @@ export type RegisteredPrompt = { title?: string; description?: string; argsSchema?: Args; + scopes?: string[]; callback?: PromptCallback; enabled?: boolean; }): void; @@ -1262,10 +1342,10 @@ function createPromptHandler( return async (args, ctx) => { const parseResult = await parseSchemaAsync(argsSchema, args); if (!parseResult.success) { - const errorMessage = parseResult.error.issues.map((i: { message: string }) => i.message).join(', '); + const errorMessage = parseResult.error.issues.map(i => i.message).join(', '); throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid arguments for prompt ${name}: ${errorMessage}`); } - return typedCallback(parseResult.data as SchemaOutput, ctx); + return typedCallback(parseResult.data, ctx); }; } else { const typedCallback = callback as (ctx: ServerContext) => GetPromptResult | Promise; diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 00d3e6f52..c8723599b 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -12,6 +12,9 @@ import type { Implementation, InitializeRequest, InitializeResult, + JSONRPCErrorResponse, + JSONRPCRequest, + JSONRPCResultResponse, JsonSchemaType, jsonSchemaValidator, ListRootsRequest, @@ -53,8 +56,8 @@ import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; +import { Authenticator } from './auth/authenticator.js'; import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/server/_shims'; - import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; export type ServerOptions = ProtocolOptions & { @@ -77,6 +80,11 @@ export type ServerOptions = ProtocolOptions & { * @default {@linkcode DefaultJsonSchemaValidator} ({@linkcode index.AjvJsonSchemaValidator | AjvJsonSchemaValidator} on Node.js, {@linkcode index.CfWorkerJsonSchemaValidator | CfWorkerJsonSchemaValidator} on Cloudflare Workers) */ jsonSchemaValidator?: jsonSchemaValidator; + + /** + * Optional authenticator for incoming requests. + */ + authenticator?: Authenticator; }; /** @@ -92,6 +100,7 @@ export class Server extends Protocol { private _capabilities: ServerCapabilities; private _instructions?: string; private _jsonSchemaValidator: jsonSchemaValidator; + private _authenticator?: Authenticator; private _experimental?: { tasks: ExperimentalServerTasks }; /** @@ -110,6 +119,7 @@ export class Server extends Protocol { this._capabilities = options?.capabilities ?? {}; this._instructions = options?.instructions; this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); + this._authenticator = options?.authenticator; this.setRequestHandler('initialize', request => this._oninitialize(request)); this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); @@ -119,6 +129,13 @@ export class Server extends Protocol { } } + /** + * Returns the authenticator for this server, if one was provided. + */ + get authenticator(): Authenticator | undefined { + return this._authenticator; + } + private _registerLoggingHandler(): void { this.setRequestHandler('logging/setLevel', async (request, ctx) => { const transportSessionId: string | undefined = @@ -148,6 +165,7 @@ export class Server extends Protocol { ? { ...ctx.http, req: transportInfo?.requestInfo, + authInfo: transportInfo?.authInfo, closeSSE: transportInfo?.closeSSEStream, closeStandaloneSSE: transportInfo?.closeStandaloneSSEStream } @@ -155,6 +173,26 @@ export class Server extends Protocol { }; } + protected override async _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): Promise { + if (this._authenticator && request.method !== 'initialize' && request.method !== 'ping') { + const authInfo = await this._authenticator.authenticate({ + headers: Object.fromEntries(extra?.requestInfo?.headers.entries() ?? []) + }); + + if (!authInfo) { + throw new ProtocolError(ProtocolErrorCode.Unauthorized, 'Unauthorized'); + } + + // Inject authInfo into extra for buildContext + if (!extra) { + extra = {}; + } + extra.authInfo = authInfo; + } + + return super._onrequest(request, extra); + } + /** * Access experimental features. * @@ -210,23 +248,26 @@ export class Server extends Protocol { const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise => { const validatedRequest = parseSchema(CallToolRequestSchema, request); if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + const errorMessage = validatedRequest.error.issues.map(i => i.message).join(', '); throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); } const { params } = validatedRequest.data; - - const result = await Promise.resolve(handler(request, ctx)); + let result: ResultTypeMap[M]; + try { + result = await Promise.resolve(handler(request, ctx)); + } catch (error) { + if (error instanceof ProtocolError) { + throw error; + } + throw error; + } // When task creation is requested, validate and return CreateTaskResult if (params.task) { const taskValidationResult = parseSchema(CreateTaskResultSchema, result); if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); + const errorMessage = taskValidationResult.error.issues.map(i => i.message).join(', '); throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); } return taskValidationResult.data; @@ -235,8 +276,7 @@ export class Server extends Protocol { // For non-task requests, validate against CallToolResultSchema const validationResult = parseSchema(CallToolResultSchema, result); if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + const errorMessage = validationResult.error.issues.map(i => i.message).join(', '); throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); } @@ -506,11 +546,12 @@ export class Server extends Protocol { // These may appear even without tools/toolChoice in the current request when // a previous sampling request returned tool_use and this is a follow-up with results. if (params.messages.length > 0) { - const lastMessage = params.messages.at(-1)!; - const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; - const hasToolResults = lastContent.some(c => c.type === 'tool_result'); + const lastMessage = params.messages[params.messages.length - 1]; + if (lastMessage && lastMessage.content) { + const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; + const hasToolResults = lastContent.some(c => c.type === 'tool_result'); - const previousMessage = params.messages.length > 1 ? params.messages.at(-2) : undefined; + const previousMessage = params.messages.length > 1 ? params.messages[params.messages.length - 2] : undefined; const previousContent = previousMessage ? Array.isArray(previousMessage.content) ? previousMessage.content @@ -537,12 +578,13 @@ export class Server extends Protocol { const toolResultIds = new Set( lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId) ); - if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { + if (toolUseIds.size !== toolResultIds.size || !Array.from(toolUseIds).every(id => toolResultIds.has(id))) { throw new ProtocolError( ProtocolErrorCode.InvalidParams, 'ids of tool_result blocks and tool_use blocks from previous message do not match' ); } + } } } @@ -561,54 +603,47 @@ export class Server extends Protocol { * @returns The result of the elicitation request. */ async elicitInput(params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions): Promise { - const mode = (params.mode ?? 'form') as 'form' | 'url'; - - switch (mode) { - case 'url': { - if (!this._clientCapabilities?.elicitation?.url) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support url elicitation.'); - } + if (params.mode === 'url') { + if (!this._clientCapabilities?.elicitation?.url) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support url elicitation.'); + } - const urlParams = params as ElicitRequestURLParams; - return this._requestWithSchema({ method: 'elicitation/create', params: urlParams }, ElicitResultSchema, options); + return this._requestWithSchema({ method: 'elicitation/create', params }, ElicitResultSchema, options); + } else { + if (!this._clientCapabilities?.elicitation?.form) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support form elicitation.'); } - case 'form': { - if (!this._clientCapabilities?.elicitation?.form) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support form elicitation.'); - } - const formParams: ElicitRequestFormParams = - params.mode === 'form' ? (params as ElicitRequestFormParams) : { ...(params as ElicitRequestFormParams), mode: 'form' }; + const formParams: ElicitRequestFormParams = params.mode === 'form' ? params : { ...params, mode: 'form' }; - const result = await this._requestWithSchema( - { method: 'elicitation/create', params: formParams }, - ElicitResultSchema, - options - ); + const result = await this._requestWithSchema( + { method: 'elicitation/create', params: formParams }, + ElicitResultSchema, + options + ); - if (result.action === 'accept' && result.content && formParams.requestedSchema) { - try { - const validator = this._jsonSchemaValidator.getValidator(formParams.requestedSchema as JsonSchemaType); - const validationResult = validator(result.content); - - if (!validationResult.valid) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` - ); - } - } catch (error) { - if (error instanceof ProtocolError) { - throw error; - } + if (result.action === 'accept' && result.content && formParams.requestedSchema) { + try { + const validator = this._jsonSchemaValidator.getValidator(formParams.requestedSchema as JsonSchemaType); + const validationResult = validator(result.content); + + if (!validationResult.valid) { throw new ProtocolError( - ProtocolErrorCode.InternalError, - `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` + ProtocolErrorCode.InvalidParams, + `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` ); } + } catch (error: unknown) { + if (error instanceof ProtocolError) { + throw error; + } + throw new ProtocolError( + ProtocolErrorCode.InternalError, + `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` + ); } - return result; } + return result; } } diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 74e689892..7d277723c 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -15,6 +15,8 @@ import { isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, + ProtocolError, + ProtocolErrorCode, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; @@ -714,7 +716,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { if (this._enableJsonResponse) { // For JSON response mode, return a Promise that resolves when all responses are ready - return new Promise(resolve => { + return new Promise(async resolve => { this._streamMapping.set(streamId, { resolveJson: resolve, cleanup: () => { @@ -729,7 +731,21 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { } for (const message of messages) { - this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo }); + try { + await Promise.resolve(this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo })); + } catch (error) { + if (error instanceof ProtocolError) { + if (error.message.includes('Unauthorized')) { + resolve(this.createJsonErrorResponse(401, error.code, 'Unauthorized', { headers: { 'WWW-Authenticate': 'Bearer' } })); + return; + } + if (error.message.includes('Forbidden')) { + resolve(this.createJsonErrorResponse(403, error.code, error.message)); + return; + } + } + throw error; + } } }); } @@ -799,13 +815,48 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { }; } - this.onmessage?.(message, { authInfo: options?.authInfo, requestInfo, closeSSEStream, closeStandaloneSSEStream }); + try { + await Promise.resolve(this.onmessage?.(message, { + authInfo: options?.authInfo, + requestInfo, + closeSSEStream, + closeStandaloneSSEStream + })); + } catch (error) { + if (error instanceof ProtocolError) { + if (error.code === ProtocolErrorCode.Unauthorized) { + const response = this.createJsonErrorResponse(401, error.code, 'Unauthorized', { headers: { 'WWW-Authenticate': 'Bearer' } }); + this._streamMapping.delete(streamId); + return response; + } + if (error.code === ProtocolErrorCode.Forbidden) { + const response = this.createJsonErrorResponse(403, error.code, error.message); + this._streamMapping.delete(streamId); + return response; + } + if (error.code === ProtocolErrorCode.UrlElicitationRequired) { + throw error; + } + } + console.error('Transport caught error in onmessage:', error); + // Standard tools should return a CallToolResult with isError: true. + // For onmessage we only rethrow auth-related errors and UrlElicitationRequired. + throw error; + } } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses // This will be handled by the send() method when responses are ready return new Response(readable, { status: 200, headers }); } catch (error) { + if (error instanceof ProtocolError) { + if (error.code === ProtocolErrorCode.Unauthorized) { + return this.createJsonErrorResponse(401, error.code, 'Unauthorized', { headers: { 'WWW-Authenticate': 'Bearer' } }); + } + if (error.code === ProtocolErrorCode.Forbidden) { + return this.createJsonErrorResponse(403, error.code, error.message); + } + } // return JSON-RPC formatted error this.onerror?.(error as Error); return this.createJsonErrorResponse(400, -32_700, 'Parse error', { data: String(error) }); @@ -888,8 +939,8 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { async close(): Promise { // Close all SSE connections - for (const { cleanup } of this._streamMapping.values()) { - cleanup(); + for (const mapping of Array.from(this._streamMapping.values())) { + mapping.cleanup(); } this._streamMapping.clear(); @@ -982,7 +1033,9 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { this._requestResponseMap.set(requestId, message); - const relatedIds = [...this._requestToStreamMapping.entries()].filter(([_, sid]) => sid === streamId).map(([id]) => id); + const relatedIds = Array.from(this._requestToStreamMapping.entries()) + .filter(([, sid]) => sid === streamId) + .map(([id]) => id); // Check if we have responses for all requests using this connection const allResponsesReady = relatedIds.every(id => this._requestResponseMap.has(id)); diff --git a/packages/server/test/server/auth.test.ts b/packages/server/test/server/auth.test.ts new file mode 100644 index 000000000..990c5968e --- /dev/null +++ b/packages/server/test/server/auth.test.ts @@ -0,0 +1,76 @@ +import { describe, it, expect } from "vitest"; +import { BearerTokenAuthenticator } from "../../src/server/auth/bearer.js"; +import { Authorizer } from "../../src/server/auth/authorizer.js"; + +describe("BearerTokenAuthenticator", () => { + it("should authenticate with a valid token", async () => { + const authenticator = new BearerTokenAuthenticator(async (token) => { + if (token === "valid-token") { + return { name: "test-user", token: "valid-token", clientId: "test-client", scopes: ["read"] }; + } + return undefined; + }); + + const authInfo = await authenticator.authenticate({ + requestId: 1, + method: "test", + headers: { + authorization: "Bearer valid-token", + }, + }); + expect(authInfo).toEqual({ name: "test-user", token: "valid-token", clientId: "test-client", scopes: ["read"] }); + }); + + it("should return undefined with an invalid token", async () => { + const authenticator = new BearerTokenAuthenticator(async (_) => undefined); + + const authInfo = await authenticator.authenticate({ + requestId: 1, + method: "test", + headers: { + authorization: "Bearer invalid-token", + }, + }); + expect(authInfo).toBeUndefined(); + }); + + it("should return undefined when Authorization header is missing", async () => { + const authenticator = new BearerTokenAuthenticator(async (_) => ({ + name: "test-user", + token: "test-token", + clientId: "test-client", + scopes: [], + })); + + const authInfo = await authenticator.authenticate({ + requestId: 1, + method: "test", + headers: {}, + }); + expect(authInfo).toBeUndefined(); + }); +}); + +describe("Authorizer", () => { + it("should authorize when no scopes are required", () => { + const authInfo = { name: "test-user", token: "test-token", clientId: "test-client", scopes: [] }; + expect(Authorizer.isAuthorized(authInfo, undefined)).toBe(true); + expect(Authorizer.isAuthorized(authInfo, [])).toBe(true); + }); + + it("should authorize when all required scopes are present", () => { + const authInfo = { name: "test-user", token: "test-token", clientId: "test-client", scopes: ["read", "write"] }; + expect(Authorizer.isAuthorized(authInfo, ["read"])).toBe(true); + expect(Authorizer.isAuthorized(authInfo, ["read", "write"])).toBe(true); + }); + + it("should not authorize when a required scope is missing", () => { + const authInfo = { name: "test-user", token: "test-token", clientId: "test-client", scopes: ["read"] }; + expect(Authorizer.isAuthorized(authInfo, ["write"])).toBe(false); + expect(Authorizer.isAuthorized(authInfo, ["read", "write"])).toBe(false); + }); + + it("should not authorize if authInfo is missing but scopes are required", () => { + expect(Authorizer.isAuthorized(undefined, ["read"])).toBe(false); + }); +}); diff --git a/packages/server/test/server/auth_integration.test.ts b/packages/server/test/server/auth_integration.test.ts new file mode 100644 index 000000000..1f99c8981 --- /dev/null +++ b/packages/server/test/server/auth_integration.test.ts @@ -0,0 +1,158 @@ +import { describe, it, expect, beforeEach } from "vitest"; +import { McpServer } from "../../src/index.js"; +import { WebStandardStreamableHTTPServerTransport } from "../../src/index.js"; +import { BearerTokenAuthenticator } from "../../src/index.js"; +import { randomUUID } from "node:crypto"; + +describe("Auth Integration", () => { + let server: McpServer; + let transport: WebStandardStreamableHTTPServerTransport; + let sessionId: string; + + const TEST_MESSAGES = { + initialize: { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-11-25", + capabilities: {}, + }, + id: "init-1", + }, + publicTool: { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "public", arguments: {} }, + id: "call-public", + }, + privateTool: { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "private", arguments: {} }, + id: "call-private", + }, + }; + + beforeEach(async () => { + const authenticator = new BearerTokenAuthenticator(async (token) => { + if (token === "admin-token") { + return { token, clientId: "admin", scopes: ["admin", "read"] }; + } + if (token === "user-token") { + return { token, clientId: "user", scopes: ["read"] }; + } + return undefined; + }); + + server = new McpServer( + { name: "test-auth-server", version: "1.0.0" }, + { authenticator } + ); + + server.registerTool("public", {}, async () => ({ content: [{ type: "text", text: "public" }] })); + server.registerTool("private", { scopes: ["admin"] }, async () => ({ content: [{ type: "text", text: "private" }] })); + + transport = new WebStandardStreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + }); + + await server.connect(transport); + }); + + async function initialize(): Promise { + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + const response = await transport.handleRequest(request); + if (!response.ok) { + throw new Error(`Failed to initialize: ${response.status} ${await response.text()}`); + } + return response.headers.get("mcp-session-id")!; + } + + it("should return 401 for requests without a token", async () => { + sessionId = await initialize(); + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": sessionId, + }, + body: JSON.stringify(TEST_MESSAGES.publicTool), + }); + const response = await transport.handleRequest(request); + expect(response.status).toBe(401); + expect(response.headers.get("WWW-Authenticate")).toBe("Bearer"); + }); + + it("should return 401 for requests with an invalid token", async () => { + sessionId = await initialize(); + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": sessionId, + "Authorization": "Bearer invalid", + }, + body: JSON.stringify(TEST_MESSAGES.publicTool), + }); + const response = await transport.handleRequest(request); + expect(response.status).toBe(401); + }); + + it("should allow access to public tools with valid user token", async () => { + sessionId = await initialize(); + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": sessionId, + "Authorization": "Bearer user-token", + }, + body: JSON.stringify(TEST_MESSAGES.publicTool), + }); + const response = await transport.handleRequest(request); + expect(response.status).toBe(200); + }); + + it("should return 403 for private tools with insufficient scopes", async () => { + sessionId = await initialize(); + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": sessionId, + "Authorization": "Bearer user-token", + }, + body: JSON.stringify(TEST_MESSAGES.privateTool), + }); + const response = await transport.handleRequest(request); + expect(response.status).toBe(403); + }); + + it("should allow access to private tools with sufficient scopes", async () => { + sessionId = await initialize(); + const request = new Request("http://localhost/mcp", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": sessionId, + "Authorization": "Bearer admin-token", + }, + body: JSON.stringify(TEST_MESSAGES.privateTool), + }); + const response = await transport.handleRequest(request); + expect(response.status).toBe(200); + }); +});