From 68e4eb286c48caf5f0d81cae5bc88a163a51e92a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:17:26 +0000 Subject: [PATCH 01/14] Tighten requirements-catalog divergence notes to match current SDK behaviour - resources:annotations: drop the stale lastModified divergence; the model now carries the field and it round-trips. - lifecycle:capability:client-not-declared: narrow to the one remaining gap (the deprecated send_roots_list_changed path); the handler half is correct by construction. - lifecycle:pre-initialization-ordering: mark removed_in=2026-07-28; the initialize handshake is gone in the new spec. - client-transport:http:session-404-reinitialize: reword as an intentional cross-SDK choice (404 surfaced to caller), not a missed MUST. - test_resources: drop the now-redundant lastModified xfail. --- tests/interaction/_requirements.py | 23 ++++++++++---------- tests/interaction/lowlevel/test_resources.py | 6 ++--- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d376f0b9f..bf7f8ceec 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -178,13 +178,15 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "The client does not check its own declared capabilities before sending notifications or " - "serving callbacks; nothing prevents a caller from violating the spec's MUST." + "The handler half is correct by construction -- the client derives its declared " + "capabilities from the callbacks registered at construction, so it cannot serve a " + "capability it did not declare. Only the deprecated send_roots_list_changed notification " + "is ungated: a caller can send it without having registered a roots callback." ), ), deferred=( - "Not implemented in the SDK: the client does not check its own declared capabilities before " - "sending notifications or serving callbacks." + "Not implemented in the SDK: the deprecated send_roots_list_changed notification is not " + "gated on a declared roots capability." ), ), "lifecycle:capability:server-not-advertised": Requirement( @@ -300,6 +302,8 @@ def __post_init__(self) -> None: "Before initialization completes, the client sends no requests other than pings, and the " "server sends no requests other than pings and logging." ), + removed_in="2026-07-28", + note="initialize handshake removed at 2026-07-28; per-request _meta envelope replaces it.", divergence=Divergence( note=( "The server's send methods (create_message / elicit_form / list_roots) do not check " @@ -1071,12 +1075,6 @@ def __post_init__(self) -> None: "resources:annotations": Requirement( source=f"{SPEC_BASE_URL}/server/resources#annotations", behavior="Resource annotations supplied by the server round-trip to the client in the list result.", - divergence=Divergence( - note=( - "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " - "pydantic default extra='ignore', so the value is silently dropped on parse." - ), - ), ), "resources:capability:declared": Requirement( source=f"{SPEC_BASE_URL}/server/resources#capabilities", @@ -3163,8 +3161,9 @@ def __post_init__(self) -> None: transports=("streamable-http",), divergence=Divergence( note=( - "The client surfaces the 404 as an error to the caller instead of re-initializing a new " - "session; the spec's MUST is not satisfied." + "The 404 is intentionally surfaced to the caller; this matches the TypeScript, C#, and " + "Go SDKs. The 2025-11-25 MUST was removed in the 2026 spec (SEP-2567), and the transport " + "layer cannot safely re-initialize without replaying the caller's request." ), ), deferred=( diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 44ab33e64..a94278940 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -39,10 +39,8 @@ async def test_list_resources_returns_registered_resources(connect: Connect) -> None: """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. - The fully-populated entry includes annotations, so the snapshot also proves they round-trip. - The SDK's Annotations model omits the schema's lastModified field (see the divergence on - resources:annotations); the input is built via model_validate with lastModified set so the - snapshot pins the drop and will fail once the SDK adds the field. + The fully-populated entry includes annotations (audience, priority, last_modified), so the + snapshot also proves they round-trip. """ async def list_resources( From bad42e284177e920fabb909e469e83cd335dc325 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:18:11 +0000 Subject: [PATCH 02/14] Emit RFC 6750 scope= in WWW-Authenticate and validate token audience BearerAuthBackend / RequireAuthMiddleware now produce spec-conformant challenges and reject tokens issued for a different resource server. - A request with no credentials gets a bare `Bearer` challenge (with scope/resource_metadata only), not error="invalid_token" -- RFC 6750 Section 3.1 says the error attribute SHOULD NOT appear when no authentication information was presented. - A malformed/unknown token, an expired token, or a token whose audience does not match the configured resource_server_url is answered 401 invalid_token with a specific error_description, carried via a new InvalidTokenUser marker so the middleware can distinguish it from no-credentials. - All challenges (401 and the 403 insufficient_scope path) now advertise the required scopes in a `scope=` parameter, which the SDK client already reads to drive step-up. - New check_token_audience() helper canonicalises default ports before comparing, and is wired through both the lowlevel and MCPServer Starlette stacks via the auth settings' resource_server_url. Docs and migration guide updated; the corresponding interaction-suite divergence entries are now closed. --- docs/advanced/authorization.md | 9 +- docs/migration.md | 4 + src/mcp/server/auth/middleware/bearer_auth.py | 99 ++++++++++----- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/mcpserver/server.py | 4 +- src/mcp/shared/auth_utils.py | 28 ++++- tests/interaction/_requirements.py | 28 ----- tests/interaction/auth/_harness.py | 13 +- .../interaction/auth/test_authorize_token.py | 12 +- tests/interaction/auth/test_bearer.py | 117 ++++++++++++------ .../auth/middleware/test_bearer_auth.py | 77 +++++++++--- tests/shared/test_auth_utils.py | 19 ++- 12 files changed, 277 insertions(+), 135 deletions(-) diff --git a/docs/advanced/authorization.md b/docs/advanced/authorization.md index 2afb3d5a0..87ecc17b8 100644 --- a/docs/advanced/authorization.md +++ b/docs/advanced/authorization.md @@ -27,7 +27,7 @@ The SDK has no opinion about what a valid token looks like. You tell it, by impl `AuthSettings` is the public face of your resource server: * `issuer_url`: the authorization server that issues your tokens. -* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. +* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. When your verifier returns an `AccessToken.resource`, the SDK rejects the token unless it matches this URL — a token issued for a different resource never reaches a tool. * `required_scopes`: every token must carry all of them. !!! tip @@ -61,14 +61,13 @@ You registered one tool. The second route is the SDK's. This document is how a client that has never heard of your server finds its way in: it reads `authorization_servers` and goes there for a token. You wrote none of it. !!! check - Call `/mcp` with no token (or with one your verifier returned `None` for) and the request is - stopped at the door: + Call `/mcp` with no token and the request is stopped at the door: ```text HTTP/1.1 401 Unauthorized - WWW-Authenticate: Bearer error="invalid_token", error_description="Authentication required", resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + WWW-Authenticate: Bearer scope="notes:read", resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" - {"error": "invalid_token", "error_description": "Authentication required"} + {} ``` Nothing was parsed and no tool ran. And that `resource_metadata` pointer in `WWW-Authenticate` is diff --git a/docs/migration.md b/docs/migration.md index 42d420bf0..cab83f9d9 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1462,6 +1462,10 @@ issuer inconsistent with what clients compare against under RFC 8414 / RFC 9207. already-built `AnyHttpUrl` object still normalizes at construction; pass a string to get the preserved form. +### Bearer tokens with a mismatched audience are rejected + +`BearerAuthBackend` now compares `AccessToken.resource` against `AuthSettings.resource_server_url` and answers a token whose RFC 8707 resource indicator does not name this server with `401 invalid_token`. The check is canonical-URI equality, so a token issued for `https://host/` is not accepted by a server at `https://host/mcp`. It is skipped when either side is `None` — populate `AccessToken.resource` only when your verifier surfaces the underlying audience claim. `BearerAuthBackend.__init__` gains a keyword-only `resource_server_url: AnyHttpUrl | None = None`, wired automatically from `AuthSettings`; pass it only if you construct the backend directly. + ### Lowlevel `Server`: `subscribe` capability now correctly reported Previously, the lowlevel `Server` hardcoded `subscribe=False` in resource capabilities even when a `subscribe_resource()` handler was registered. The `subscribe` capability is now dynamically set to `True` when an `on_subscribe_resource` handler is provided. Clients that previously didn't see `subscribe: true` in capabilities will now see it when a handler is registered, which may change client behavior. diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index ba66e9422..d80e13f15 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -3,11 +3,12 @@ from typing import Any, TypedDict from pydantic import AnyHttpUrl -from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser +from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser, SimpleUser from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.shared.auth_utils import check_token_audience class AuthenticatedUser(SimpleUser): @@ -19,6 +20,27 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class InvalidTokenUser(BaseUser): + """Marker for a request that presented a Bearer token the verifier rejected, + that has expired, or whose audience does not match this resource server. + Carries the human-readable reason for the WWW-Authenticate error_description.""" + + def __init__(self, reason: str) -> None: + self.reason = reason + + @property + def is_authenticated(self) -> bool: + return False + + @property + def display_name(self) -> str: + return "" + + @property + def identity(self) -> str: + return "" + + class AuthorizationContext(TypedDict): client_id: str issuer: str | None @@ -46,27 +68,30 @@ def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" - def __init__(self, token_verifier: TokenVerifier): + def __init__(self, token_verifier: TokenVerifier, *, resource_server_url: AnyHttpUrl | None = None) -> None: self.token_verifier = token_verifier + self.resource_server_url = resource_server_url - async def authenticate(self, conn: HTTPConnection): + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: auth_header = next( (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), None, ) if not auth_header or not auth_header.lower().startswith("bearer "): - return None - - token = auth_header[7:] # Remove "Bearer " prefix + return None # no credentials presented → bare challenge per RFC 6750 §3 - # Validate the token with the verifier + token = auth_header[7:] auth_info = await self.token_verifier.verify_token(token) - - if not auth_info: - return None - - if auth_info.expires_at and auth_info.expires_at < int(time.time()): - return None + if auth_info is None: + return AuthCredentials(), InvalidTokenUser("The access token is malformed or unknown") + if auth_info.expires_at is not None and auth_info.expires_at < int(time.time()): + return AuthCredentials(), InvalidTokenUser("The access token has expired") + if ( + self.resource_server_url is not None + and auth_info.resource is not None + and not check_token_audience(auth_info.resource, self.resource_server_url) + ): + return AuthCredentials(), InvalidTokenUser("The access token was issued for a different resource") return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) @@ -97,35 +122,47 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: auth_user = scope.get("user") + if isinstance(auth_user, InvalidTokenUser): + await self._send_auth_error(send, status_code=401, error="invalid_token", description=auth_user.reason) + return if not isinstance(auth_user, AuthenticatedUser): - await self._send_auth_error( - send, status_code=401, error="invalid_token", description="Authentication required" - ) + await self._send_auth_error(send, status_code=401) return - auth_credentials = scope.get("auth") - + auth_credentials = scope["auth"] for required_scope in self.required_scopes: - # auth_credentials should always be provided; this is just paranoia - if auth_credentials is None or required_scope not in auth_credentials.scopes: + if required_scope not in auth_credentials.scopes: await self._send_auth_error( - send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}" + send, + status_code=403, + error="insufficient_scope", + description="The access token lacks a required scope", ) return await self.app(scope, receive, send) - async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None: - """Send an authentication error response with WWW-Authenticate header.""" - # Build WWW-Authenticate header value - www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] + async def _send_auth_error( + self, send: Send, *, status_code: int, error: str | None = None, description: str | None = None + ) -> None: + """Send a Bearer challenge. RFC 6750 §3: error/error_description only when a token + was presented; scope advertises what is required; resource_metadata for discovery.""" + parts: list[str] = [] + if error is not None: + parts.append(f'error="{error}"') + if description is not None: + parts.append(f'error_description="{description}"') + if self.required_scopes: + parts.append(f'scope="{" ".join(self.required_scopes)}"') if self.resource_metadata_url: - www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') - - www_authenticate = f"Bearer {', '.join(www_auth_parts)}" - - # Send response - body = {"error": error, "error_description": description} + parts.append(f'resource_metadata="{self.resource_metadata_url}"') + www_authenticate = f"Bearer {', '.join(parts)}" if parts else "Bearer" + + body: dict[str, str] = {} + if error is not None: + body["error"] = error + if description is not None: + body["error_description"] = description body_bytes = json.dumps(body).encode() await send( diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index c10ff82f3..30173238a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -710,7 +710,7 @@ def streamable_http_app( middleware = [ Middleware( AuthenticationMiddleware, - backend=BearerAuthBackend(token_verifier), + backend=BearerAuthBackend(token_verifier, resource_server_url=auth.resource_server_url), ), Middleware(AuthContextMiddleware), ] diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 029512a78..511c19c8b 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -1000,7 +1000,9 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no # extract auth info from request (but do not require it) Middleware( AuthenticationMiddleware, - backend=BearerAuthBackend(self._token_verifier), + backend=BearerAuthBackend( + self._token_verifier, resource_server_url=self.settings.auth.resource_server_url + ), ), # Add the auth context middleware to store # authenticated user in a contextvar diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 3ba880f40..d1e3b3c82 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -5,12 +5,15 @@ from pydantic import AnyUrl, HttpUrl +_DEFAULT_PORTS = {"http": 80, "https": 443} + def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: """Convert server URL to canonical resource URL per RFC 8707. RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". - Returns absolute URI with lowercase scheme/host for canonical form. + Returns absolute URI with lowercase scheme/host and the scheme's default port + elided (RFC 3986 §6.2.3) for canonical form. Args: url: Server URL to convert @@ -23,9 +26,13 @@ def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: # Parse the URL and remove fragment, create canonical form parsed = urlsplit(url_str) - canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) - - return canonical + scheme = parsed.scheme.lower() + netloc = parsed.netloc.lower() + # RFC 3986 §6.2.3: an explicit default port is equivalent to omitting it. + if parsed.port is not None and _DEFAULT_PORTS.get(scheme) == parsed.port: + userinfo, sep, hostport = netloc.rpartition("@") + netloc = f"{userinfo}{sep}{hostport.rsplit(':', 1)[0]}" + return urlunsplit(parsed._replace(scheme=scheme, netloc=netloc, fragment="")) def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: @@ -65,6 +72,19 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> return requested_path.startswith(configured_path) +def check_token_audience(token_resource: str, server_resource: str | HttpUrl | AnyUrl) -> bool: + """Return True iff a token's RFC 8707 resource indicator identifies this server. + + Server-side audience validation is canonical-URI equality (authorization.mdx + Token Audience Binding): a token for a parent or sibling path on the same + origin is NOT for this server. Contrast check_resource_allowed, which is the + client-side hierarchical question and intentionally more permissive. + """ + return resource_url_from_server_url(token_resource).rstrip("/") == resource_url_from_server_url( + server_resource + ).rstrip("/") + + def calculate_token_expiry(expires_in: int | str | None) -> float | None: """Calculate token expiry timestamp from expires_in seconds. diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index bf7f8ceec..840a94f2e 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2622,12 +2622,6 @@ def __post_init__(self) -> None: behavior="The resource server validates that the token audience matches its resource identifier.", transports=("streamable-http",), note="Auth is enforced at the HTTP layer.", - divergence=Divergence( - note=( - "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " - "resource is accepted. Spec MUST." - ), - ), ), "hosting:auth:authinfo-propagates": Requirement( source="sdk", @@ -2640,18 +2634,12 @@ def __post_init__(self) -> None: behavior="An expired token returns 401 invalid_token.", transports=("streamable-http",), note="Auth is enforced at the HTTP layer; 401 is an HTTP status code.", - divergence=Divergence( - note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", - ), ), "hosting:auth:invalid-401": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", transports=("streamable-http",), note="Auth is enforced at the HTTP layer; 401 is an HTTP status code.", - divergence=Divergence( - note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", - ), ), "hosting:auth:metadata-endpoints": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", @@ -2671,15 +2659,6 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), note="Auth is enforced at the HTTP layer; 401 is an HTTP status code.", - divergence=Divergence( - note=( - "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " - "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " - "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " - 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' - "authentication information was presented." - ), - ), ), "hosting:auth:prm:authorization-servers-field": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", @@ -2706,13 +2685,6 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), note="Auth is enforced at the HTTP layer; 403 is an HTTP status code.", - divergence=Divergence( - note=( - 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' - "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " - "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." - ), - ), ), "hosting:auth:as:authorize-requires-pkce": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index 4fd1110c9..db55049b0 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -277,9 +277,9 @@ def shim( class _FirstChallenge: """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. - Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry - parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured - to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry a + `scope=` value that differs from the gate's `required_scopes` (which is all the real + middleware can emit), so client scope-selection priority is reachable end to end. Reserve this pattern for behaviour the real server cannot be made to produce. """ @@ -312,9 +312,10 @@ def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) - """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up - handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the - divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this - pattern for behaviour the real server cannot be made to produce. + handling: the real middleware's 403 carries `scope=` from the gate's static + `required_scopes`, but step-up tests need a wider scope than the gate would emit so the + client's scope-union logic has something to add. Reserve this pattern for behaviour the + real server cannot be made to produce. The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the first authenticated POST is the auth flow's retry of the original initialize request (yielded diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index d4eb591b5..08c4c9142 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -328,12 +328,12 @@ async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_adve async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. - The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence - on `hosting:auth:scope-403`), so the test supplies the first 401 itself via - `first_challenge_shim` and disables token verification so the post-auth retry succeeds - regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors - `required_scopes`); the challenge says `from-header`; the authorize URL must carry - `from-header`. + The SDK's bearer middleware emits `scope=` with the configured `required_scopes`, which + here would be `from-prm` — the same value PRM advertises — so the test supplies the first + 401 itself via `first_challenge_shim` to put a *different* value in the challenge, and + disables token verification so the post-auth retry succeeds regardless of the granted + scope. PRM advertises `["from-prm"]`; the challenge says `from-header`; the authorize URL + must carry `from-header`. """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py index 55029c9f4..120690df3 100644 --- a/tests/interaction/auth/test_bearer.py +++ b/tests/interaction/auth/test_bearer.py @@ -3,8 +3,8 @@ These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, -the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP -endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +the `WWW-Authenticate` header structure, and that a token with no audience claim reaches the +MCP endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. """ import time @@ -40,6 +40,14 @@ expires_at=_FUTURE, resource="https://other.example/mcp", ), + "tok-parent-aud": AccessToken( + token="tok-parent-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="http://127.0.0.1:8000/", + ), + "tok-no-aud": AccessToken(token="tok-no-aud", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), } @@ -80,41 +88,37 @@ async def test_a_request_with_no_authorization_header_is_challenged_with_resourc ) -> None: """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. - The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and - expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The - spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the - no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence - on this requirement. Asserting the dict equals an exact key set also pins that no parameter - appears twice. + RFC 6750 §3: a no-credentials challenge carries no error code. The snapshot pins the + full header (parameter order included); asserting the dict equals an exact key set also + pins that no parameter appears twice. """ response = await post_mcp(protected) assert response.status_code == 401 assert response.headers["www-authenticate"] == snapshot( - 'Bearer error="invalid_token", error_description="Authentication required", ' - 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + 'Bearer scope="mcp:read", resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' ) assert parse_www_authenticate(response.headers["www-authenticate"]) == { - "error": "invalid_token", - "error_description": "Authentication required", + "scope": REQUIRED_SCOPE, "resource_metadata": RESOURCE_METADATA_URL, } - assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + assert response.json() == snapshot({}) @requirement("hosting:auth:invalid-401") async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: """A token the verifier does not recognize is answered 401 `invalid_token`. - The challenge is identical to the no-header case (the backend returns `None` for both); the - missing `scope` parameter is the recorded divergence on this requirement. + The challenge is distinct from the no-header case: a bearer token was presented, so RFC + 6750 §3.1's `error` and `error_description` apply. """ response = await post_mcp(protected, bearer="tok-unknown") assert response.status_code == 401 assert parse_www_authenticate(response.headers["www-authenticate"]) == { "error": "invalid_token", - "error_description": "Authentication required", + "error_description": "The access token is malformed or unknown", + "scope": REQUIRED_SCOPE, "resource_metadata": RESOURCE_METADATA_URL, } @@ -124,48 +128,89 @@ async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> """A token whose `expires_at` is in the past is answered 401 `invalid_token`. The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete - past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded - divergence on this requirement. + past timestamp so no time mocking is involved. """ response = await post_mcp(protected, bearer="tok-expired") assert response.status_code == 401 - assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "The access token has expired", + "scope": REQUIRED_SCOPE, + "resource_metadata": RESOURCE_METADATA_URL, + } @requirement("hosting:auth:scope-403") -async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( +async def test_a_token_missing_a_required_scope_is_answered_403_with_the_required_scope_in_the_challenge( protected: httpx.AsyncClient, ) -> None: - """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + """A token lacking the required scope is answered 403 `insufficient_scope` with `scope=` naming what's needed. - The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` - naming the required scope; the SDK never emits it, recorded as the divergence on this - requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is - a resource-server/client asymmetry. + The SDK client reads `scope` from this header to drive step-up, so the parameter is the + contract between resource server and client. """ response = await post_mcp(protected, bearer="tok-noscope") assert response.status_code == 403 - parsed = parse_www_authenticate(response.headers["www-authenticate"]) - assert parsed == { + assert parse_www_authenticate(response.headers["www-authenticate"]) == { "error": "insufficient_scope", - "error_description": f"Required scope: {REQUIRED_SCOPE}", + "error_description": "The access token lacks a required scope", + "scope": REQUIRED_SCOPE, "resource_metadata": RESOURCE_METADATA_URL, } - assert "scope" not in parsed @requirement("hosting:auth:aud-validation") -async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: - """A token whose `resource` does not match the server's resource identifier is accepted. +async def test_a_token_with_a_mismatched_audience_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is answered 401. - The spec mandates the resource server validate the token's audience; the bearer backend - never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint - serves it. This pins current behaviour with the divergence recorded on the requirement. + Spec-mandated: the resource server MUST validate the token's audience and reject tokens + not issued specifically for it. """ response = await post_mcp(protected, bearer="tok-wrong-aud") + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "The access token was issued for a different resource", + "scope": REQUIRED_SCOPE, + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_for_a_parent_path_on_the_same_origin_is_answered_401_invalid_token( + protected: httpx.AsyncClient, +) -> None: + """A token whose audience is the same origin but a parent path is answered 401. + + This is the discriminating case for canonical-URI equality: under hierarchical prefix + semantics a token for `http://host/` would be accepted by a server at `http://host/mcp`; + under audience binding it must be rejected. The cross-origin case above cannot catch a + regression to prefix semantics. + """ + response = await post_mcp(protected, bearer="tok-parent-aud") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "The access token was issued for a different resource", + "scope": REQUIRED_SCOPE, + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_without_a_resource_claim_passes_the_audience_check(protected: httpx.AsyncClient) -> None: + """A token whose `AccessToken.resource` is unset passes the audience check. + + SDK-defined pass-through: the SDK cannot distinguish a verifier that performed its own + audience check and chose not to surface the claim from a token that genuinely carries + none, so `resource is None` is accepted. This pins that policy. + """ + response = await post_mcp(protected, bearer="tok-no-aud") + assert response.status_code == 200 assert response.headers["content-type"].startswith("text/event-stream") # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response @@ -186,4 +231,6 @@ async def test_an_access_token_in_the_query_string_is_not_accepted(protected: ht response = await post_mcp(protected, query={"access_token": "tok-valid"}) assert response.status_code == 401 - assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert "error" not in parsed + assert parsed["scope"] == REQUIRED_SCOPE diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index bd14e294c..b76e0f380 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -6,10 +6,15 @@ import pytest from starlette.authentication import AuthCredentials from starlette.datastructures import Headers -from starlette.requests import Request +from starlette.requests import Request, empty_receive from starlette.types import Message, Receive, Scope, Send -from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.middleware.bearer_auth import ( + AuthenticatedUser, + BearerAuthBackend, + InvalidTokenUser, + RequireAuthMiddleware, +) from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, ProviderTokenVerifier @@ -126,7 +131,9 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizat assert result is None async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """Test authentication with invalid token.""" + """A Bearer token the verifier rejects yields an InvalidTokenUser carrying the + reason, so RequireAuthMiddleware can send error="invalid_token" rather than a + bare challenge (RFC 6750 §3.1 distinguishes no-credentials from bad-credentials).""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( { @@ -135,14 +142,24 @@ async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServer } ) result = await backend.authenticate(request) - assert result is None + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert credentials.scopes == [] + assert isinstance(user, InvalidTokenUser) + assert user.reason == "The access token is malformed or unknown" + # BaseUser interface obligations — Starlette may render these + assert user.is_authenticated is False + assert user.display_name == "" + assert user.identity == "" async def test_expired_token( self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], expired_access_token: AccessToken, ): - """Test authentication with expired token.""" + """An expired token yields an InvalidTokenUser whose reason names expiry, so the + WWW-Authenticate error_description tells the client why (RFC 6750 §3.1).""" backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) request = Request( @@ -152,7 +169,10 @@ async def test_expired_token( } ) result = await backend.authenticate(request) - assert result is None + assert result is not None + _, user = result + assert isinstance(user, InvalidTokenUser) + assert user.reason == "The access token has expired" async def test_valid_token( self, @@ -344,17 +364,18 @@ async def send(message: Message) -> None: assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called - async def test_no_auth_credentials(self, valid_access_token: AccessToken): - """Test middleware with no auth credentials in scope.""" + async def test_invalid_token_user_gets_401_with_error_description(self): + """When the backend marked the request with InvalidTokenUser, the middleware + sends 401 with error="invalid_token" and the carried reason as error_description + (RFC 6750 §3.1) — distinct from the bare challenge sent when no token was presented.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + scope: Scope = { + "type": "http", + "user": InvalidTokenUser("The access token has expired"), + "auth": AuthCredentials(), + } - # Create a user with read/write scopes - user = AuthenticatedUser(valid_access_token) - - scope: Scope = {"type": "http", "user": user} # No auth credentials - - # Create dummy async functions for receive and send async def receive() -> Message: # pragma: no cover return {"type": "http.request"} @@ -365,11 +386,12 @@ async def send(message: Message) -> None: await middleware(scope, receive, send) - # Check that a 403 response was sent assert len(sent_messages) == 2 assert sent_messages[0]["type"] == "http.response.start" - assert sent_messages[0]["status"] == 403 - assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) + assert sent_messages[0]["status"] == 401 + www_authenticate = dict(sent_messages[0]["headers"])[b"www-authenticate"] + assert b'error="invalid_token"' in www_authenticate + assert b'error_description="The access token has expired"' in www_authenticate assert not app.called async def test_has_required_scopes(self, valid_access_token: AccessToken): @@ -446,3 +468,24 @@ async def send(message: Message) -> None: # pragma: no cover assert app.scope == scope assert app.receive == receive assert app.send == send + + +@pytest.mark.anyio +async def test_unauthenticated_request_with_no_required_scopes_gets_bare_bearer_challenge(): + """RFC 6750 §3: when no credentials were presented and the server has nothing to + advertise (no required scopes, no resource_metadata), the WWW-Authenticate header + is the bare scheme name with no parameters.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=[]) + scope: Scope = {"type": "http"} + + sent_messages: list[Message] = [] + + async def send(message: Message) -> None: + sent_messages.append(message) + + await middleware(scope, empty_receive, send) + + assert sent_messages[0]["status"] == 401 + assert dict(sent_messages[0]["headers"])[b"www-authenticate"] == b"Bearer" + assert not app.called diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 5ae0e22b0..23f1e0620 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -2,7 +2,7 @@ from pydantic import HttpUrl -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +from mcp.shared.auth_utils import check_resource_allowed, check_token_audience, resource_url_from_server_url # Tests for resource_url_from_server_url function @@ -34,6 +34,23 @@ def test_resource_url_from_server_url_preserves_port(): assert resource_url_from_server_url("http://example.com:8080/") == "http://example.com:8080/" +def test_resource_url_from_server_url_strips_default_port(): + """An explicit default port is equivalent to omitting it (RFC 3986 §6.2.3).""" + assert resource_url_from_server_url("https://example.com:443/mcp") == "https://example.com/mcp" + assert resource_url_from_server_url("http://example.com:80/mcp") == "http://example.com/mcp" + # Only the scheme's own default is stripped — :80 on https is significant. + assert resource_url_from_server_url("https://example.com:80/mcp") == "https://example.com:80/mcp" + # IPv6 brackets survive the rewrite. + assert resource_url_from_server_url("https://[::1]:443/mcp") == "https://[::1]/mcp" + + +def test_check_token_audience_ignores_default_port(): + """A token issued for `https://h:443/mcp` is for the server at `https://h/mcp`.""" + assert check_token_audience("https://h:443/mcp", "https://h/mcp") is True + assert check_token_audience("https://h/mcp", "https://h:443/mcp") is True + assert check_token_audience("https://h:8443/mcp", "https://h/mcp") is False + + def test_resource_url_from_server_url_lowercase_scheme_and_host(): """Scheme and host should be lowercase for canonical form.""" assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" From 47639a2b36326afdb62e0863942ef757cdbe3cb8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:15:41 +0000 Subject: [PATCH 03/14] Stop replying to cancelled requests; map unhandled handler exceptions to -32603 A request cancelled via notifications/cancelled now gets no response: when the handler scope's cancel is caught, JSONRPCDispatcher returns instead of writing an error. The sender retired its own waiter when it cancelled, so no reply is needed to unblock it. An unhandled exception in a request handler now produces JSON-RPC error -32603 (INTERNAL_ERROR) with the opaque message "Internal server error" instead of code 0 carrying str(exc). The exception is still logged server-side. To send a specific code/message, raise MCPError; pydantic ValidationError still maps to INVALID_PARAMS. handler_exception_to_error_data is now total: it returns MappedError(error: ErrorData, unexpected: bool) for every Exception. Callers gate logger.exception / raise_handler_exceptions on the unexpected flag rather than re-deriving the rung set, so a handler that deliberately raises MCPError(code=INTERNAL_ERROR) is not treated as a crash. JSONRPCDispatcher, DirectDispatcher, and the modern HTTP entry's _to_jsonrpc_response all call the one helper; DirectDispatcher no longer hand-rolls its own ladder. The interaction suite's protocol:cancel:in-flight and protocol:error:internal-error requirements drop their divergence entries and modern-error-surface arm exclusions; the two prompt-validation divergence notes are reworded for the new -32603 surface. docs/migration.md gains a section covering both behaviour changes. --- docs/advanced/authorization.md | 2 +- docs/migration.md | 19 +- src/mcp/server/_streamable_http_modern.py | 11 +- src/mcp/shared/direct_dispatcher.py | 24 +-- src/mcp/shared/jsonrpc_dispatcher.py | 64 ++++--- tests/client/test_session.py | 9 +- tests/docs_src/test_authorization.py | 16 +- tests/docs_src/test_deprecated.py | 3 +- tests/interaction/_requirements.py | 32 +--- .../interaction/lowlevel/test_cancellation.py | 65 +++---- tests/interaction/lowlevel/test_tools.py | 9 +- tests/interaction/mcpserver/test_prompts.py | 24 +-- .../transports/test_hosting_http_modern.py | 5 +- tests/server/mcpserver/test_server.py | 6 +- tests/server/test_cancel_handling.py | 68 +++---- tests/server/test_completion_with_context.py | 9 +- tests/server/test_runner.py | 4 +- tests/shared/test_jsonrpc_dispatcher.py | 180 ++++++++++-------- tests/shared/test_streamable_http.py | 7 +- 19 files changed, 274 insertions(+), 283 deletions(-) diff --git a/docs/advanced/authorization.md b/docs/advanced/authorization.md index 87ecc17b8..2c80fb5c6 100644 --- a/docs/advanced/authorization.md +++ b/docs/advanced/authorization.md @@ -27,7 +27,7 @@ The SDK has no opinion about what a valid token looks like. You tell it, by impl `AuthSettings` is the public face of your resource server: * `issuer_url`: the authorization server that issues your tokens. -* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. When your verifier returns an `AccessToken.resource`, the SDK rejects the token unless it matches this URL — a token issued for a different resource never reaches a tool. +* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. When your verifier returns an `AccessToken.resource`, the SDK rejects the token unless it matches this URL, so a token issued for a different resource never reaches a tool. * `required_scopes`: every token must carry all of them. !!! tip diff --git a/docs/migration.md b/docs/migration.md index cab83f9d9..aadd3cc71 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -34,6 +34,21 @@ elicitation required, invalid parameters). For tool *execution* failures the calling LLM should see and react to, raise any other exception or return `CallToolResult(is_error=True, ...)` directly; that path is unchanged. +### Unhandled handler exceptions return `INTERNAL_ERROR`; cancelled requests get no reply + +An unhandled exception in a request handler now produces JSON-RPC error `-32603` +(`INTERNAL_ERROR`) with the opaque message `"Internal server error"`. v1 returned +code `0` with `str(exc)` as the message, leaking handler internals to the peer; the +exception is still logged server-side via `logger.exception`. To send a specific +code and message, raise `MCPError` (unchanged); a pydantic `ValidationError` is +still mapped to `INVALID_PARAMS`. + +A request cancelled via `notifications/cancelled` now receives no response at all, +per the spec's SHOULD. v1 answered the cancelled request with an error +(`code=0, message="Request cancelled"`). The sender's awaiting call already fails +with anyio cancellation when its scope is cancelled, so no reply is needed to +unblock it. + ### `streamablehttp_client` removed The deprecated `streamablehttp_client` function has been removed. Use `streamable_http_client` instead. @@ -1388,9 +1403,9 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr Behavior changes: -- **Callbacks and notifications now run concurrently.** In v1 the receive loop processed one inbound message at a time, so callbacks ran inline and in order. Now each delivery starts in arrival order but runs as its own task. Server-initiated request callbacks (`sampling`, `elicitation`, `roots`) no longer block other traffic, may themselves send requests without deadlocking, and are interrupted if the server sends `notifications/cancelled` (the request is then answered with an error). Notification callbacks (`logging_callback`, `progress_callback`, `message_handler`) may interleave, and a `progress_callback` may run after the request it reports on has returned; there is no built-in bound on concurrent deliveries. Transport-level errors reach `message_handler` the same way, and a `message_handler` that raises is logged rather than fatal to the session. Callbacks that need strict sequencing must coordinate themselves. +- **Callbacks and notifications now run concurrently.** In v1 the receive loop processed one inbound message at a time, so callbacks ran inline and in order. Now each delivery starts in arrival order but runs as its own task. Server-initiated request callbacks (`sampling`, `elicitation`, `roots`) no longer block other traffic, may themselves send requests without deadlocking, and are interrupted if the server sends `notifications/cancelled` (the callback is interrupted; no response is sent for the cancelled request). Notification callbacks (`logging_callback`, `progress_callback`, `message_handler`) may interleave, and a `progress_callback` may run after the request it reports on has returned; there is no built-in bound on concurrent deliveries. Transport-level errors reach `message_handler` the same way, and a `message_handler` that raises is logged rather than fatal to the session. Callbacks that need strict sequencing must coordinate themselves. - **Timeouts**: a timed-out or abandoned request is now followed by `notifications/cancelled`, so the server stops the handler instead of leaving it running. -- **A raising request callback** is answered with `code=0` and the exception text; v1 flattened every callback exception to `INVALID_PARAMS`. For a specific error response, return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: pydantic's `ValidationError` is still answered with `INVALID_PARAMS`, as in v1. +- **A raising request callback** is answered with `INTERNAL_ERROR` (`-32603`) and a generic message — the exception text is logged client-side, not sent; v1 flattened every callback exception to `INVALID_PARAMS`. For a specific error response, return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: pydantic's `ValidationError` is still answered with `INVALID_PARAMS`, as in v1. - **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed it raises `MCPError` (`CONNECTION_CLOSED`) instead. `send_notification` before entry still works. - **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** No client transport ever serialized these hints; progress and response correlation via `progressToken` and the request id is unaffected. - **Client callbacks now receive `mcp.client.ClientRequestContext`** (its `request_id` is always populated); the private `mcp.shared._context.RequestContext` generic is deleted. Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index e36ac7dd4..a9e1461a3 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -27,7 +27,6 @@ import anyio from anyio.streams.memory import MemoryObjectSendStream from mcp_types import ( - INTERNAL_ERROR, INVALID_REQUEST, PARSE_ERROR, ClientCapabilities, @@ -135,17 +134,15 @@ async def _to_jsonrpc_response( """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. The exception-to-wire boundary for the modern HTTP entry, composed around - `serve_one`. `MCPError` and `ValidationError` map via the shared - `handler_exception_to_error_data` ladder; any other exception is logged and - surfaced as `INTERNAL_ERROR` so handler internals never reach the wire. + `serve_one`; the shared `handler_exception_to_error_data` ladder owns the + mapping. """ try: result = await coro except Exception as exc: - error = handler_exception_to_error_data(exc) - if error is None: + error, unexpected = handler_exception_to_error_data(exc) + if unexpected: logger.exception("request handler raised") - error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index fd3e69d49..55ee1dc40 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -24,12 +24,12 @@ import anyio import anyio.abc -from mcp_types import CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, REQUEST_TIMEOUT, RequestId -from pydantic import ValidationError +from mcp_types import CONNECTION_CLOSED, INTERNAL_ERROR, REQUEST_TIMEOUT, RequestId from mcp.shared._compat import resync_tracer from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.jsonrpc_dispatcher import handler_exception_to_error_data from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext @@ -232,21 +232,15 @@ async def _dispatch_request( dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id) try: return await self._on_request(dctx, method, params) - except MCPError: - raise - except ValidationError as e: - # Same shape JSONRPCDispatcher writes, so runner-over-direct - # tests see what runner-over-JSONRPC would. - raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e except Exception as e: - # Single owner of the in-proc exception-to-error policy (mirrors - # JSONRPCDispatcher / `_streamable_http_modern._to_jsonrpc_response` - # for the wire paths). True chains the original for in-process - # debugging; False sanitizes to match the wire path's leak guard. - if self._raise_handler_exceptions: + error, unexpected = handler_exception_to_error_data(e) + if unexpected and self._raise_handler_exceptions: + # In-process debugging: chain the real exception so the + # traceback survives. Never reaches the wire. raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e - logger.exception("request handler raised") - raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None + if unexpected: + logger.exception("request handler raised") + raise MCPError(code=error.code, message=error.message, data=error.data) from None except TimeoutError: raise MCPError( code=REQUEST_TIMEOUT, diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 64fcd3298..37ef6721d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -12,7 +12,7 @@ from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Generic, Literal, cast +from typing import Any, Generic, Literal, NamedTuple, cast import anyio import anyio.abc @@ -67,21 +67,37 @@ the handler's scope; `"signal"` only sets `ctx.cancel_requested`.""" -def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None: - """Map a handler-raised exception to its wire `ErrorData`. +class MappedError(NamedTuple): + error: ErrorData + unexpected: bool + """True iff ``exc`` hit the opaque fallthrough (no recognized rung). Callers + gate `logger.exception` / `_raise_handler_exceptions` on this — never on + ``error.code``, since handlers may deliberately raise + ``MCPError(code=INTERNAL_ERROR)``.""" - The two rungs every dispatcher shares: an `MCPError` carries its own - `ErrorData`; a pydantic `ValidationError` is the spec's INVALID_PARAMS - with empty ``data`` (no pydantic text on the wire). Returns ``None`` for - any other exception so each caller applies its own catch-all - - `JSONRPCDispatcher` currently pins ``code=0`` for v1 compat, - the modern HTTP entry uses `INTERNAL_ERROR`. + +def handler_exception_to_error_data(exc: Exception) -> MappedError: + """Map a handler-raised exception to its wire `ErrorData` plus an + ``unexpected`` flag. + + `MCPError` carries its own `ErrorData`; pydantic `ValidationError` is the + spec's INVALID_PARAMS with empty ``data`` (no pydantic text on the wire); + everything else is INTERNAL_ERROR with an opaque message so handler + internals never reach the peer. The single source of truth for both the + wire shape and the "was this unexpected?" classification — callers do not + re-derive the rung set. """ if isinstance(exc, MCPError): - return exc.error + return MappedError(exc.error, unexpected=False) if isinstance(exc, ValidationError): - return ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") - return None + return MappedError( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), + unexpected=False, + ) + return MappedError( + ErrorData(code=INTERNAL_ERROR, message="Internal server error"), + unexpected=True, + ) def progress_token_from_params(params: Mapping[str, Any] | None) -> ProgressToken | None: @@ -693,11 +709,11 @@ async def _handle_request( if scope.cancelled_caught: # anyio absorbs the scope's own cancel at __exit__, and # `cancelled_caught` (unlike `cancel_called`) guarantees the - # result write above did not happen - no double response. - # TODO(L38): spec says SHOULD NOT respond after cancel; - # the existing server always has, so match that for now. - answer_write_started = True - await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) + # result write above did not happen. Spec: receivers SHOULD NOT + # respond to a cancelled request; the sender retired its own + # waiter when it cancelled (`send_raw_request` finally), so no + # reply is needed to unblock it. + return except anyio.get_cancelled_exc_class(): # Shutdown: answer the request so the peer isn't left waiting - unless # an answer write already started (it may have reached the transport; @@ -712,16 +728,12 @@ async def _handle_request( ) raise except Exception as e: - error = handler_exception_to_error_data(e) - if error is not None: - await self._write_error(req.id, error) - else: + error, unexpected = handler_exception_to_error_data(e) + if unexpected: logger.exception("handler for %r raised", req.method) - # TODO(L58): code=0 pins existing-server compat; JSON-RPC says - # INTERNAL_ERROR. Revisit per the suite's divergence entry. - await self._write_error(req.id, ErrorData(code=0, message=str(e))) - if self._raise_handler_exceptions: - raise + await self._write_error(req.id, error) + if unexpected and self._raise_handler_exceptions: + raise # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. def _allocate_id(self) -> int: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 83893e36f..5eaa5eff7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -956,9 +956,10 @@ async def call() -> None: @pytest.mark.anyio -async def test_raising_sampling_callback_answers_with_code_zero(): - """A raising sampling callback is answered with code 0 and `str(exc)` (SDK-defined). - Raw streams because the assertion is the outbound `JSONRPCError` envelope itself.""" +async def test_raising_sampling_callback_answers_with_internal_error(): + """A raising sampling callback is answered with INTERNAL_ERROR and an opaque message (spec-mandated + code; SDK-defined redaction). Raw streams because the assertion is the outbound `JSONRPCError` + envelope itself.""" async def boom(ctx: object, params: object) -> types.CreateMessageResult: raise RuntimeError("sampling boom") @@ -973,7 +974,7 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: ) out = await from_client.receive() assert isinstance(out.message, JSONRPCError) - assert out.message.error == types.ErrorData(code=0, message="sampling boom") + assert out.message.error == types.ErrorData(code=INTERNAL_ERROR, message="Internal server error") @pytest.mark.anyio diff --git a/tests/docs_src/test_authorization.py b/tests/docs_src/test_authorization.py index 4c7554ed7..a843c3bf0 100644 --- a/tests/docs_src/test_authorization.py +++ b/tests/docs_src/test_authorization.py @@ -55,25 +55,27 @@ async def test_the_metadata_document_is_built_from_auth_settings() -> None: async def test_a_request_without_a_token_never_reaches_the_protocol() -> None: - """The `!!! check`: no `Authorization` header means a 401 that points at the metadata document.""" + """The `!!! check`: no `Authorization` header means a 401 whose `WWW-Authenticate` points at the metadata.""" transport = httpx.ASGITransport(app=tutorial001.mcp.streamable_http_app()) async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:8000") as http_client: response = await http_client.post("/mcp", json={}) assert response.status_code == 401 - assert response.json() == {"error": "invalid_token", "error_description": "Authentication required"} + assert response.json() == {} assert response.headers["www-authenticate"] == ( - 'Bearer error="invalid_token", error_description="Authentication required", ' - 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + 'Bearer scope="notes:read", resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' ) -async def test_a_token_the_verifier_rejects_gets_the_same_401() -> None: - """tutorial001: `verify_token` returning `None` and a missing header are indistinguishable to the caller.""" +async def test_a_rejected_token_is_named_invalid_token() -> None: + """tutorial001: a token your verifier returns `None` for is a 401 with an RFC 6750 `invalid_token` error.""" transport = httpx.ASGITransport(app=tutorial001.mcp.streamable_http_app()) async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:8000") as http_client: response = await http_client.post("/mcp", json={}, headers={"Authorization": "Bearer not-a-real-token"}) assert response.status_code == 401 - assert response.json() == {"error": "invalid_token", "error_description": "Authentication required"} + assert response.json() == { + "error": "invalid_token", + "error_description": "The access token is malformed or unknown", + } async def test_get_access_token_is_none_outside_an_authenticated_request() -> None: diff --git a/tests/docs_src/test_deprecated.py b/tests/docs_src/test_deprecated.py index 892a8f362..d717baa8c 100644 --- a/tests/docs_src/test_deprecated.py +++ b/tests/docs_src/test_deprecated.py @@ -17,7 +17,6 @@ from mcp.client import ClientRequestContext from mcp.server import MCPServer from mcp.server.mcpserver import Context -from mcp.shared.exceptions import NoBackChannelError pytestmark = pytest.mark.anyio @@ -54,7 +53,7 @@ async def test_create_message_warns_and_then_raises_on_a_modern_connection() -> MCPDeprecationWarning, match=r"^The sampling capability is deprecated as of 2026-07-28 \(SEP-2577\)\.$", ), - pytest.raises(NoBackChannelError) as exc, + pytest.raises(MCPError) as exc, ): await client.call_tool("ask_model", {"prompt": "hi"}) assert str(exc.value) == ( diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 840a94f2e..d0e6693d2 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -511,14 +511,6 @@ def __post_init__(self) -> None: "A cancellation notification for an in-flight request stops the server-side handler, and the " "receiver does not send a response for the cancelled request." ), - divergence=Divergence( - note=( - "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " - "request; both seats send an error response (code 0, 'Request cancelled') instead — the " - "server for cancelled client requests, and the client for cancelled server-initiated " - "requests — which is what unblocks the sender's pending call." - ), - ), arm_exclusions=( ArmExclusion(reason="requires-session", transport="streamable-http-stateless"), ArmExclusion(reason="requires-session", spec_version="2026-07-28"), @@ -582,23 +574,6 @@ def __post_init__(self) -> None: "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " "-32603 Internal error." ), - divergence=Divergence( - note=( - "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " - "leaks str(exc) as the error message." - ), - ), - arm_exclusions=( - ArmExclusion( - reason="modern-error-surface", - spec_version="2026-07-28", - note=( - "The modern entry maps Exception->INTERNAL_ERROR (-32603) with an opaque message, so the " - "2026 arm SATISFIES this requirement; the test pins the legacy code-0 divergence and " - "needs an era-aware assertion before re-admission." - ), - ), - ), ), "protocol:error:invalid-params": Requirement( source=f"{SPEC_BASE_URL}/basic#responses", @@ -1259,10 +1234,9 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " - "which the low-level server converts to error code 0 with the exception text as the message." + "which the dispatcher converts to an opaque -32603 Internal error rather than -32602." ), ), - arm_exclusions=(ArmExclusion(reason="modern-error-surface", spec_version="2026-07-28"),), ), "prompts:get:multi-message": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", @@ -1305,7 +1279,6 @@ def __post_init__(self) -> None: "mcpserver:prompt:args-validation": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", - arm_exclusions=(ArmExclusion(reason="modern-error-surface", spec_version="2026-07-28"),), ), "mcpserver:prompt:decorated": Requirement( source="sdk", @@ -1334,10 +1307,9 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " - "ValueError, which the low-level server converts to error code 0." + "ValueError, which the dispatcher converts to an opaque -32603 Internal error." ), ), - arm_exclusions=(ArmExclusion(reason="modern-error-surface", spec_version="2026-07-28"),), ), # ═══════════════════════════════════════════════════════════════════════════ # Completion diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 247e1135a..334178c6f 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -1,9 +1,9 @@ """Cancellation interactions against the low-level Server, driven through the public Client API. -There is no client-side cancellation API: cancelling means sending a CancelledNotification -carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so -these tests capture the id from inside the blocked handler before cancelling. The handler blocks -on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +The client-side cancellation idiom is scope-cancel: cancel an `anyio.CancelScope` enclosing the +pending `client.call_tool(...)` await. The dispatcher writes the courtesy `notifications/cancelled` +on abandonment, so the test does not hand-build the notification or learn the request id. The +handler blocks on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. """ import anyio @@ -14,7 +14,6 @@ REQUEST_TIMEOUT, CallToolResult, EmptyResult, - ErrorData, Implementation, InitializeResult, JSONRPCNotification, @@ -40,21 +39,17 @@ @requirement("protocol:cancel:in-flight") @requirement("protocol:cancel:handler-abort-propagates") async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: - """Cancelling an in-flight request interrupts its handler and fails the pending call. + """Cancelling an in-flight request interrupts its handler and fails the pending call locally. - The server answers the cancelled request with an error response (the spec says it should - not respond at all; see the divergence note on the requirement), so the caller's pending - request raises rather than hanging. + Spec-mandated: receivers SHOULD NOT respond to a cancelled request, so the caller's await is + interrupted by anyio cancellation (not an `MCPError` reply). The wire-level "no response is + sent" half is asserted by the dispatcher unit test; this test stays above the wire. """ started = anyio.Event() handler_cancelled = anyio.Event() - request_ids: list[types.RequestId] = [] - errors: list[ErrorData] = [] async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" - assert ctx.request_id is not None - request_ids.append(ctx.request_id) started.set() try: await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event @@ -67,31 +62,24 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: with anyio.fail_after(5): - async with anyio.create_task_group() as task_group: + with anyio.CancelScope() as scope: + async with anyio.create_task_group() as task_group: - async def call_and_capture_error() -> None: - with pytest.raises(MCPError) as exc_info: + async def call() -> None: await client.call_tool("block", {}) - errors.append(exc_info.value.error) - - task_group.start_soon(call_and_capture_error) - await started.wait() - await client.session.send_notification( - types.CancelledNotification( - params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") - ) - ) - - await handler_cancelled.wait() + raise NotImplementedError # unreachable: the scope is cancelled - assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + task_group.start_soon(call) + await started.wait() + scope.cancel() + assert scope.cancelled_caught # the await failed via anyio cancel, not MCPError + await handler_cancelled.wait() # the receiver actually stopped work @requirement("protocol:cancel:server-survives") async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: """A request cancelled mid-flight does not poison the session: the next request succeeds.""" started = anyio.Event() - request_ids: list[types.RequestId] = [] async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -106,8 +94,6 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: if params.name == "echo": return CallToolResult(content=[TextContent(text="still alive")]) - assert ctx.request_id is not None - request_ids.append(ctx.request_id) started.set() await anyio.Event().wait() # blocks until cancelled raise NotImplementedError # unreachable @@ -116,18 +102,17 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: with anyio.fail_after(5): - async with anyio.create_task_group() as task_group: + with anyio.CancelScope() as scope: + async with anyio.create_task_group() as task_group: - async def call_and_swallow_cancellation_error() -> None: - with pytest.raises(MCPError): + async def call() -> None: await client.call_tool("block", {}) + raise NotImplementedError # unreachable: the scope is cancelled - task_group.start_soon(call_and_swallow_cancellation_error) - await started.wait() - await client.session.send_notification( - types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) - ) - + task_group.start_soon(call) + await started.wait() + scope.cancel() + assert scope.cancelled_caught result = await client.call_tool("echo", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 861dd75e4..a46f13fc8 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -5,6 +5,7 @@ import pytest from inline_snapshot import snapshot from mcp_types import ( + INTERNAL_ERROR, INVALID_PARAMS, AudioContent, CallToolResult, @@ -96,11 +97,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:error:internal-error") -async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: +async def test_call_tool_uncaught_exception_becomes_internal_error_with_opaque_message(connect: Connect) -> None: """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. - The low-level server reports it with code 0 and the exception text as the message; see the - divergence note on the requirement. + Spec-mandated for the code (-32603 Internal error). SDK-defined for the message: the exception + text is logged server-side and never reaches the wire. """ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: @@ -113,7 +114,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara with pytest.raises(MCPError) as exc_info: await client.call_tool("explode", {}) - assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + assert exc_info.value.error == snapshot(ErrorData(code=INTERNAL_ERROR, message="Internal server error")) @requirement("tools:list:basic") diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 58c8b48c7..155236232 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -3,6 +3,7 @@ import pytest from inline_snapshot import snapshot from mcp_types import ( + INTERNAL_ERROR, ErrorData, GetPromptResult, ListPromptsResult, @@ -77,8 +78,8 @@ def greet(name: str) -> str: async def test_get_unknown_prompt_is_error(connect: Connect) -> None: """Getting a prompt name that was never registered fails with a JSON-RPC error. - The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on - the requirement). + The spec reserves -32602 for this case; the SDK reports an opaque -32603 (see the divergence + note on the requirement). """ mcp = MCPServer("prompter") @@ -91,7 +92,7 @@ def greet(name: str) -> str: with pytest.raises(MCPError) as exc_info: await client.get_prompt("nope") - assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + assert exc_info.value.error == snapshot(ErrorData(code=INTERNAL_ERROR, message="Internal server error")) @requirement("prompts:get:missing-required-args") @@ -99,8 +100,7 @@ async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: """Getting a prompt without one of its required arguments fails with a JSON-RPC error. The missing argument is detected before the prompt function is called, but the spec's -32602 - Invalid params is reported as error code 0 with the bare exception text (see the divergence - note on the requirement). + Invalid params is reported as an opaque -32603 (see the divergence note on the requirement). """ mcp = MCPServer("prompter") @@ -113,7 +113,7 @@ def greet(name: str) -> str: with pytest.raises(MCPError) as exc_info: await client.get_prompt("greet") - assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + assert exc_info.value.error == snapshot(ErrorData(code=INTERNAL_ERROR, message="Internal server error")) @requirement("mcpserver:prompt:args-validation") @@ -121,23 +121,23 @@ async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_func """An argument that fails the function signature's type validation is rejected before the function runs. The decorated function is wrapped in pydantic's validate_call, so a value that cannot be - coerced to the parameter's annotation fails before the body executes. The function body - raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable - rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + coerced to the parameter's annotation fails before the body executes. The error response is + opaque, so a closure-captured list (not the error message) proves the body never ran. """ + called: list[object] = [] mcp = MCPServer("prompter") @mcp.prompt() def repeat(phrase: str, count: int) -> str: """A registered prompt; type validation rejects the call before the function runs.""" - raise NotImplementedError + raise NotImplementedError(called.append((phrase, count))) async with connect(mcp) as client: with pytest.raises(MCPError) as exc_info: await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) - assert exc_info.value.error.code == 0 - assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + assert exc_info.value.error == snapshot(ErrorData(code=INTERNAL_ERROR, message="Internal server error")) + assert called == [] @requirement("mcpserver:prompt:optional-args") diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index a8f1f53c7..176730b40 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -184,9 +184,8 @@ async def test_modern_handler_exception_maps_to_internal_error_without_leaking_t """A handler exception on the 2026-07-28 path returns -32603 with a generic message. Spec-mandated for the code: -32603 is the JSON-RPC Internal error code. SDK-defined for the - message: the 2026-07-28 entry deliberately does not echo ``str(exc)`` (the legacy dispatcher's - code-0 leak is the recorded divergence on ``protocol:error:internal-error``). Asserted at the - wire because the SDK client surfaces only the error object, not the HTTP status it travelled on. + message: the entry deliberately does not echo ``str(exc)``. Asserted at the wire because the + SDK client surfaces only the error object, not the HTTP status it travelled on. """ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 70855f44b..13d05d36f 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1475,8 +1475,9 @@ async def test_get_unknown_prompt(self): mcp = MCPServer() async with Client(mcp, mode="legacy") as client: - with pytest.raises(MCPError, match="Unknown prompt"): + with pytest.raises(MCPError) as exc_info: await client.get_prompt("unknown") + assert exc_info.value.error.code == INTERNAL_ERROR async def test_get_prompt_missing_args(self): """Test error when required arguments are missing.""" @@ -1486,8 +1487,9 @@ async def test_get_prompt_missing_args(self): def prompt_fn(name: str) -> str: ... # pragma: no branch async with Client(mcp, mode="legacy") as client: - with pytest.raises(MCPError, match="Missing required arguments"): + with pytest.raises(MCPError) as exc_info: await client.get_prompt("prompt_fn") + assert exc_info.value.error.code == INTERNAL_ERROR async def test_resource_decorator_rfc6570_reserved_expansion(): diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 3d32adb3c..e7fed4073 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,8 +6,6 @@ CallToolRequest, CallToolRequestParams, CallToolResult, - CancelledNotification, - CancelledNotificationParams, ClientCapabilities, Implementation, InitializeRequestParams, @@ -22,7 +20,6 @@ from mcp import Client from mcp.server import Server, ServerRequestContext -from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage @@ -30,10 +27,8 @@ async def test_server_remains_functional_after_cancel(): """Verify server can handle new requests after a cancellation.""" - # Track tool calls call_count = 0 ev_first_call = anyio.Event() - first_request_id = None async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: return ListToolsResult( @@ -47,53 +42,38 @@ async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestP ) async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: - nonlocal call_count, first_request_id - if params.name == "test_tool": - call_count += 1 - if call_count == 1: - first_request_id = ctx.request_id - ev_first_call.set() - await anyio.sleep(5) # First call is slow - return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + nonlocal call_count + assert params.name == "test_tool" + call_count += 1 + if call_count == 1: + ev_first_call.set() + await anyio.Event().wait() # blocks until cancelled + return CallToolResult(content=[TextContent(type="text", text=f"Call number: {call_count}")]) server = Server("test-server", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async with Client(server, mode="legacy") as client: - # First request (will be cancelled) - async def first_request(): - try: - await client.session.send_request( - CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), - CallToolResult, - ) - pytest.fail("First request should have been cancelled") # pragma: no cover - except MCPError: - pass # Expected - - # Start first request - async with anyio.create_task_group() as tg: - tg.start_soon(first_request) - - # Wait for it to start - await ev_first_call.wait() - - # Cancel it - assert first_request_id is not None - await client.session.send_notification( - CancelledNotification( - params=CancelledNotificationParams(request_id=first_request_id, reason="Testing server recovery"), - ) - ) + with anyio.fail_after(5): + with anyio.CancelScope() as scope: + async with anyio.create_task_group() as tg: + + async def first_request() -> None: + await client.session.send_request( + CallToolRequest(params=CallToolRequestParams(name="test_tool", arguments={})), + CallToolResult, + ) + raise NotImplementedError # unreachable: the scope is cancelled + + tg.start_soon(first_request) + await ev_first_call.wait() + scope.cancel() + assert scope.cancelled_caught - # Second request (should work normally) - result = await client.call_tool("test_tool", {}) + # Second request (should work normally) + result = await client.call_tool("test_tool", {}) - # Verify second request completed successfully assert len(result.content) == 1 - # Type narrowing for pyright content = result.content[0] - assert content.type == "text" assert isinstance(content, TextContent) assert content.text == "Call number: 2" assert call_count == 2 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index 79e4223ed..e00839ea2 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -2,6 +2,7 @@ import pytest from mcp_types import ( + INTERNAL_ERROR, CompleteRequestParams, CompleteResult, Completion, @@ -9,7 +10,7 @@ ResourceTemplateReference, ) -from mcp import Client +from mcp import Client, MCPError from mcp.server import Server, ServerRequestContext @@ -139,14 +140,14 @@ async def handle_completion(ctx: ServerRequestContext, params: CompleteRequestPa async with Client(server, mode="legacy") as client: # Try to complete table without database context - should raise error - with pytest.raises(Exception) as exc_info: + with pytest.raises(MCPError) as exc_info: await client.complete( ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), argument={"name": "table", "value": ""}, ) - # Verify error message - assert "Please select a database first" in str(exc_info.value) + # The handler's bare ValueError surfaces as an opaque INTERNAL_ERROR. + assert exc_info.value.error.code == INTERNAL_ERROR # Now complete with proper context - should work normally result_with_context = await client.complete( diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index ed9662f08..52187f262 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -776,8 +776,8 @@ async def bad_return(ctx: Ctx, params: PaginatedRequestParams | None) -> int: async with connected_runner(server) as (client, _): with pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == 0 - assert "int" in exc.value.error.message + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "Internal server error" @pytest.mark.anyio diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 82d16bc4b..79f45542e 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -104,8 +104,8 @@ async def call(method: str) -> None: @pytest.mark.anyio -async def test_handler_raising_exception_sends_code_zero_with_str_message(): - """Matches the existing server's `_handle_request`: code=0, message=str(e).""" +async def test_handler_raising_exception_sends_internal_error_with_opaque_message(): + """Spec-mandated: an unrecognised handler exception becomes -32603 with an opaque message.""" async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: raise RuntimeError("kaboom") @@ -113,19 +113,29 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", None) - assert exc.value.error.code == 0 - assert exc.value.error.message == "kaboom" + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "Internal server error" assert exc.value.__cause__ is None # cause does not survive the wire @pytest.mark.anyio -async def test_peer_cancel_interrupt_mode_writes_cancelled_error_response(): - """Matches the existing server: a peer-cancelled request is answered with code=0.""" +async def test_peer_cancel_interrupt_mode_writes_no_response(): + """Spec-mandated: a peer-cancelled request is interrupted and the receiver writes no response. + + Scripted at the wire: the handler-exit event proves the cancel reached the running handler; + a follow-up request's response being the *first* thing on the ordered server→client stream + proves nothing was emitted for the cancelled id. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() handler_exited = anyio.Event() seen_ctx: list[DCtx] = [] - async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "probe": + return {"ok": True} seen_ctx.append(ctx) handler_started.set() try: @@ -134,22 +144,34 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | handler_exited.set() raise NotImplementedError - seen_error: list[ErrorData] = [] - async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): - with anyio.fail_after(5): - async with anyio.create_task_group() as tg: # pragma: no branch - - async def call_then_record() -> None: - with pytest.raises(MCPError) as exc: - await client.send_raw_request("slow", None) - seen_error.append(exc.value.error) + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe - tg.start_soon(call_then_record) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="slow"))) await handler_started.wait() - await client.notify("notifications/cancelled", {"requestId": 1}) + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ) + ) + ) await handler_exited.wait() - assert seen_ctx[0].cancel_requested.is_set() - assert seen_error == [ErrorData(code=0, message="Request cancelled")] + assert seen_ctx[0].cancel_requested.is_set() + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="probe"))) + first = await s2c_recv.receive() + assert isinstance(first, SessionMessage) + assert isinstance(first.message, JSONRPCResponse) + assert first.message.id == 2 + assert first.message.result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() @pytest.mark.anyio @@ -371,7 +393,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> sent = s2c_recv.receive_nowait() assert isinstance(sent, SessionMessage) assert isinstance(sent.message, JSONRPCError) - assert sent.message.error.code == 0 + assert sent.message.error.code == INTERNAL_ERROR finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): s.close() @@ -1585,17 +1607,22 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, _server, _crec, srec): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: # pragma: no branch - - async def call() -> None: - with pytest.raises(MCPError): - await client.send_raw_request("slow", None) - - tg.start_soon(call) - await handler_started.wait() - await client.notify("notifications/cancelled", {"requestId": 1}) - await handler_exited.wait() - await srec.notified.wait() + with anyio.CancelScope() as call_scope: + async with anyio.create_task_group() as tg: + + async def call() -> None: + # cancel_on_abandon=False so retiring the await doesn't add a second + # courtesy cancel to `srec.notifications`. + await client.send_raw_request("slow", None, {"cancel_on_abandon": False}) + raise NotImplementedError # unreachable: the scope is cancelled + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + await srec.notified.wait() + call_scope.cancel() + assert call_scope.cancelled_caught assert srec.notifications == [("notifications/cancelled", {"requestId": 1})] @@ -2050,20 +2077,23 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, _server, _crec, srec): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: # pragma: no branch + with anyio.CancelScope() as call_scope: + async with anyio.create_task_group() as tg: - async def call() -> None: - with pytest.raises(MCPError): - await client.send_raw_request("slow", None) + async def call() -> None: + await client.send_raw_request("slow", None, {"cancel_on_abandon": False}) + raise NotImplementedError # unreachable: the scope is cancelled - tg.start_soon(call) - await handler_started.wait() - await client.notify("notifications/cancelled", {"requestId": True}) - # Once the teed notification is observed, the correlation arm has already run. - await srec.notified.wait() - assert not handler_exited.is_set() - await client.notify("notifications/cancelled", {"requestId": 1}) - await handler_exited.wait() + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": True}) + # Once the teed notification is observed, the correlation arm has already run. + await srec.notified.wait() + assert not handler_exited.is_set() + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + call_scope.cancel() + assert call_scope.cancelled_caught @pytest.mark.anyio @@ -2158,9 +2188,13 @@ async def test_cancelled_correlates_across_string_and_int_request_id_forms(reque c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + handler_exited = anyio.Event() async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - await anyio.sleep_forever() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() raise NotImplementedError async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: @@ -2180,11 +2214,9 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) ) with anyio.fail_after(5): - resp = await s2c_recv.receive() - assert isinstance(resp, SessionMessage) - assert isinstance(resp.message, JSONRPCError) - assert resp.message.id == request_id # response echoes the peer's id form verbatim - assert resp.message.error == ErrorData(code=0, message="Request cancelled") + # The handler exiting proves the cross-form id correlated to the in-flight entry; + # no response is read because a cancelled request gets none. + await handler_exited.wait() tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -2233,7 +2265,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> # Let the first handler task run to completion past the write. await anyio.wait_all_tasks_blocked() assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] - # The surviving entry must still be cancellable. + # The surviving entry must still be cancellable: the second handler exiting + # proves the cancel reached it (no response is read; a cancelled request gets none). await c2s_send.send( SessionMessage( message=JSONRPCNotification( @@ -2241,11 +2274,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) ) ) - resp2 = await s2c_recv.receive() - assert isinstance(resp2, SessionMessage) - assert isinstance(resp2.message, JSONRPCError) - assert resp2.message.error == ErrorData(code=0, message="Request cancelled") - assert second_exited.is_set() + await second_exited.wait() tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -2296,7 +2325,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> # Let the first handler task run past its pop entirely. await anyio.wait_all_tasks_blocked() assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] - # The surviving entry must still be cancellable by the peer. + # The surviving entry must still be cancellable by the peer: the second handler + # exiting proves the cancel reached it (no response is read; a cancelled request gets none). await c2s_send.send( SessionMessage( message=JSONRPCNotification( @@ -2304,11 +2334,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) ) ) - resp2 = await s2c_recv.receive() - assert isinstance(resp2, SessionMessage) - assert isinstance(resp2.message, JSONRPCError) - assert resp2.message.error == ErrorData(code=0, message="Request cancelled") - assert second_exited.is_set() + await second_exited.wait() tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -2359,25 +2385,29 @@ async def observe(ctx: Any, call_next: Any) -> Any: async with Client(server, mode="legacy") as client: with anyio.fail_after(5): - async with anyio.create_task_group() as tg: # pragma: no branch + with anyio.CancelScope() as call_scope: + async with anyio.create_task_group() as tg: - async def call() -> None: - with pytest.raises(MCPError): + async def call() -> None: await client.session.send_request( CallToolRequest(params=CallToolRequestParams(name="t", arguments={})), CallToolResult, ) - - tg.start_soon(call) - await handler_started.wait() - assert request_id is not None - await client.session.send_notification( - CancelledNotification( - params=CancelledNotificationParams(request_id=request_id, reason="user clicked stop") + raise NotImplementedError # unreachable: the scope is cancelled + + tg.start_soon(call) + await handler_started.wait() + assert request_id is not None + await client.session.send_notification( + CancelledNotification( + params=CancelledNotificationParams(request_id=request_id, reason="user clicked stop") + ) ) - ) - await cancel_observed.wait() - assert len(observed) == 1 + await cancel_observed.wait() + call_scope.cancel() + assert call_scope.cancelled_caught + # The hand-sent cancel is observed first; abandoning the await may emit a second courtesy + # cancel for the same id, so only the first entry is asserted. assert observed[0][0] == "notifications/cancelled" assert observed[0][1]["requestId"] == request_id assert observed[0][1]["reason"] == "user clicked stop" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index cbce222ec..3bcc8be8e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -23,6 +23,7 @@ from httpx_sse import ServerSentEvent from mcp_types import ( DEFAULT_NEGOTIATED_VERSION, + INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, CallToolRequestParams, @@ -907,11 +908,11 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session @pytest.mark.anyio async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession) -> None: - """A server-side error reaches the client as an MCPError with the handler's message.""" + """A server-side error reaches the client as an opaque INTERNAL_ERROR (the handler's message is not echoed).""" with pytest.raises(MCPError) as exc_info: await initialized_client_session.read_resource(uri="unknown://test-error") - assert exc_info.value.error.code == 0 - assert "Unknown resource: unknown://test-error" in exc_info.value.error.message + assert exc_info.value.error.code == INTERNAL_ERROR + assert exc_info.value.error.message == "Internal server error" @pytest.mark.anyio From a1734460a1711ec98a89be0cf1f2bfa8d1567b2e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 17:38:31 +0000 Subject: [PATCH 04/14] Reject bearer tokens that carry no audience claim When `AuthSettings.resource_server_url` is configured, `BearerAuthBackend` previously ran the RFC 8707 audience comparison only for tokens whose verifier populated `AccessToken.resource`: a token carrying no resource indicator at all was accepted. The MCP authorization spec requires a resource server to only accept tokens issued specifically for it, so the gate now fails closed: a verified token with no `resource` is answered `401 invalid_token` ("The access token carries no audience claim"). `resource_server_url=None` still means there is no audience to enforce. For verifiers that validate the audience themselves and cannot surface the claim (for example a JWT decoder configured with the expected audience), the new `AuthSettings.verifier_validates_audience=True` opts the gate out. The `AuthSettings.enforced_audience` property derives the single value both server wirings pass to `BearerAuthBackend`, whose signature is unchanged. `RefreshToken` gains an optional `resource` field so an authorization server provider can carry the original grant's audience binding through `exchange_refresh_token`; without it every refreshed access token would be audience-unbound and rejected by the hardened gate. The docs tutorials and example servers now populate `AccessToken.resource` (and the client-credentials demo token endpoint honors the RFC 8707 `resource` parameter) so they pass the check they teach. The migration guide entry for audience validation is rewritten for the fail-closed behavior. --- docs/advanced/authorization.md | 6 +-- docs/migration.md | 13 ++++- docs_src/authorization/tutorial001.py | 4 +- docs_src/authorization/tutorial002.py | 4 +- examples/stories/bearer_auth/server.py | 1 + .../oauth_client_credentials/server.py | 13 ++++- .../server_lowlevel.py | 13 ++++- src/mcp/server/auth/middleware/bearer_auth.py | 17 +++--- src/mcp/server/auth/provider.py | 1 + src/mcp/server/auth/settings.py | 20 +++++++ src/mcp/server/lowlevel/server.py | 4 +- src/mcp/server/mcpserver/server.py | 2 +- tests/interaction/auth/_provider.py | 7 ++- tests/interaction/auth/test_bearer.py | 54 +++++++++++++++---- tests/interaction/auth/test_discovery.py | 7 ++- 15 files changed, 133 insertions(+), 33 deletions(-) diff --git a/docs/advanced/authorization.md b/docs/advanced/authorization.md index 2c80fb5c6..335355478 100644 --- a/docs/advanced/authorization.md +++ b/docs/advanced/authorization.md @@ -16,7 +16,7 @@ That's the whole triangle. Everything on this page is the middle bullet. The SDK has no opinion about what a valid token looks like. You tell it, by implementing **`TokenVerifier`**: -```python title="server.py" hl_lines="12-14 19-24" +```python title="server.py" hl_lines="14-16 21-26" --8<-- "docs_src/authorization/tutorial001.py" ``` @@ -27,7 +27,7 @@ The SDK has no opinion about what a valid token looks like. You tell it, by impl `AuthSettings` is the public face of your resource server: * `issuer_url`: the authorization server that issues your tokens. -* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. When your verifier returns an `AccessToken.resource`, the SDK rejects the token unless it matches this URL, so a token issued for a different resource never reaches a tool. +* `resource_server_url`: the public URL of this MCP endpoint. It names *which* resource a token is for, and it's where the discovery document lives. The SDK rejects any token whose `AccessToken.resource` does not match this URL, including one whose verifier left `resource` unset, so a token issued for a different resource (or for no resource) never reaches a tool. If your verifier validates the audience itself and cannot surface the claim, set `verifier_validates_audience=True`. * `required_scopes`: every token must carry all of them. !!! tip @@ -83,7 +83,7 @@ This document is how a client that has never heard of your server finds its way Inside any handler, **`get_access_token()`** is the `AccessToken` your verifier returned for the current request: -```python title="server.py" hl_lines="4 32-35" +```python title="server.py" hl_lines="4 34-37" --8<-- "docs_src/authorization/tutorial002.py" ``` diff --git a/docs/migration.md b/docs/migration.md index aadd3cc71..944d407ee 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1477,9 +1477,18 @@ issuer inconsistent with what clients compare against under RFC 8414 / RFC 9207. already-built `AnyHttpUrl` object still normalizes at construction; pass a string to get the preserved form. -### Bearer tokens with a mismatched audience are rejected +### Bearer tokens are rejected unless their audience names this server -`BearerAuthBackend` now compares `AccessToken.resource` against `AuthSettings.resource_server_url` and answers a token whose RFC 8707 resource indicator does not name this server with `401 invalid_token`. The check is canonical-URI equality, so a token issued for `https://host/` is not accepted by a server at `https://host/mcp`. It is skipped when either side is `None` — populate `AccessToken.resource` only when your verifier surfaces the underlying audience claim. `BearerAuthBackend.__init__` gains a keyword-only `resource_server_url: AnyHttpUrl | None = None`, wired automatically from `AuthSettings`; pass it only if you construct the backend directly. +`BearerAuthBackend` now compares `AccessToken.resource` against `AuthSettings.resource_server_url` and answers any token whose RFC 8707 resource indicator does not name this server — **including a token that carries no resource indicator at all** — with `401 invalid_token`. The comparison is canonical-URI equality, so a token issued for `https://host/` is not accepted by a server at `https://host/mcp`. + +To migrate, do exactly one of: + +- **Populate `AccessToken.resource`** from the token's `aud` claim (an introspection response's `aud`, or the decoded JWT's `aud`) in your `TokenVerifier`. This is the recommended path and what the SDK's examples now show. +- **Set `AuthSettings(verifier_validates_audience=True)`** if your verifier already validates the audience itself and cannot surface it — for example a JWT library configured with `audience=` that fails decoding on a mismatch. This tells the bearer gate not to repeat a check your verifier already performed. Do not set it just to make the `401` go away: with it set, the SDK performs no audience validation of its own at all. + +Leaving `resource_server_url=None` continues to disable the check entirely (there is no audience to compare against), but a protected server should configure it: it is also the value published as RFC 9728 Protected Resource Metadata. If your authorization server does not support RFC 8707 resource indicators, your tokens will not carry an audience — audit that before opting out, because accepting audience-unbound tokens is what the MCP specification's audience-validation MUST exists to prevent. + +`RefreshToken` gains an optional `resource` field so an `OAuthAuthorizationServerProvider` can propagate the original grant's audience binding through `exchange_refresh_token`; without it a refreshed access token would carry no audience and be rejected. `BearerAuthBackend.__init__` gains a keyword-only `resource_server_url: AnyHttpUrl | None = None`, wired automatically from `AuthSettings.enforced_audience`; `None` (the default, and what the SDK passes when `verifier_validates_audience` is set) means no audience is enforced. ### Lowlevel `Server`: `subscribe` capability now correctly reported diff --git a/docs_src/authorization/tutorial001.py b/docs_src/authorization/tutorial001.py index f15f54fd7..65d7427b3 100644 --- a/docs_src/authorization/tutorial001.py +++ b/docs_src/authorization/tutorial001.py @@ -5,7 +5,9 @@ from mcp.server.auth.settings import AuthSettings KNOWN_TOKENS = { - "alice-token": AccessToken(token="alice-token", client_id="alice", scopes=["notes:read"]), + "alice-token": AccessToken( + token="alice-token", client_id="alice", scopes=["notes:read"], resource="http://127.0.0.1:8000/mcp" + ), } diff --git a/docs_src/authorization/tutorial002.py b/docs_src/authorization/tutorial002.py index 55b024f2c..8c3e45c39 100644 --- a/docs_src/authorization/tutorial002.py +++ b/docs_src/authorization/tutorial002.py @@ -6,7 +6,9 @@ from mcp.server.auth.settings import AuthSettings KNOWN_TOKENS = { - "alice-token": AccessToken(token="alice-token", client_id="alice", scopes=["notes:read"]), + "alice-token": AccessToken( + token="alice-token", client_id="alice", scopes=["notes:read"], resource="http://127.0.0.1:8000/mcp" + ), } diff --git a/examples/stories/bearer_auth/server.py b/examples/stories/bearer_auth/server.py index 45c9872c3..dfa6eada5 100644 --- a/examples/stories/bearer_auth/server.py +++ b/examples/stories/bearer_auth/server.py @@ -28,6 +28,7 @@ async def verify_token(self, token: str) -> AccessToken | None: client_id="demo-client", scopes=[REQUIRED_SCOPE], expires_at=int(time.time()) + 3600, + resource=RESOURCE_URL, subject="demo-user", ) diff --git a/examples/stories/oauth_client_credentials/server.py b/examples/stories/oauth_client_credentials/server.py index 7e3d910e8..fa13a15c2 100644 --- a/examples/stories/oauth_client_credentials/server.py +++ b/examples/stories/oauth_client_credentials/server.py @@ -13,7 +13,7 @@ from mcp.server.mcpserver import MCPServer from mcp.shared.auth import OAuthMetadata, OAuthToken from stories._hosting import NO_DNS_REBIND, run_app_from_args -from stories._shared.auth import BASE_URL, auth_settings +from stories._shared.auth import BASE_URL, MCP_URL, auth_settings # DEMO ONLY — never hard-code real credentials. DEMO_CLIENT_ID = "demo-m2m-client" @@ -65,8 +65,17 @@ async def token_endpoint(request: Request) -> JSONResponse: creds = base64.b64decode(request.headers.get("authorization", "").removeprefix("Basic ")).decode() if creds != f"{DEMO_CLIENT_ID}:{DEMO_CLIENT_SECRET}": return JSONResponse({"error": "invalid_client"}, status_code=401) + # RFC 8707 §2.2: this AS protects exactly one resource. Anything else (or a missing + # indicator) is answered with `invalid_target`, and the issued token is audience-bound + # to that one resource so the bearer gate accepts it. Never mint whatever audience the + # client names: a multi-resource AS that does so hands out tokens for resources the + # client was never granted. + if form.get("resource") != MCP_URL: + return JSONResponse({"error": "invalid_target"}, status_code=400) access = f"access_{secrets.token_hex(16)}" - issued[access] = AccessToken(token=access, client_id=DEMO_CLIENT_ID, scopes=[DEMO_SCOPE], expires_at=None) + issued[access] = AccessToken( + token=access, client_id=DEMO_CLIENT_ID, scopes=[DEMO_SCOPE], expires_at=None, resource=MCP_URL + ) body = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=DEMO_SCOPE) return JSONResponse(body.model_dump(exclude_none=True), headers={"cache-control": "no-store"}) diff --git a/examples/stories/oauth_client_credentials/server_lowlevel.py b/examples/stories/oauth_client_credentials/server_lowlevel.py index ba2003ded..a98617cf1 100644 --- a/examples/stories/oauth_client_credentials/server_lowlevel.py +++ b/examples/stories/oauth_client_credentials/server_lowlevel.py @@ -18,7 +18,7 @@ from mcp.server.lowlevel import Server from mcp.shared.auth import OAuthMetadata, OAuthToken from stories._hosting import NO_DNS_REBIND, run_app_from_args -from stories._shared.auth import BASE_URL, auth_settings +from stories._shared.auth import BASE_URL, MCP_URL, auth_settings from .server import DEMO_CLIENT_ID, DEMO_CLIENT_SECRET, DEMO_SCOPE @@ -62,8 +62,17 @@ async def token_endpoint(request: Request) -> JSONResponse: creds = base64.b64decode(request.headers.get("authorization", "").removeprefix("Basic ")).decode() if creds != f"{DEMO_CLIENT_ID}:{DEMO_CLIENT_SECRET}": return JSONResponse({"error": "invalid_client"}, status_code=401) + # RFC 8707 §2.2: this AS protects exactly one resource. Anything else (or a missing + # indicator) is answered with `invalid_target`, and the issued token is audience-bound + # to that one resource so the bearer gate accepts it. Never mint whatever audience the + # client names: a multi-resource AS that does so hands out tokens for resources the + # client was never granted. + if form.get("resource") != MCP_URL: + return JSONResponse({"error": "invalid_target"}, status_code=400) access = f"access_{secrets.token_hex(16)}" - issued[access] = AccessToken(token=access, client_id=DEMO_CLIENT_ID, scopes=[DEMO_SCOPE], expires_at=None) + issued[access] = AccessToken( + token=access, client_id=DEMO_CLIENT_ID, scopes=[DEMO_SCOPE], expires_at=None, resource=MCP_URL + ) body = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=DEMO_SCOPE) return JSONResponse(body.model_dump(exclude_none=True), headers={"cache-control": "no-store"}) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index d80e13f15..8c2a8e9d5 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -69,6 +69,12 @@ class BearerAuthBackend(AuthenticationBackend): """Authentication backend that validates Bearer tokens using a TokenVerifier.""" def __init__(self, token_verifier: TokenVerifier, *, resource_server_url: AnyHttpUrl | None = None) -> None: + """Validate bearer tokens with `token_verifier` and, when `resource_server_url` is set, + enforce that every token's RFC 8707 audience names exactly that resource. + + `resource_server_url=None` means there is no audience to enforce; verification stops at + the verifier's answer and the expiry check. + """ self.token_verifier = token_verifier self.resource_server_url = resource_server_url @@ -86,12 +92,11 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas return AuthCredentials(), InvalidTokenUser("The access token is malformed or unknown") if auth_info.expires_at is not None and auth_info.expires_at < int(time.time()): return AuthCredentials(), InvalidTokenUser("The access token has expired") - if ( - self.resource_server_url is not None - and auth_info.resource is not None - and not check_token_audience(auth_info.resource, self.resource_server_url) - ): - return AuthCredentials(), InvalidTokenUser("The access token was issued for a different resource") + if self.resource_server_url is not None: + if auth_info.resource is None: + return AuthCredentials(), InvalidTokenUser("The access token carries no audience claim") + if not check_token_audience(auth_info.resource, self.resource_server_url): + return AuthCredentials(), InvalidTokenUser("The access token was issued for a different resource") return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index eeb371f1c..0fb52181a 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -46,6 +46,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + resource: str | None = None # RFC 8707 resource indicator; propagate to refreshed AccessTokens subject: str | None = None # resource owner; propagate to refreshed AccessTokens diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index ae2083a38..674501a91 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -40,3 +40,23 @@ class AuthSettings(BaseModel): description="The URL of the MCP server to be used as the resource identifier " "and base route to look up OAuth Protected Resource Metadata.", ) + + verifier_validates_audience: bool = Field( + default=False, + description="Set when your TokenVerifier validates the token's audience itself and " + "therefore never populates AccessToken.resource (for example a JWT decoder configured " + "with the expected audience). The bearer gate then skips its own audience check. " + "Leave False to have the SDK reject any token whose resource indicator is absent or " + "names a different server.", + ) + + @property + def enforced_audience(self) -> AnyHttpUrl | None: + """The resource identifier the bearer gate compares each token's audience against. + + `None` when no `resource_server_url` is configured, or when + `verifier_validates_audience` declares that the verifier already did the check -- in + both cases the gate has nothing of its own to enforce. Both server wirings read this, + so it is the single source of the should-the-gate-audience-check decision. + """ + return None if self.verifier_validates_audience else self.resource_server_url diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 30173238a..dcfa2974a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -710,7 +710,7 @@ def streamable_http_app( middleware = [ Middleware( AuthenticationMiddleware, - backend=BearerAuthBackend(token_verifier, resource_server_url=auth.resource_server_url), + backend=BearerAuthBackend(token_verifier, resource_server_url=auth.enforced_audience), ), Middleware(AuthContextMiddleware), ] @@ -732,7 +732,7 @@ def streamable_http_app( if token_verifier: # Determine resource metadata URL resource_metadata_url = None - if auth and auth.resource_server_url: # pragma: no branch + if auth and auth.resource_server_url: # Build compliant metadata URL for WWW-Authenticate header resource_metadata_url = build_resource_metadata_url(auth.resource_server_url) diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 511c19c8b..4c8a44531 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -1001,7 +1001,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no Middleware( AuthenticationMiddleware, backend=BearerAuthBackend( - self._token_verifier, resource_server_url=self.settings.auth.resource_server_url + self._token_verifier, resource_server_url=self.settings.auth.enforced_audience ), ), # Add the auth context middleware to store diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py index 0c54d4fd3..4e422947b 100644 --- a/tests/interaction/auth/_provider.py +++ b/tests/interaction/auth/_provider.py @@ -157,6 +157,7 @@ async def exchange_authorization_code( token=refresh, client_id=client.client_id, scopes=authorization_code.scopes, + resource=authorization_code.resource, ) del self.codes[authorization_code.code] return OAuthToken( @@ -183,9 +184,11 @@ async def exchange_refresh_token( if self._fail_next_refresh: self._fail_next_refresh = False raise TokenError(error="invalid_grant", error_description="refresh denied by harness") - access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + access = self.mint_access_token(client_id=client.client_id, scopes=scopes, resource=refresh_token.resource) new_refresh = f"refresh_{secrets.token_hex(16)}" - self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + self.refresh_tokens[new_refresh] = RefreshToken( + token=new_refresh, client_id=client.client_id, scopes=scopes, resource=refresh_token.resource + ) del self.refresh_tokens[refresh_token.token] return OAuthToken( access_token=access, diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py index 120690df3..483cd1087 100644 --- a/tests/interaction/auth/test_bearer.py +++ b/tests/interaction/auth/test_bearer.py @@ -3,8 +3,9 @@ These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, -the `WWW-Authenticate` header structure, and that a token with no audience claim reaches the -MCP endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +the `WWW-Authenticate` header structure, and that the audience gate fails closed (a token with +no audience claim is rejected unless `AuthSettings.verifier_validates_audience` opts the gate +out). The flow side of the same 401 is `test_flow.py`'s flagship test. """ import time @@ -28,11 +29,20 @@ _FUTURE = int(time.time()) + 3600 _PAST = int(time.time()) - 3600 +# The audience `auth_settings()` configures as `resource_server_url`. Every fixture token with +# exactly one non-audience defect carries it, so each token isolates the defect its test names. +RESOURCE = "http://127.0.0.1:8000/mcp" TOKENS = { - "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), - "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), - "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-valid": AccessToken( + token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE, resource=RESOURCE + ), + "tok-expired": AccessToken( + token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST, resource=RESOURCE + ), + "tok-noscope": AccessToken( + token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE, resource=RESOURCE + ), "tok-wrong-aud": AccessToken( token="tok-wrong-aud", client_id="c", @@ -202,15 +212,39 @@ async def test_a_token_for_a_parent_path_on_the_same_origin_is_answered_401_inva @requirement("hosting:auth:aud-validation") -async def test_a_token_without_a_resource_claim_passes_the_audience_check(protected: httpx.AsyncClient) -> None: - """A token whose `AccessToken.resource` is unset passes the audience check. +async def test_a_token_without_a_resource_claim_is_answered_401_invalid_token( + protected: httpx.AsyncClient, +) -> None: + """A token whose `AccessToken.resource` is unset is answered 401 when an audience is configured. - SDK-defined pass-through: the SDK cannot distinguish a verifier that performed its own - audience check and chose not to surface the claim from a token that genuinely carries - none, so `resource is None` is accepted. This pins that policy. + Spec-mandated (authorization MUST: servers reject tokens that do not include them in the + audience claim). The bearer gate fails closed; the operator-level escape hatch for + verifiers that validate audience internally is `AuthSettings.verifier_validates_audience`. """ response = await post_mcp(protected, bearer="tok-no-aud") + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "The access token carries no audience claim", + "scope": REQUIRED_SCOPE, + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_without_a_resource_claim_passes_when_verifier_validates_audience_is_set() -> None: + """With `verifier_validates_audience=True` the bearer gate skips its own audience check. + + SDK-defined opt-out for the spec's "or otherwise verify" clause: a verifier that validates + the token's audience internally (a JWT decoder configured with the expected audience) and so + never populates `AccessToken.resource`. The body proves the request reached the MCP endpoint. + """ + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]).model_copy(update={"verifier_validates_audience": True}) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + response = await post_mcp(http, bearer="tok-no-aud") + assert response.status_code == 200 assert response.headers["content-type"].startswith("text/event-stream") # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py index 1317fd19d..9bac7585d 100644 --- a/tests/interaction/auth/test_discovery.py +++ b/tests/interaction/auth/test_discovery.py @@ -134,9 +134,14 @@ async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_th provider = InMemoryAuthorizationServerProvider() server = Server("guarded", on_list_tools=list_tools) app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + # A legacy server publishes no protected-resource metadata, so it also has no resource + # identifier for the bearer gate to enforce: the client never learns a `resource` to bind. + legacy_settings = auth_settings().model_copy(update={"resource_server_url": None}) with anyio.fail_after(5): - async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + async with connect_with_oauth( + server, provider=provider, settings=legacy_settings, app_shim=app_shim, on_request=on_request + ) as ( client, _, ): From 884badc944dc4a7387f0fe9b1fe514f22fa33f34 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 17:47:05 +0000 Subject: [PATCH 05/14] Use spec error codes for unhandled elicitation/create and roots/list A client constructed without an elicitation_callback or list_roots_callback still answers the server's request, via a default callback that returns a JSON-RPC error. Both defaults used -32600 (invalid request). The spec assigns a specific code to each case: - elicitation/create: -32602 (invalid params). A client with no callback declares no elicitation modes, so every incoming request names an undeclared mode, which clients MUST answer with -32602. - roots/list: -32601 (method not found), the code clients SHOULD use when they do not support roots. The default sampling callback keeps -32600: the spec assigns no code to a client that does not support sampling. Error messages are unchanged. Update the affected tests and the interaction-requirement entries that recorded the old codes, fix the code named in the client callbacks doc page, and add a migration note. --- docs/client/callbacks.md | 2 +- docs/migration.md | 6 +++++ src/mcp/client/session.py | 6 +++-- tests/client/test_client.py | 4 +-- tests/client/test_list_roots_callback.py | 4 +-- tests/docs_src/test_client_callbacks.py | 3 ++- tests/docs_src/test_mrtr.py | 4 +-- tests/interaction/_requirements.py | 6 ----- .../interaction/lowlevel/test_elicitation.py | 25 +++++++++++-------- tests/interaction/lowlevel/test_roots.py | 20 +++++++++------ 10 files changed, 47 insertions(+), 33 deletions(-) diff --git a/docs/client/callbacks.md b/docs/client/callbacks.md index db2c4d7cd..93e039bf5 100644 --- a/docs/client/callbacks.md +++ b/docs/client/callbacks.md @@ -105,7 +105,7 @@ Pass all three callbacks and you get `['elicitation', 'sampling', 'roots']`. Pas MCPError: Elicitation not supported ``` - That is a protocol error (`-32600`, *invalid request*), not a tool error: there is nothing for + That is a protocol error (`-32602`, *invalid params*), not a tool error: there is nothing for the model to read and retry. It's why `client_features` is worth having: a well-behaved server checks before it asks. diff --git a/docs/migration.md b/docs/migration.md index 944d407ee..079bee405 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -410,6 +410,12 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` `Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. +### Unhandled `elicitation/create` returns `-32602`; unhandled `roots/list` returns `-32601` + +When a server sends `elicitation/create` to a client that registered no `elicitation_callback`, or `roots/list` to a client that registered no `list_roots_callback`, the SDK still answers on the client's behalf with a JSON-RPC error. In v1 both answers used code `-32600` (`INVALID_REQUEST`). They now use the code the spec assigns to each case: `elicitation/create` is answered with `-32602` (`INVALID_PARAMS`), per the [elicitation error-handling section](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation#error-handling) (a client with no callback declared no elicitation modes, and a request for an undeclared mode MUST be answered with `-32602`), and `roots/list` is answered with `-32601` (`METHOD_NOT_FOUND`), per the [roots error-handling section](https://modelcontextprotocol.io/specification/2025-11-25/client/roots#error-handling). The error messages (`Elicitation not supported`, `List roots not supported`) are unchanged, and `sampling/createMessage` without a `sampling_callback` still answers `-32600` — the spec assigns no code to that case. + +Server-side code that branched on `error.code == INVALID_REQUEST` to detect a client without elicitation or roots support should switch to `INVALID_PARAMS` and `METHOD_NOT_FOUND` respectively — or, better, check the client's declared capabilities before sending, which is the condition these codes describe. + ### `InputRequiredResult` handling differs between `Client` and `ClientSession` For protocol 2026-07-28, `tools/call`, `prompts/get`, and `resources/read` may return an `InputRequiredResult` asking the client to supply additional input (sampling, elicitation, roots) and retry. diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fa71d1330..29e83902c 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -149,6 +149,8 @@ async def _default_sampling_callback( context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: + # Unlike elicitation (INVALID_PARAMS) and roots (METHOD_NOT_FOUND) below, the spec assigns no + # error code to a client that does not support sampling; INVALID_REQUEST is the SDK's choice. return types.ErrorData( code=types.INVALID_REQUEST, message="Sampling not supported", @@ -160,7 +162,7 @@ async def _default_elicitation_callback( params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( - code=types.INVALID_REQUEST, + code=types.INVALID_PARAMS, message="Elicitation not supported", ) @@ -169,7 +171,7 @@ async def _default_list_roots_callback( context: ClientRequestContext, ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( - code=types.INVALID_REQUEST, + code=types.METHOD_NOT_FOUND, message="List roots not supported", ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a6a9ac6ea..568ee2c62 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -681,7 +681,7 @@ async def elicitation_callback( async def test_call_tool_auto_loop_raises_mcp_error_when_no_callback_registered() -> None: """SDK-defined: with no `elicitation_callback`, the default returns - `ErrorData(INVALID_REQUEST, ...)` and the driver raises it as `MCPError` + `ErrorData(INVALID_PARAMS, ...)` and the driver raises it as `MCPError` rather than retrying.""" server = MCPServer("test") @@ -694,7 +694,7 @@ async def needs_input(ctx: Context) -> str | types.InputRequiredResult: async with Client(server) as client: with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.call_tool("needs_input") - assert exc.value.error.code == types.INVALID_REQUEST + assert exc.value.error.code == types.INVALID_PARAMS async def test_get_prompt_auto_loop_resolves_input_required_via_callbacks() -> None: diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 14ca1577d..cda28377b 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,5 +1,5 @@ import pytest -from mcp_types import INVALID_REQUEST, ListRootsResult, Root, TextContent +from mcp_types import METHOD_NOT_FOUND, ListRootsResult, Root, TextContent from pydantic import FileUrl from mcp import Client @@ -44,4 +44,4 @@ async def test_list_roots(context: Context, message: str): async with Client(server, mode="legacy") as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("test_list_roots", {"message": "test message"}) - assert exc_info.value.error.code == INVALID_REQUEST + assert exc_info.value.error.code == METHOD_NOT_FOUND diff --git a/tests/docs_src/test_client_callbacks.py b/tests/docs_src/test_client_callbacks.py index b615c4700..3b6a085b7 100644 --- a/tests/docs_src/test_client_callbacks.py +++ b/tests/docs_src/test_client_callbacks.py @@ -3,6 +3,7 @@ import pytest from inline_snapshot import snapshot from mcp_types import ( + INVALID_PARAMS, INVALID_REQUEST, CreateMessageRequestParams, CreateMessageResult, @@ -74,7 +75,7 @@ async def test_without_the_callback_the_servers_request_is_refused() -> None: async with Client(tutorial001.mcp, mode="legacy") as client: with pytest.raises(MCPError, match="Elicitation not supported") as exc_info: await client.call_tool("issue_card") - assert exc_info.value.error.code == INVALID_REQUEST + assert exc_info.value.error.code == INVALID_PARAMS async def test_registering_the_callback_declares_the_capability() -> None: diff --git a/tests/docs_src/test_mrtr.py b/tests/docs_src/test_mrtr.py index 4be449edc..7d87b7f2e 100644 --- a/tests/docs_src/test_mrtr.py +++ b/tests/docs_src/test_mrtr.py @@ -4,7 +4,7 @@ from inline_snapshot import snapshot from mcp_types import ( INTERNAL_ERROR, - INVALID_REQUEST, + INVALID_PARAMS, CallToolResult, CreateMessageRequest, CreateMessageRequestParams, @@ -62,7 +62,7 @@ async def test_the_auto_loop_without_a_callback_raises_mcp_error() -> None: async with Client(tutorial001.server) as client: with pytest.raises(MCPError) as exc: await client.call_tool("provision", {"name": "orders"}) - assert exc.value.error.code == INVALID_REQUEST + assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "Elicitation not supported" diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d0e6693d2..5e49b7ad2 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1724,9 +1724,6 @@ def __post_init__(self) -> None: "An elicitation request to a client that did not declare the elicitation capability is " "answered with -32602 Invalid params." ), - divergence=Divergence( - note="The client's default callback answers with -32600 Invalid request instead of -32602.", - ), arm_exclusions=( ArmExclusion(reason="server-initiated-request", transport="streamable-http-stateless"), ArmExclusion(reason="server-initiated-request", spec_version="2026-07-28"), @@ -1934,9 +1931,6 @@ def __post_init__(self) -> None: "A roots/list request to a client that did not declare the roots capability is answered with " "-32601 Method not found." ), - divergence=Divergence( - note="The client's default callback answers with -32600 Invalid request instead of -32601.", - ), arm_exclusions=( ArmExclusion(reason="server-initiated-request", transport="streamable-http-stateless"), ArmExclusion(reason="server-initiated-request", spec_version="2026-07-28"), diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index b8393dd31..0b687a7b0 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -9,6 +9,7 @@ import pytest from inline_snapshot import snapshot from mcp_types import ( + INVALID_PARAMS, CallToolResult, ElicitCompleteNotification, ElicitCompleteNotificationParams, @@ -160,14 +161,13 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest @requirement("elicitation:form:not-supported") @requirement("elicitation:capability:server-respects-mode") -async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: - """Eliciting from a client that configured no elicitation callback fails with an error. - - The client's default callback answers with an Invalid request error, which the server-side - elicit call raises as an MCPError; the tool reports the code and message it caught. The spec - requires -32602 for an undeclared mode (see the divergence note on the requirement). The - request reaching the client also shows the server does not check the client's declared - elicitation capability before sending (see the divergence on `server-respects-mode`). +async def test_elicit_form_without_callback_fails_with_invalid_params(connect: Connect) -> None: + """Eliciting from a client that configured no elicitation callback fails with `INVALID_PARAMS`. + + Spec-mandated (MUST): an elicitation/create whose mode the client never declared is answered + with -32602 Invalid params, and a client with no callback declared no elicitation modes at all. + The request reaching the client at all also shows the server does not check the client's + declared elicitation capability before sending (see the divergence on `server-respects-mode`). """ async def list_tools( @@ -177,12 +177,15 @@ async def list_tools( tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] ) + errors: list[ErrorData] = [] + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask" try: await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + errors.append(exc.error) + return CallToolResult(content=[TextContent(text=exc.error.message)]) raise NotImplementedError # elicit_form cannot succeed without a client callback server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) @@ -190,7 +193,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: result = await client.call_tool("ask", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + assert result == snapshot(CallToolResult(content=[TextContent(text="Elicitation not supported")])) + (error,) = errors + assert error.code == INVALID_PARAMS @requirement("elicitation:url:action:accept-no-content") diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index bfd6cc90a..f207844f8 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -4,7 +4,7 @@ import mcp_types as types import pytest from inline_snapshot import snapshot -from mcp_types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from mcp_types import INTERNAL_ERROR, METHOD_NOT_FOUND, CallToolResult, ErrorData, ListRootsResult, Root, TextContent from pydantic import FileUrl from mcp import MCPError @@ -80,11 +80,12 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: @requirement("roots:list:not-supported") -async def test_list_roots_without_callback_is_error(connect: Connect) -> None: - """A roots/list request to a client with no roots callback fails with an error the handler can observe. +async def test_list_roots_without_callback_fails_with_method_not_found(connect: Connect) -> None: + """A roots/list request to a client with no roots callback fails with `METHOD_NOT_FOUND`. - The client's default callback answers with INVALID_REQUEST rather than leaving the server - hanging; the spec names -32601 for this case (see the divergence note on the requirement). + Spec-recommended (SHOULD): a client that does not support roots answers roots/list with + -32601 Method not found. The error reaches the requesting server handler as an `MCPError` + rather than leaving it hanging. """ async def list_tools( @@ -92,12 +93,15 @@ async def list_tools( ) -> types.ListToolsResult: return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + errors: list[ErrorData] = [] + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "show_roots" try: await ctx.session.list_roots() # pyright: ignore[reportDeprecated] except MCPError as exc: - return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + errors.append(exc.error) + return CallToolResult(content=[TextContent(text=exc.error.message)]) raise NotImplementedError # list_roots cannot succeed without a client callback server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) @@ -105,7 +109,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: result = await client.call_tool("show_roots", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + assert result == snapshot(CallToolResult(content=[TextContent(text="List roots not supported")])) + (error,) = errors + assert error.code == METHOD_NOT_FOUND @requirement("roots:list:client-error") From 24061fe98be9984e2ad7266e9d87d0b8a57d261b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 18:01:05 +0000 Subject: [PATCH 06/14] Reject non-HTTPS, non-loopback redirect URIs at client registration Two fixes to the optional bundled OAuth authorization server (the `auth_server_provider=` path). The registration endpoint accepted any well-formed URL as a `redirect_uris` entry: cleartext `http://` on a non-loopback host, `javascript:`, `data:`, and URIs carrying a fragment all registered successfully. The MCP authorization specification's Communication Security section requires every redirect URI to be either localhost or HTTPS, and OAuth 2.1 section 2.3 forbids a fragment component. Such an entry is now rejected with `400 invalid_client_metadata`. Loopback is exactly the three forms OAuth 2.1 section 8.4.2 names (`localhost`, `127.0.0.1`, `[::1]`), on any port; query strings remain permitted. This also rejects RFC 8252 private-use schemes such as `com.example.app:/callback`: MCP restricts redirect URIs to HTTPS or loopback, with no carve-out for native apps. The rule lives on the request model: `RegistrationRequest`, until now a dead alias of `OAuthClientMetadata`, becomes a real subclass with a `redirect_uris` field validator, so a forbidden URI fails parsing and takes the handler's existing `invalid_client_metadata` arm rather than needing a post-parse check. `OAuthClientMetadata` itself is unchanged: the client also serializes it when registering against third-party authorization servers whose redirect-URI policies the SDK does not own. The loopback host set moves to a single `LOOPBACK_HOSTS` constant in `mcp.server.auth.provider`, shared with `validate_issuer_url`, which previously inlined the same tuple. Separately, the token endpoint now answers an authorization-code exchange whose `redirect_uri` does not match the one used at `/authorize` with `error=invalid_grant` instead of `invalid_request`. RFC 6749 section 5.2 assigns this case to `invalid_grant` ("does not match the redirection URI used in the authorization request"), and the handler's other authorization-code failures already use it. The exchange was already rejected with HTTP 400; only the `error` field changes. Update the affected tests and the interaction-requirement entries, and add a migration note. Closes #2629 --- docs/migration.md | 8 ++ src/mcp/server/auth/handlers/register.py | 36 +++++-- src/mcp/server/auth/handlers/token.py | 2 +- src/mcp/server/auth/provider.py | 4 + src/mcp/server/auth/routes.py | 4 +- tests/interaction/_requirements.py | 17 +--- tests/interaction/auth/test_as_handlers.py | 98 ++++++++++++++++--- .../mcpserver/auth/test_auth_integration.py | 2 +- 8 files changed, 132 insertions(+), 39 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 079bee405..8c5bc1e0f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1496,6 +1496,14 @@ Leaving `resource_server_url=None` continues to disable the check entirely (ther `RefreshToken` gains an optional `resource` field so an `OAuthAuthorizationServerProvider` can propagate the original grant's audience binding through `exchange_refresh_token`; without it a refreshed access token would carry no audience and be rejected. `BearerAuthBackend.__init__` gains a keyword-only `resource_server_url: AnyHttpUrl | None = None`, wired automatically from `AuthSettings.enforced_audience`; `None` (the default, and what the SDK passes when `verifier_validates_audience` is set) means no audience is enforced. +### Bundled authorization server: RFC-correct redirect-URI handling + +Two fixes to the optional bundled OAuth authorization server (the `auth_server_provider=` path). + +The token endpoint now answers an authorization-code exchange whose `redirect_uri` does not match the one used at `/authorize` with `error=invalid_grant` instead of `error=invalid_request`. RFC 6749 §5.2 assigns this case to `invalid_grant` ("does not match the redirection URI used in the authorization request"). The exchange was already rejected with HTTP 400; only the `error` field changes. + +The registration endpoint now rejects a `redirect_uris` entry that is neither HTTPS nor a loopback host (`localhost`, `127.0.0.1`, or `[::1]`) with `400 invalid_client_metadata`. Previously any well-formed URL — including cleartext `http://` on a non-loopback host, `javascript:`, and `data:` — was accepted and stored. The MCP authorization specification's Communication Security section requires every redirect URI to be either `localhost` or HTTPS; the SDK accepts the three loopback forms OAuth 2.1 §8.4.2 names. Local development against `http://localhost:*`, `http://127.0.0.1:*`, or `http://[::1]:*` is unaffected. Note that this also rejects RFC 8252 private-use URI schemes (such as `com.example.app:/callback`): MCP restricts redirect URIs to HTTPS or loopback, which is stricter than vanilla OAuth allows for native apps. A redirect URI carrying a fragment component is also rejected (OAuth 2.1 §2.3). Query strings remain permitted. + ### Lowlevel `Server`: `subscribe` capability now correctly reported Previously, the lowlevel `Server` hardcoded `subscribe=False` in resource capabilities even when a `subscribe_resource()` handler was registered. The `subscribe` capability is now dynamically set to `True` when an `on_subscribe_resource` handler is provided. Clients that previously didn't see `subscribe: true` in capabilities will now see it when a handler is registered, which may change client behavior. diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e565b2738..59ea456ef 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -4,19 +4,43 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, ValidationError +from pydantic import AnyUrl, BaseModel, ValidationError, field_validator from starlette.requests import Request from starlette.responses import Response from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode +from mcp.server.auth.provider import ( + LOOPBACK_HOSTS, + OAuthAuthorizationServerProvider, + RegistrationError, + RegistrationErrorCode, +) from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import JWT_BEARER_GRANT_TYPE, OAuthClientInformationFull, OAuthClientMetadata -# this alias is a no-op; it's just to separate out the types exposed to the -# provider from what we use in the HTTP handler -RegistrationRequest = OAuthClientMetadata + +class RegistrationRequest(OAuthClientMetadata): + """The registration endpoint's inbound client metadata, with server-side redirect-URI policy. + + The MCP authorization spec requires every redirect URI to use HTTPS or target a loopback + host, and OAuth 2.1 section 2.3 forbids a fragment component, so a request carrying a URI + that violates either fails validation and never reaches the provider. The base + `OAuthClientMetadata` stays permissive: the client also serializes it when registering + against third-party authorization servers whose redirect-URI policies the SDK does not own. + """ + + @field_validator("redirect_uris", mode="after") + @classmethod + def _https_or_loopback_without_fragment(cls, v: list[AnyUrl] | None) -> list[AnyUrl] | None: + # None and an empty list both mean there is nothing to check. + for uri in v or []: + if uri.scheme != "https" and uri.host not in LOOPBACK_HOSTS: + raise ValueError(f"redirect_uri must use https or target a loopback host: {uri}") + # `is not None`, not truthiness: a bare `https://x/cb#` parses with fragment == "". + if uri.fragment is not None: + raise ValueError(f"redirect_uri must not include a fragment: {uri}") + return v class RegistrationErrorResponse(BaseModel): @@ -33,7 +57,7 @@ async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: body = await request.body() - client_metadata = OAuthClientMetadata.model_validate_json(body) + client_metadata = RegistrationRequest.model_validate_json(body) # Scope validation is handled below except ValidationError as validation_error: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0e644c378..07f559b46 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -177,7 +177,7 @@ async def handle(self, request: Request): if token_redirect_str != auth_redirect_str: return self.response( TokenErrorResponse( - error="invalid_request", + error="invalid_grant", error_description=("redirect_uri did not match the one used when creating auth code"), ) ) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 0fb52181a..e2fb722f7 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -6,6 +6,10 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +# OAuth 2.1 §8.4.2: the loopback IP literal `127.0.0.1` or `[::1]`, or the hostname `localhost`. +# Spelled as pydantic's `AnyUrl.host` reports them (an IPv6 literal keeps its brackets). +LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "[::1]"}) + class AuthorizationParams(BaseModel): state: str | None diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index fa88dddcf..d44c62ced 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -15,7 +15,7 @@ from mcp.server.auth.handlers.revoke import RevocationHandler from mcp.server.auth.handlers.token import TokenHandler from mcp.server.auth.middleware.client_auth import ClientAuthenticator -from mcp.server.auth.provider import OAuthAuthorizationServerProvider +from mcp.server.auth.provider import LOOPBACK_HOSTS, OAuthAuthorizationServerProvider from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import JWT_BEARER_GRANT_TYPE, OAuthMetadata, ProtectedResourceMetadata from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER @@ -32,7 +32,7 @@ def validate_issuer_url(url: AnyHttpUrl): """ # RFC 8414 requires HTTPS, but we allow loopback/localhost HTTP for testing - if url.scheme != "https" and url.host not in ("localhost", "127.0.0.1", "[::1]"): + if url.scheme != "https" and url.host not in LOOPBACK_HOSTS: raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 5e49b7ad2..31a2af8b7 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2689,28 +2689,15 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), note="Auth is enforced at the HTTP layer; the bundled AS is an ASGI app.", - divergence=Divergence( - note=( - "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " - "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " - "The rejection itself is the security-relevant property and is correct." - ), - ), ), "hosting:auth:as:redirect-uri-scheme": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", behavior=( - "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback " + "host (`localhost`, `127.0.0.1`, `[::1]`), and that carry no fragment component." ), transports=("streamable-http",), note="Auth is enforced at the HTTP layer; the bundled AS is an ASGI app.", - divergence=Divergence( - note=( - "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " - "host check, so http://evil.example/callback is accepted and registered. The spec's " - "localhost-or-HTTPS rule is left to the provider implementation." - ), - ), ), "hosting:auth:as:token-cache-headers": Requirement( source="sdk", diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py index 5cb4e92d8..23fae5eca 100644 --- a/tests/interaction/auth/test_as_handlers.py +++ b/tests/interaction/auth/test_as_handlers.py @@ -170,16 +170,15 @@ async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( @requirement("hosting:auth:as:redirect-uri-binding") -async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( +async def test_a_token_exchange_with_a_mismatched_redirect_uri_is_rejected_with_invalid_grant( as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], ) -> None: - """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected with `invalid_grant`. This is the security-critical half of redirect-URI binding: a code intercepted via redirect substitution cannot be redeemed because the attacker cannot reproduce the original authorize - redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; - the SDK returns `invalid_request` (see the divergence on the requirement). The rejection - itself is the security property and is correct. + redirect URI at the token endpoint. RFC 6749 §5.2 assigns the mismatch to `invalid_grant`, + matching the handler's other authorization-code failures. """ http, _ = as_app client_info, code, verifier = await _mint_code(http) @@ -192,7 +191,7 @@ async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_ assert response.status_code == 400 assert response.json() == snapshot( { - "error": "invalid_request", + "error": "invalid_grant", "error_description": "redirect_uri did not match the one used when creating auth code", } ) @@ -279,22 +278,93 @@ async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( @requirement("hosting:auth:as:redirect-uri-scheme") -async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( - as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +@pytest.mark.parametrize( + "redirect_uri", + [ + "http://evil.example/callback", + "http://localhost.evil.example/callback", + "javascript:alert(1)", + "com.example.app:/oauth/cb", + ], +) +async def test_a_redirect_uri_that_is_neither_https_nor_loopback_is_rejected_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str +) -> None: + """A registration whose redirect URI is neither HTTPS nor a loopback host is rejected with 400. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the + registration request model enforces this at parse time so the provider never sees the + client. Loopback is matched on the whole host (`localhost.evil.example` is not loopback), + and a scheme with no authority — `javascript:`, or an RFC 8252 private-use scheme such as + `com.example.app:` — fails the same check. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = [redirect_uri] + + response = await http.post("/register", json=body) + + assert response.status_code == 400 + error = response.json() + assert error["error"] == "invalid_client_metadata" + # Pydantic frames the validator's message as `redirect_uris: Value error, ` (third-party + # text), so assert only the SDK-authored sentence to pin which validation fired. + assert "redirect_uri must use https or target a loopback host" in error["error_description"] + assert provider.clients == {} + + +@requirement("hosting:auth:as:redirect-uri-scheme") +@pytest.mark.parametrize( + "redirect_uri", + [ + "https://app.example.com/callback", + "http://localhost:3030/callback", + "http://127.0.0.1:8000/callback", + "http://[::1]:8000/callback", + ], +) +async def test_an_https_or_loopback_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str ) -> None: - """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + """A registration whose redirect URI uses HTTPS or targets a loopback host is accepted and stored. - The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled - registration handler does not enforce this and registers `http://evil.example/callback` - successfully. See the divergence on the requirement. + Loopback covers exactly the three forms OAuth 2.1 names: the hostname `localhost` and the + loopback IP literals `127.0.0.1` and `[::1]`, on any port, over plain HTTP. """ http, provider = as_app body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) - body["redirect_uris"] = ["http://evil.example/callback"] + body["redirect_uris"] = [redirect_uri] response = await http.post("/register", json=body) assert response.status_code == 201 info = OAuthClientInformationFull.model_validate_json(response.content) - assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert [str(u) for u in (info.redirect_uris or [])] == [redirect_uri] assert info.client_id in provider.clients + + +@requirement("hosting:auth:as:redirect-uri-scheme") +@pytest.mark.parametrize( + "redirect_uri", ["https://app.example.com/callback#", "https://app.example.com/callback#nonce"] +) +async def test_a_redirect_uri_carrying_a_fragment_is_rejected_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str +) -> None: + """A registration whose redirect URI carries a fragment component is rejected with 400. + + OAuth 2.1 section 2.3: a redirect URI MUST NOT include a fragment component. The bare + trailing `#` parses to an empty-string fragment and is rejected the same as a named one. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = [redirect_uri] + + response = await http.post("/register", json=body) + + assert response.status_code == 400 + error = response.json() + assert error["error"] == "invalid_client_metadata" + # Pydantic frames the validator's message as `redirect_uris: Value error, ` (third-party + # text), so assert only the SDK-authored sentence to pin which validation fired. + assert "redirect_uri must not include a fragment" in error["error_description"] + assert provider.clients == {} diff --git a/tests/server/mcpserver/auth/test_auth_integration.py b/tests/server/mcpserver/auth/test_auth_integration.py index 35fec1c57..7cda36fed 100644 --- a/tests/server/mcpserver/auth/test_auth_integration.py +++ b/tests/server/mcpserver/auth/test_auth_integration.py @@ -501,7 +501,7 @@ async def test_token_redirect_uri_mismatch( ) assert response.status_code == 400 error_response = response.json() - assert error_response["error"] == "invalid_request" + assert error_response["error"] == "invalid_grant" assert "redirect_uri did not match" in error_response["error_description"] @pytest.mark.anyio From 3eb352c8b05e91182533ccf6859c321d1578990f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 18:13:18 +0000 Subject: [PATCH 07/14] Align OAuth client with spec on PKCE verification and scope selection Two changes to how the OAuth client uses discovered authorization-server metadata, both required by the MCP authorization specification. Verify PKCE support before the authorization-code grant. The spec's Authorization Code Protection section requires clients to verify PKCE support from the authorization server's metadata and to refuse to proceed when code_challenge_methods_supported is absent. The new validate_pkce_support() also refuses a method list that omits S256, since that is the only method this client sends. The check sits at the top of _perform_authorization_code_grant, so it covers both the initial 401 flow and the 403 insufficient_scope step-up, while grants that never issue an authorization code (client credentials, private key JWT) are unaffected. When no metadata document was discovered at all the flow proceeds as before: absence of a document is not evidence of non-support. Stop reading scopes_supported from authorization-server metadata when selecting a scope. The spec's scope-selection chain is the WWW-Authenticate scope parameter, then the protected-resource metadata's scopes_supported, otherwise omit the scope parameter. The SDK inserted an extra fallback to the authorization server's scopes_supported, which over-requests (an authorization server may serve many resource servers, so its list is a superset of any one resource's) and causes access_denied failures against servers that reject unknown scopes. With the fallback removed, clients that relied on it should pass an explicit scope on their OAuthClientMetadata. Closes #1307 --- docs/advanced/oauth-clients.md | 2 +- docs/migration.md | 12 ++ src/mcp/client/auth/oauth2.py | 10 +- src/mcp/client/auth/utils.py | 34 ++++- tests/client/test_auth.py | 137 +++++++++++++++++- tests/interaction/_requirements.py | 20 +-- .../interaction/auth/test_authorize_token.py | 92 ++++++++++-- 7 files changed, 266 insertions(+), 41 deletions(-) diff --git a/docs/advanced/oauth-clients.md b/docs/advanced/oauth-clients.md index 3407f0266..9c8b73df5 100644 --- a/docs/advanced/oauth-clients.md +++ b/docs/advanced/oauth-clients.md @@ -78,7 +78,7 @@ The first time `Client` sends a request, the server answers `401`. The provider 1. **Discovery.** It reads the `WWW-Authenticate` header, fetches the server's Protected Resource Metadata from `/.well-known/oauth-protected-resource`, learns which authorization server protects this resource, and fetches *that* server's metadata. 2. **Registration.** Nothing in storage? It registers you dynamically with your `OAuthClientMetadata` and stores the result. -3. **Authorization.** It generates the PKCE pair and a `state`, builds the authorization URL, awaits your `redirect_handler`, then awaits your `callback_handler` for the code. +3. **Authorization.** It checks that the discovered authorization-server metadata advertises `S256` PKCE support (and stops with `OAuthFlowError` if it does not), generates the PKCE pair and a `state`, builds the authorization URL, awaits your `redirect_handler`, then awaits your `callback_handler` for the code. 4. **Exchange.** It trades the code for an `OAuthToken`, stores it, and replays your original request with `Authorization: Bearer ...`. After that it is quiet. Tokens come out of storage, an expired access token is refreshed with the refresh token, and only when none of that works does it run the flow again. diff --git a/docs/migration.md b/docs/migration.md index 8c5bc1e0f..305fb31a8 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -121,6 +121,18 @@ async def callback_handler() -> AuthorizationCodeResult: Forward the `iss` query parameter from the redirect so the validation can run: omitting it makes the flow fail with `OAuthFlowError` against servers that advertise `authorization_response_iss_parameter_supported`, and silently skips the check for servers that send `iss` without advertising it. +### OAuth client refuses to authorize when AS metadata does not advertise S256 PKCE + +The OAuth client now verifies PKCE support before starting the authorization-code grant, as the MCP authorization specification's Authorization Code Protection section requires. When an authorization-server metadata document was discovered and its `code_challenge_methods_supported` is absent (which the spec defines as "the authorization server does not support PKCE") or does not list `S256` (the only method the SDK sends), the flow raises `OAuthFlowError` instead of redirecting to the authorization endpoint. The symptom is `OAuthFlowError: Authorization server metadata does not include code_challenge_methods_supported; PKCE support cannot be verified`. Previously the SDK never inspected the field and proceeded with an S256 challenge regardless. + +When no authorization-server metadata document could be discovered at all, the flow proceeds as before — absence of a document is not evidence of non-support. Grants that never issue an authorization code (`ClientCredentialsOAuthProvider`, `PrivateKeyJWTOAuthProvider`) are unaffected. There is no SDK-side opt-out: an authorization server that supports S256 but omits the field from its published metadata needs that metadata fixed (RFC 8414 §2 defines `code_challenge_methods_supported`). + +### OAuth client no longer reads `scopes_supported` from AS metadata to choose a scope + +The specification's scope-selection chain is two steps: the `scope` parameter from the `WWW-Authenticate` challenge, then `scopes_supported` from the Protected Resource Metadata document, *otherwise the `scope` parameter is omitted*. The SDK inserted an extra fallback between those two steps — the **authorization-server** metadata's `scopes_supported` — which over-requests (an authorization server may serve many resource servers, so its list is a superset of any one resource's) and caused real `access_denied` failures ([#1307](https://github.com/modelcontextprotocol/python-sdk/issues/1307)). That fallback is removed: when neither the challenge nor the PRM names scopes, the client now omits the `scope` parameter and lets the authorization server apply its defaults. + +This also affects the SEP-2207 `offline_access` augmentation, which only fires once a base scope was selected: if the authorization server's `scopes_supported` was your only scope source, the client now sends no `scope` at all (not even `offline_access`) and the authorization server's defaults decide whether a refresh token is issued. In either case, if you relied on the removed fallback, pass an explicit `scope` on the `OAuthClientMetadata` you give to `OAuthClientProvider`. + ### `get_session_id` callback removed from `streamable_http_client` The `get_session_id` callback (third element of the returned tuple) has been removed from `streamable_http_client`. The function now returns a 2-tuple `(read_stream, write_stream)` instead of a 3-tuple. diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 711848d72..f929decc4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -21,6 +21,7 @@ from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError from mcp.client.auth.utils import ( + CODE_CHALLENGE_METHOD, build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, create_client_info_from_metadata_url, @@ -40,6 +41,7 @@ union_scopes, validate_authorization_response_iss, validate_metadata_issuer, + validate_pkce_support, ) from mcp.shared.auth import ( AuthorizationCodeResult, @@ -323,6 +325,12 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if not self.context.callback_handler: raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover + # Authorization Code Protection: a discovered metadata document that does not advertise + # S256 PKCE support must stop the flow before any authorize redirect is built. When no + # document was discovered at all there is nothing to verify against, so the flow proceeds. + if self.context.oauth_metadata is not None: + validate_pkce_support(self.context.oauth_metadata) + if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: @@ -342,7 +350,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), "state": state, "code_challenge": pkce_params.code_challenge, - "code_challenge_method": "S256", + "code_challenge_method": CODE_CHALLENGE_METHOD, } # Only include resource param if conditions are met diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d6b05e066..848ef7360 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -1,4 +1,5 @@ import re +from typing import Final from urllib.parse import urljoin, urlparse from httpx import Request, Response @@ -15,6 +16,10 @@ ) from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +# The only PKCE code challenge method this client sends (OAuth 2.1 section 4.1.1 mandates S256). +# `validate_pkce_support` checks the authorization server's metadata against the same name. +CODE_CHALLENGE_METHOD: Final = "S256" + def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: """Extract field from WWW-Authenticate header. @@ -107,14 +112,11 @@ def get_client_metadata_scopes( # MCP spec scope selection priority: # 1. WWW-Authenticate header scope # 2. PRM scopes_supported - # 3. AS scopes_supported (SDK fallback) - # 4. Omit scope parameter + # 3. Omit scope parameter if www_authenticate_scope is not None: selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: selected_scope = " ".join(protected_resource_metadata.scopes_supported) - elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: - selected_scope = " ".join(authorization_server_metadata.scopes_supported) # SEP-2207: append offline_access when the AS supports it and the client can use refresh tokens if ( @@ -272,6 +274,30 @@ def validate_metadata_issuer(oauth_metadata: OAuthMetadata, expected_issuer: str ) +def validate_pkce_support(oauth_metadata: OAuthMetadata) -> None: + """Verify that authorization-server metadata advertises support for S256 PKCE. + + Per the MCP authorization specification's Authorization Code Protection requirements, a + client must verify PKCE support from the authorization server's metadata before proceeding + with authorization: an absent `code_challenge_methods_supported` means the server does not + support PKCE. The SDK only ever sends the mandatory `S256` method, so a method list that + omits `S256` is equally unusable. + + Raises: + OAuthFlowError: If `code_challenge_methods_supported` is absent or does not list `S256`. + """ + methods = oauth_metadata.code_challenge_methods_supported + if methods is None: + raise OAuthFlowError( + "Authorization server metadata does not include code_challenge_methods_supported; " + "PKCE support cannot be verified" + ) + if CODE_CHALLENGE_METHOD not in methods: + raise OAuthFlowError( + f"Authorization server does not support the {CODE_CHALLENGE_METHOD} PKCE code challenge method: {methods}" + ) + + def create_oauth_metadata_request(url: str) -> Request: return Request("GET", url, headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_PROTOCOL_VERSION}) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1ec38ccf6..322d02d22 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -30,6 +30,7 @@ union_scopes, validate_authorization_response_iss, validate_metadata_issuer, + validate_pkce_support, ) from mcp.server.auth.routes import build_metadata from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions @@ -2447,7 +2448,11 @@ def test_offline_access_not_duplicated_when_already_present(self): assert scopes == "read offline_access write" def test_offline_access_not_added_when_no_scopes_selected(self): - """offline_access is not added when no base scopes are available (None).""" + """offline_access is not appended when the challenge and PRM yield no base scope. + + The AS metadata advertising offline_access is not a scope source, so there is no + selected scope to augment and the result stays None. + """ asm = self._make_as_metadata(scopes_supported=["offline_access"]) scopes = get_client_metadata_scopes( @@ -2456,9 +2461,7 @@ def test_offline_access_not_added_when_no_scopes_selected(self): authorization_server_metadata=asm, client_grant_types=["authorization_code", "refresh_token"], ) - # When AS scopes are the only source and include offline_access, - # the base scope is "offline_access" and no duplication happens - assert scopes == "offline_access" + assert scopes is None def test_offline_access_not_added_when_as_scopes_supported_is_none(self): """offline_access is not added when AS scopes_supported is None.""" @@ -2571,6 +2574,7 @@ async def callback_handler() -> AuthorizationCodeResult: b'{"issuer": "https://auth.example.com",' b' "authorization_endpoint": "https://auth.example.com/authorize",' b' "token_endpoint": "https://auth.example.com/token",' + b' "code_challenge_methods_supported": ["S256"],' b' "scopes_supported": ["read", "write", "offline_access"]}' ), request=oauth_request, @@ -2679,6 +2683,7 @@ async def callback_handler() -> AuthorizationCodeResult: b'{"issuer": "https://auth.example.com",' b' "authorization_endpoint": "https://auth.example.com/authorize",' b' "token_endpoint": "https://auth.example.com/token",' + b' "code_challenge_methods_supported": ["S256"],' b' "scopes_supported": ["read", "write"]}' ), request=oauth_request, @@ -2780,6 +2785,48 @@ def test_validate_metadata_issuer_rejects_mismatch(): validate_metadata_issuer(_issuer_metadata(issuer="https://attacker.example.com"), _ISSUER) +def test_as_metadata_scopes_supported_is_never_used_as_a_scope_source(): + """When neither the challenge nor the PRM names scopes, scope is omitted even though the AS advertises some. + + The spec's scope-selection chain is WWW-Authenticate scope, then PRM scopes_supported, then + omit; authorization-server metadata is not a step in it. + """ + asm = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + scopes_supported=["as-only:read", "as-only:write"], + ) + assert get_client_metadata_scopes(None, None, asm) is None + + +def test_pkce_validation_passes_when_as_metadata_lists_s256(): + """Metadata whose code_challenge_methods_supported includes S256 is accepted, whatever else it lists.""" + validate_pkce_support(_issuer_metadata().model_copy(update={"code_challenge_methods_supported": ["plain", "S256"]})) + + +def test_pkce_validation_refuses_when_code_challenge_methods_supported_is_absent(): + """An AS metadata document without code_challenge_methods_supported is refused: PKCE support is unverified.""" + metadata = _issuer_metadata() + assert metadata.code_challenge_methods_supported is None + with pytest.raises(OAuthFlowError) as exc_info: + validate_pkce_support(metadata) + assert str(exc_info.value) == snapshot( + "Authorization server metadata does not include code_challenge_methods_supported; " + "PKCE support cannot be verified" + ) + + +def test_pkce_validation_refuses_when_s256_is_not_among_the_advertised_methods(): + """A code_challenge_methods_supported list without S256 is refused: the SDK only sends S256.""" + metadata = _issuer_metadata().model_copy(update={"code_challenge_methods_supported": ["plain"]}) + with pytest.raises(OAuthFlowError) as exc_info: + validate_pkce_support(metadata) + assert str(exc_info.value) == snapshot( + "Authorization server does not support the S256 PKCE code challenge method: ['plain']" + ) + + @pytest.mark.parametrize( ("previous", "new", "expected"), [ @@ -2991,7 +3038,8 @@ async def test_issuer_binding_re_evaluated_after_asm_when_prm_discovery_failed( content=( b'{"issuer": "https://new-as.example.com", ' b'"authorization_endpoint": "https://new-as.example.com/authorize", ' - b'"token_endpoint": "https://new-as.example.com/token"}' + b'"token_endpoint": "https://new-as.example.com/token", ' + b'"code_challenge_methods_supported": ["S256"]}' ), ) ], @@ -3138,7 +3186,8 @@ async def echo_callback() -> AuthorizationCodeResult: content=( b'{"issuer": "https://api.example.com", ' b'"authorization_endpoint": "https://api.example.com/authorize", ' - b'"token_endpoint": "https://api.example.com/token"}' + b'"token_endpoint": "https://api.example.com/token", ' + b'"code_challenge_methods_supported": ["S256"]}' ), request=asm_req, ) @@ -3169,3 +3218,79 @@ async def echo_callback() -> AuthorizationCodeResult: await auth_flow.asend(httpx.Response(200, request=final_req)) except StopAsyncIteration: pass + + +@pytest.mark.anyio +async def test_pkce_is_still_sent_when_no_authorization_server_metadata_document_is_discovered( + oauth_provider: OAuthClientProvider, +): + """With no discoverable AS metadata, the client still proceeds and sends an S256 PKCE challenge. + + The refuse-if-unsupported gate is conditioned on a discovered authorization-server metadata + *document*: failing to obtain one is not evidence of non-support, and the SDK deliberately + keeps a legacy no-metadata fallback path (the `client-auth:prm-discovery:no-prm-fallback` + requirement), so the flow falls back to the resource origin's `/authorize` with PKCE intact. + This has to drive the httpx `async_auth_flow` generator directly: the in-process interaction + harness cannot express a server with no metadata endpoint, because its real authorize route + always returns an RFC 9207 `iss` that the client rejects when it has no metadata issuer to + compare against. + + Steps: + 1. 401 with no challenge -> path-based then root PRM discovery, both 404. + 2. One root AS-metadata probe (the only URL built when no AS is known), 404 -> + `oauth_metadata` stays None. + 3. DCR against the resource origin's `/register` fallback. + 4. The authorize redirect is built and carries `code_challenge_method=S256`. + """ + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + captured_url: str | None = None + + async def capture_redirect(url: str) -> None: + nonlocal captured_url + captured_url = url + + async def echo_callback() -> AuthorizationCodeResult: + assert captured_url is not None + params = parse_qs(urlparse(captured_url).query) + return AuthorizationCodeResult(code="auth_code", state=params["state"][0]) + + oauth_provider.context.redirect_handler = capture_redirect + oauth_provider.context.callback_handler = echo_callback + + auth_flow = oauth_provider.async_auth_flow(httpx.Request("GET", "https://api.example.com/v1/mcp")) + request = await auth_flow.__anext__() + + prm_req = await auth_flow.asend(httpx.Response(401, request=request)) + assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + prm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) + assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + asm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) + assert str(asm_req.url) == "https://api.example.com/.well-known/oauth-authorization-server" + dcr_req = await auth_flow.asend(httpx.Response(404, request=asm_req)) + assert oauth_provider.context.oauth_metadata is None + assert str(dcr_req.url) == "https://api.example.com/register" + + token_req = await auth_flow.asend( + httpx.Response( + 201, + json={"client_id": "registered", "redirect_uris": ["http://localhost:3030/callback"]}, + request=dcr_req, + ) + ) + + assert captured_url is not None + params = parse_qs(urlparse(captured_url).query) + assert params["code_challenge_method"] == ["S256"] + assert params["code_challenge"] != [""] + + final_req = await auth_flow.asend( + httpx.Response(200, json={"access_token": "t", "token_type": "Bearer", "expires_in": 3600}, request=token_req) + ) + try: + await auth_flow.asend(httpx.Response(200, request=final_req)) + except StopAsyncIteration: + pass diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 31a2af8b7..5d02d7d40 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -3434,15 +3434,14 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", behavior=( "The client refuses to proceed when the authorization server's metadata does not include " - "code_challenge_methods_supported, since PKCE support cannot be verified." + "code_challenge_methods_supported, or includes it without S256 (the only method the client " + "sends), since a compliant PKCE flow cannot be completed." ), transports=("streamable-http",), - note="OAuth is HTTP-only.", - divergence=Divergence( - note=( - "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " - "regardless; the spec MUST is not enforced." - ), + note=( + "OAuth is HTTP-only. The check fires only when an authorization-server metadata document was " + "discovered; the legacy no-metadata fallback (client-auth:prm-discovery:no-prm-fallback) " + "deliberately proceeds, since the absence of a document is not evidence of non-support." ), ), "client-auth:pkce:s256": Requirement( @@ -3523,13 +3522,6 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), note="OAuth is HTTP-only.", - divergence=Divergence( - note=( - "The SDK inserts an extra fallback step between PRM and omit: if the authorization " - "server metadata advertises scopes_supported, that list is used (client/auth/utils.py). " - "This is beyond the spec's two-step chain." - ), - ), ), "client-auth:state:verify": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index 08c4c9142..fd645a37e 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -28,7 +28,7 @@ from mcp.client.auth import OAuthFlowError from mcp.server import Server, ServerRequestContext -from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata, ProtectedResourceMetadata from tests.interaction._connect import BASE_URL from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( @@ -359,13 +359,63 @@ async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_me assert json.loads(register.content)["scope"] == "from-header" +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_omitted_when_neither_the_challenge_nor_prm_supply_scopes() -> None: + """When the 401 challenge carries no `scope=` and the PRM has no `scopes_supported`, no scope is requested. + + The served AS metadata advertises `scopes_supported`, which the spec's two-step chain never + consults, so the assertion fails if the SDK falls back to it instead of omitting the + parameter. The challenge has to be hand-supplied via `first_challenge_shim` because the + SDK's own bearer middleware always emits `scope=` from its configured `required_scopes`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + challenge = f'Bearer resource_metadata="{BASE_URL}{PRM_PATH}"' + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + asm = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp", "as-advertised"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + serve = { + PRM_PATH: prm.model_dump_json(exclude_none=True).encode(), + ASM_PATH: asm.model_dump_json(exclude_none=True).encode(), + } + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + verify_tokens=False, + app_shim=lambda app: first_challenge_shim(challenge)(shimmed_app(app, serve=serve)), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert "scope" not in authorize_params(headless.authorize_url) + + [register] = find(recorded, "POST", "/register") + assert "scope" not in json.loads(register.content) + + @requirement("client-auth:pkce:refuse-if-unsupported") -async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: - """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. +async def test_the_flow_aborts_before_any_authorize_redirect_when_as_metadata_omits_code_challenge_methods() -> None: + """AS metadata without `code_challenge_methods_supported` aborts the flow before any authorize redirect. - The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow - completes. See the divergence on the requirement. + Authorization Code Protection: the client must verify PKCE support from discovered + authorization-server metadata and refuse to proceed with authorization when the field is + absent. The registration step is not authorization, so the dynamic-registration request + still goes out; what never happens is an authorize redirect or a token request. """ + recorded, on_request = record_requests() override = OAuthMetadata( issuer=AnyHttpUrl(f"{BASE_URL}/"), authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), @@ -379,18 +429,30 @@ async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_ provider = InMemoryAuthorizationServerProvider() server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() with anyio.fail_after(5): - async with connect_with_oauth( - server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) - ) as (client, headless): - result = await client.list_tools() - - assert headless.authorize_url is not None - params = authorize_params(headless.authorize_url) - assert params["code_challenge_method"] == "S256" - assert params["code_challenge"] != "" - assert result.tools[0].name == "echo" + # Exact-message equality: `OAuthFlowError` has several pre-redirect raise sites on this + # path (issuer mismatch, missing client info, ...); only this message is the PKCE refusal. + with pytest.RaisesGroup( + pytest.RaisesExc( + OAuthFlowError, + check=lambda e: str(e) + == "Authorization server metadata does not include code_challenge_methods_supported; " + "PKCE support cannot be verified", + ), + flatten_subgroups=True, + ): + await connect_with_oauth( + server, + provider=provider, + headless=headless, + app_shim=lambda app: shimmed_app(app, serve=serve), + on_request=on_request, + ).__aenter__() + + assert headless.authorize_url is None + assert find(recorded, "POST", "/token") == [] @requirement("client-auth:authorize:error-surfaces") From 1e2bd9befcaf57ed177fc4bffa44f5e2634d83dc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 18:18:30 +0000 Subject: [PATCH 08/14] Reject a second initialize on an already-initialized session A server that had already completed the initialize handshake on a connection answered a repeated initialize request on that same connection as a fresh handshake, silently overwriting the session's recorded client_params and negotiated protocol version. A check_capability call made after that point then answered against the second client's declared capabilities. The handshake now commits at most once per connection: a repeated initialize is answered with JSON-RPC error -32600 (INVALID_REQUEST, "Session already initialized") and the established session keeps serving. The check lives at the runner's request-dispatch boundary, right next to its mirror (the request-before-initialize gate), so every transport gets it without any transport-level body sniffing. The discriminator is client_params rather than initialize_accepted: the legacy stateless path builds a per-request connection that is born past the gate but has no peer info yet, and its one initialize must still be accepted. No compliant client is affected. The spec makes initialization the first interaction of a session, and ClientSession.initialize() is already idempotent (a repeat call returns the first result without sending anything). This only applies to the legacy (2025-11-25 and earlier) handshake; the 2026-07-28 protocol removes initialize entirely. Closes #2605 --- docs/migration.md | 17 +++++++++ src/mcp/server/runner.py | 10 ++++++ tests/interaction/_requirements.py | 6 ---- .../transports/test_hosting_session.py | 36 ++++++++++++++----- tests/server/test_runner.py | 25 +++++++++++++ 5 files changed, 79 insertions(+), 15 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 305fb31a8..751061e99 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -49,6 +49,23 @@ per the spec's SHOULD. v1 answered the cancelled request with an error with anyio cancellation when its scope is cancelled, so no reply is needed to unblock it. +### A second `initialize` on an already-initialized session is rejected + +A server that has already completed the `initialize` handshake on a connection now answers a +repeated `initialize` request on that same connection with JSON-RPC error `-32600` +(`INVALID_REQUEST`, message `"Session already initialized"`) instead of re-running the handshake. +Previously the repeat was answered as a fresh handshake and silently overwrote the session's +recorded `client_params` and negotiated protocol version, so a `check_capability` call made after +that point answered against the second client's declared capabilities +([#2605](https://github.com/modelcontextprotocol/python-sdk/issues/2605)). + +No compliant client is affected: the spec makes initialization the first interaction of a session, +and the SDK's own `ClientSession.initialize()` is idempotent (a repeat call returns the first +result without sending anything). A peer that needs a fresh handshake should open a new +connection — on streamable HTTP, a `POST` without an `Mcp-Session-Id` header. This applies only to +the legacy (2025-11-25 and earlier) handshake; the 2026-07-28 protocol removes `initialize` +entirely. + ### `streamablehttp_client` removed The deprecated `streamablehttp_client` function has been removed. Use `streamable_http_client` instead. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 0b57c3e5c..1149bf2be 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -26,6 +26,7 @@ CLIENT_INFO_META_KEY, INTERNAL_ERROR, INVALID_PARAMS, + INVALID_REQUEST, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, ErrorData, @@ -177,6 +178,15 @@ async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult: # the gate become a per-version legacy path then. Initialize runs inline # (read loop parked), so awaiting the peer anywhere on this path deadlocks. if method == "initialize": + # The handshake commits at most once per connection: a repeat would + # silently overwrite the negotiated `client_params` and + # `protocol_version` at the commit at the bottom of `_on_request` + # (#2605). The discriminator is `client_params`, not + # `initialize_accepted`: a born-ready `from_envelope` connection + # (the legacy stateless path) is already past the gate but has no + # peer info yet, and its one `initialize` must still be accepted. + if self.connection.client_params is not None: + raise MCPError(code=INVALID_REQUEST, message="Session already initialized") return self._serialize(method, version, self._handle_initialize(params)) # Methods without a handler are METHOD_NOT_FOUND regardless of # initialization state: JSON-RPC 2.0 reserves -32601 for "not diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 5d02d7d40..2b6838681 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2527,12 +2527,6 @@ def __post_init__(self) -> None: source="sdk", behavior="A second initialize on an already-initialized session transport is rejected.", transports=("streamable-http",), - divergence=Divergence( - note=( - "The transport forwards a second initialize carrying the existing session ID to the running " - "server, which answers it as a fresh handshake; nothing rejects re-initialization." - ), - ), removed_in="2026-07-28", note=( "removed in 2026-07-28 (SEP-2567); per-session initialize guard retired with Mcp-Session-Id, no " diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py index 1b41a4bee..aa20037e4 100644 --- a/tests/interaction/transports/test_hosting_session.py +++ b/tests/interaction/transports/test_hosting_session.py @@ -12,7 +12,7 @@ import httpx import pytest from inline_snapshot import snapshot -from mcp_types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from mcp_types import INVALID_REQUEST, JSONRPCError, JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool from mcp.server import Server, ServerRequestContext from tests.interaction._connect import ( @@ -142,20 +142,38 @@ async def test_terminating_one_session_leaves_others_working() -> None: @requirement("hosting:session:reinitialize") -async def test_second_initialize_on_an_existing_session_is_accepted() -> None: - """A second initialize POST carrying an existing session ID is processed rather than rejected. - - See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the - second initialize to the running server, which answers it as a fresh handshake. +async def test_second_initialize_on_an_existing_session_is_rejected_and_the_session_survives() -> None: + """A second initialize POST on an existing session is answered with INVALID_REQUEST, + and the established session keeps serving. + + SDK-defined, no spec MUST: closes the gap where the server answered a repeated initialize as a + fresh handshake and silently overwrote the session's committed client_params (python-sdk#2605; + the kernel-level proof that the committed state survives is in tests/server/test_runner.py). + Raw HTTP because the SDK Client performs exactly one handshake and cannot be made to repeat it. + + Steps: + 1. A real handshake establishes the session. + 2. A second initialize on that session ID routes to the existing transport (the instance + count stays 1) and is rejected at the JSON-RPC layer. The HTTP status is still 200: in + legacy SSE mode the response stream is committed before dispatch, so the rejection + arrives as a JSONRPCError event on it. + 3. tools/list on the same session still answers, proving the rejection tore nothing down. """ async with mounted_app(_server()) as (http, manager): session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) assert len(manager._server_instances) == 1 + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCError) + assert messages[0].id == 2 + assert messages[0].error.code == INVALID_REQUEST + assert messages[0].error.message == snapshot("Session already initialized") - assert response.status_code == snapshot(200) - assert isinstance(messages[0], JSONRPCResponse) - assert messages[0].id == 2 + _, after = await post_jsonrpc(http, {"jsonrpc": "2.0", "id": 3, "method": "tools/list"}, session_id=session_id) + + assert isinstance(after[0], JSONRPCResponse) + assert after[0].result["tools"][0]["name"] == "noop" @requirement("hosting:stateless:no-session-id") diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index 52187f262..777734e2b 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -19,6 +19,7 @@ from mcp_types import ( INTERNAL_ERROR, INVALID_PARAMS, + INVALID_REQUEST, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, ClientCapabilities, @@ -175,6 +176,30 @@ async def test_runner_initialize_opens_gate_but_event_fires_only_after_initializ await runner.connection.initialized.wait() +@pytest.mark.anyio +async def test_runner_rejects_a_second_initialize_and_preserves_the_committed_handshake(server: SrvT): + """A second `initialize` on an already-initialized connection is rejected with + INVALID_REQUEST and the first handshake's committed `client_params` and + `protocol_version` survive unchanged. + + SDK-defined (no spec MUST mandates the rejection). Regression lock for python-sdk#2605, + where the repeat was answered as a fresh handshake and silently overwrote both.""" + impostor = InitializeRequestParams( + protocol_version=OLDEST_SUPPORTED_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="impostor", version="9.9"), + ).model_dump(by_alias=True, exclude_none=True) + # `connected_runner(server)` already performed the real initialize (client name + # "test-client", protocol version LATEST_HANDSHAKE_VERSION) before yielding. + async with connected_runner(server) as (client, runner): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("initialize", impostor) + assert exc.value.error == ErrorData(code=INVALID_REQUEST, message="Session already initialized") + assert runner.connection.client_params is not None + assert runner.connection.client_params.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_HANDSHAKE_VERSION + + @pytest.mark.anyio async def test_runner_gates_requests_before_initialize(server: SrvT): async with connected_runner(server, initialized=False) as (client, _): From 8f60ebe5359fd7649f854930145b653daaccf41a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 18:20:48 +0000 Subject: [PATCH 09/14] Make the prompt test's opaque coverage workaround explicit The never-called prompt body in the wrong-type-argument test proves it never ran by appending to a closure-captured list, but a plain append on an unreachable line would fail the 100% coverage gate. The append therefore rides on a `raise NotImplementedError` line, which coverage's `exclude_also` strips. That shape read as an accident; add a comment spelling out why it is written that way. --- tests/interaction/mcpserver/test_prompts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 155236232..bccace724 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -130,6 +130,8 @@ async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_func @mcp.prompt() def repeat(phrase: str, count: int) -> str: """A registered prompt; type validation rejects the call before the function runs.""" + # Never runs: validate_call rejects the bad count first; `called == []` below is the + # proof. The append rides on the coverage-excluded `raise NotImplementedError` line. raise NotImplementedError(called.append((phrase, count))) async with connect(mcp) as client: From 28f500c71ab92fb9b20fd4c2fc9b5ab152de9d1a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 22:38:25 +0000 Subject: [PATCH 10/14] Add an opt-in strict_capabilities flag to Client and ClientSession Client(..., strict_capabilities=True) rejects, before any request reaches the transport, a call to a method whose required server capability the connected server did not advertise -- for example list_resources() against a server that only advertised tools, or subscribe_resource() when the server's resources capability does not set subscribe. The rejection is an MCPError with code -32601 (METHOD_NOT_FOUND) and data set to the method, the same shape a compliant server returns for an unadvertised capability, so opting in changes where the rejection happens, not what callers catch. The default is False and unchanged: every request is sent and the server's answer is surfaced. This mirrors the TypeScript SDK's enforceStrictCapabilities option (also default-off). The same keyword-only parameter exists on ClientSession for low-level users; Client forwards it. The method-to-capability table lives in mcp_types.methods.SERVER_CAPABILITY_REQUIREMENTS next to the other per-method maps, with missing_server_capability() as its only evaluator, so the check in ClientSession.send_request is a single data-driven gate rather than a per-method condition, and the relationship is the same at every protocol version. Because the gate reads server_capabilities, a bare version pin (mode="2026-07-28" with no prior_discover=) would reject every gated method; that combination is refused at Client construction with a ValueError that names the fix. The interaction-requirements entry for the lifecycle capability rule is no longer marked untested: the new tests pin both the opt-in pre-wire rejection and the default send-and-surface behaviour. --- docs/client/index.md | 4 +- docs/migration.md | 20 +++ src/mcp-types/mcp_types/methods.py | 56 +++++++- src/mcp/client/client.py | 19 +++ src/mcp/client/session.py | 15 +- tests/client/test_client.py | 8 ++ tests/interaction/_requirements.py | 13 +- .../lowlevel/test_client_connect.py | 129 ++++++++++++++++++ tests/types/test_methods.py | 43 ++++++ 9 files changed, 297 insertions(+), 10 deletions(-) diff --git a/docs/client/index.md b/docs/client/index.md index 38efa72b6..2e74e900d 100644 --- a/docs/client/index.md +++ b/docs/client/index.md @@ -31,7 +31,7 @@ Everything else on this page is identical across all three. Headers, subprocesse Four read-only properties, populated the moment you enter the block: * `client.server_info`: the server's identity. `server_info.name` here is `"Bookshop"`, `server_info.version` is whatever the server reports. -* `client.server_capabilities`: what the server can do (`tools`, `resources`, `prompts`, `completions`, ...). A capability the server doesn't have is `None`. +* `client.server_capabilities`: what the server can do (`tools`, `resources`, `prompts`, `completions`, ...). A capability the server doesn't have is `None`. Pass `Client(..., strict_capabilities=True)` and the client uses this to refuse, client-side, a call whose capability the server didn't advertise: it raises `MCPError` with code `-32601` without sending anything. By default the request is sent and the server's answer comes back. * `client.protocol_version`: the protocol version the two sides agreed on. Here it is `"2026-07-28"`. * `client.instructions`: the server's `instructions=` string, or `None` if it didn't set one. @@ -145,7 +145,7 @@ The resource verbs come in pairs: two ways to list, one way to read. `read_resource` returns `contents`, a list of `TextResourceContents` or `BlobResourceContents`. Same idea as tool content: narrow with `isinstance`, then read `.text` (or `.blob`). -A client can also **subscribe** to a resource and be told when it changes: `subscribe_resource(uri)` and `unsubscribe_resource(uri)`, same shape as everything else here. `MCPServer` doesn't implement that half. It says so up front (`server_capabilities.resources.subscribe` is `False`) and answers the request with an `MCPError`: `-32601`, *Method not found*. A server that does support subscriptions is built on the low-level `Server` (**The low-level Server**). +A client can also **subscribe** to a resource and be told when it changes: `subscribe_resource(uri)` and `unsubscribe_resource(uri)`, same shape as everything else here. `MCPServer` doesn't implement that half. It says so up front (`server_capabilities.resources.subscribe` is `False`) and answers the request with an `MCPError`: `-32601`, *Method not found*. With `strict_capabilities=True` you get the same `-32601` without the round trip: the client sees `server_capabilities.resources.subscribe` is falsy and never sends the request. A server that does support subscriptions is built on the low-level `Server` (**The low-level Server**). ## Prompts diff --git a/docs/migration.md b/docs/migration.md index 751061e99..392150df1 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -439,6 +439,26 @@ For an in-process `Client(server)` (where `server` is a `Server` or `MCPServer` `Client.send_ping()` is deprecated (ping is removed in 2026-07-28); pin `mode='legacy'` if you need it. +### `Client` gains an opt-in `strict_capabilities` flag + +`Client(..., strict_capabilities=True)` makes the client reject, before any request reaches +the transport, a call to a method whose required server capability the connected server did +not advertise -- for example `list_resources()` against a server that only advertised +`tools`, or `subscribe_resource()` against a server whose `resources` capability does not set +`subscribe`. The rejection is an `MCPError` with code `-32601` (`METHOD_NOT_FOUND`), the same +code a compliant server returns for an unadvertised capability, so existing error handling is +unaffected. + +The default is `False` and is unchanged from v1: every request is sent and the server's +answer is surfaced. This mirrors the TypeScript SDK's `enforceStrictCapabilities` option, and +the same keyword-only parameter exists directly on `ClientSession(..., strict_capabilities=)` +for low-level users -- `Client` just forwards it. The check reads +`client.server_capabilities`, so a bare version pin (`mode="2026-07-28"` with no +`prior_discover=`) -- where the client never asks the server what it supports and so every +capability-gated method would be rejected -- is refused at construction with a `ValueError` +that names the fix: supply `prior_discover=` or use `mode="auto"`. Which method needs which +capability is exported as `mcp_types.methods.SERVER_CAPABILITY_REQUIREMENTS`. + ### Unhandled `elicitation/create` returns `-32602`; unhandled `roots/list` returns `-32601` When a server sends `elicitation/create` to a client that registered no `elicitation_callback`, or `roots/list` to a client that registered no `list_roots_callback`, the SDK still answers on the client's behalf with a JSON-RPC error. In v1 both answers used code `-32600` (`INVALID_REQUEST`). They now use the code the spec assigns to each case: `elicitation/create` is answered with `-32602` (`INVALID_PARAMS`), per the [elicitation error-handling section](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation#error-handling) (a client with no callback declared no elicitation modes, and a request for an undeclared mode MUST be answered with `-32602`), and `roots/list` is answered with `-32601` (`METHOD_NOT_FOUND`), per the [roots error-handling section](https://modelcontextprotocol.io/specification/2025-11-25/client/roots#error-handling). The error messages (`Elicitation not supported`, `List roots not supported`) are unchanged, and `sampling/createMessage` without a `sampling_callback` still answers `-32600` — the spec assigns no code to that case. diff --git a/src/mcp-types/mcp_types/methods.py b/src/mcp-types/mcp_types/methods.py index 824dcfdfe..30b779692 100644 --- a/src/mcp-types/mcp_types/methods.py +++ b/src/mcp-types/mcp_types/methods.py @@ -6,7 +6,10 @@ Surface maps key `(method, version)` to per-version wire types (key absence is the version gate; shape validation is per schema era, i.e. 2025-11-25 for every pre-2026 version and 2026-07-28 for 2026). Monolith maps key `method` to the -version-free `mcp_types` models user code receives.""" +version-free `mcp_types` models user code receives. +`SERVER_CAPABILITY_REQUIREMENTS` keys `method` to the `ServerCapabilities` +attribute path it requires (version-invariant); `missing_server_capability` +evaluates it.""" from __future__ import annotations @@ -29,11 +32,13 @@ "MONOLITH_NOTIFICATIONS", "MONOLITH_REQUESTS", "MONOLITH_RESULTS", + "SERVER_CAPABILITY_REQUIREMENTS", "SERVER_NOTIFICATIONS", "SERVER_REQUESTS", "SERVER_RESULTS", "SPEC_CLIENT_METHODS", "SPEC_CLIENT_NOTIFICATION_METHODS", + "missing_server_capability", "parse_client_notification", "parse_client_request", "parse_client_result", @@ -404,6 +409,55 @@ """Monolith result model (or two-arm union) per request method.""" +# --- Server capability requirements --- + +SERVER_CAPABILITY_REQUIREMENTS: Final[Mapping[str, tuple[str, ...]]] = MappingProxyType( + { + "completion/complete": ("completions",), + "logging/setLevel": ("logging",), + "prompts/get": ("prompts",), + "prompts/list": ("prompts",), + "resources/list": ("resources",), + "resources/read": ("resources",), + "resources/subscribe": ("resources", "subscribe"), + "resources/templates/list": ("resources",), + "resources/unsubscribe": ("resources",), + "tools/call": ("tools",), + "tools/list": ("tools",), + } +) +"""The server capability each client request method requires, as an attribute path into +`ServerCapabilities`. Methods with no entry (`ping`, `initialize`, `server/discover`, +`subscriptions/listen`) require no server capability. `subscriptions/listen` stays ungated on +purpose: at 2026-07-28 the `resources.subscribe` capability licenses that request's +`resourceSubscriptions` FIELD, a params-level fact a method-keyed table cannot express, and +no spec MUST ties the method itself to a capability -- gating it here would wrongly reject a +listen that asks only for `toolsListChanged`. The relationship is the same at every protocol +version, so the map is keyed by method alone.""" + + +def missing_server_capability(method: str, capabilities: types.ServerCapabilities | None) -> str | None: + """The server capability `method` requires but `capabilities` does not advertise. + + Returns the dotted name of the first unadvertised step on the required capability path -- + `resources` when no resources capability is advertised at all, `resources.subscribe` when + `resources` is advertised but `subscribe` is not -- or None when `method` requires no + server capability or the required one is advertised. Naming the first missing step rather + than the whole path matches the TypeScript SDK's assertCapabilityForMethod and is the + actionable thing to tell the caller. `capabilities=None` (nothing negotiated yet) + advertises nothing. + """ + path = SERVER_CAPABILITY_REQUIREMENTS.get(method) + if path is None: + return None + node: Any = capabilities + for index, attr in enumerate(path): + node = node and getattr(node, attr) + if not node: + return ".".join(path[: index + 1]) + return None + + # --- Parse functions --- # Envelope stubs merged into bodies for surface validation (surface classes are full frames). diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d6a6e4caa..c4c8bc755 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -209,6 +209,17 @@ async def main(): """A previously-obtained DiscoverResult to install via .adopt() when mode is a version pin. Ignored when mode='legacy'.""" + strict_capabilities: bool = False + """Reject calls to methods whose required server capability the server did not advertise. + + Opt-in (default False: every request is sent and the server's answer is surfaced, matching + the pre-2026 client). When True, such a call raises `MCPError` with code `METHOD_NOT_FOUND` + before any request reaches the transport -- the same code a compliant server returns for an + unadvertised capability. The check reads `server_capabilities`, so a bare version pin + (`mode='2026-07-28'` with no `prior_discover=`) -- which never asks the server what it + supports -- is rejected at construction with a `ValueError`; supply `prior_discover=` or + use `mode='auto'`. Mirrors the TypeScript SDK's `enforceStrictCapabilities` option.""" + elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" @@ -233,6 +244,13 @@ def __post_init__(self) -> None: f"mode must be 'legacy', 'auto', or one of {list(MODERN_PROTOCOL_VERSIONS)}; got {self.mode!r}{hint}" ) + if self.strict_capabilities and self.mode in MODERN_PROTOCOL_VERSIONS and self.prior_discover is None: + raise ValueError( + "strict_capabilities=True with a version pin needs prior_discover=: a bare pin " + "never asks the server what it supports, so every capability-gated method would " + "be rejected. Supply prior_discover= or use mode='auto'." + ) + srv = self.server if isinstance(srv, MCPServer): srv = srv._lowlevel_server # pyright: ignore[reportPrivateUsage] @@ -255,6 +273,7 @@ async def _build_session(self, exit_stack: AsyncExitStack) -> ClientSession: message_handler=self.message_handler, client_info=self.client_info, elicitation_callback=self.elicitation_callback, + strict_capabilities=self.strict_capabilities, ) async def __aenter__(self) -> Client: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 29e83902c..64968856d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -226,6 +226,7 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, + strict_capabilities: bool = False, dispatcher: Dispatcher[Any] | None = None, ) -> None: self._session_read_timeout_seconds = read_timeout_seconds @@ -236,6 +237,7 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler + self._strict_capabilities = strict_capabilities self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._x_mcp_header_maps: dict[str, dict[tuple[str, ...], str]] = {} self._initialize_result: types.InitializeResult | None = None @@ -309,11 +311,22 @@ async def send_request( metadata: Streamable HTTP resumption hints. Raises: - MCPError: Error response, read timeout, or connection closed. + MCPError: Error response, read timeout, or connection closed. Also raised + before any send when `strict_capabilities` is set and the server did not + advertise the capability `request.method` requires (code + `METHOD_NOT_FOUND`). RuntimeError: Called before entering the context manager. """ data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] + if self._strict_capabilities and ( + missing := _methods.missing_server_capability(method, self.server_capabilities) + ): + raise MCPError( + code=METHOD_NOT_FOUND, + message=f"Server does not advertise the {missing} capability (required for {method})", + data=method, + ) opts: CallOptions = {} self._stamp(data, opts) timeout = ( diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 568ee2c62..06a7f5e81 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -517,6 +517,14 @@ def test_client_rejects_handshake_era_mode_at_construction() -> None: Client(server, mode="not-a-version") +def test_client_rejects_strict_capabilities_on_a_bare_version_pin_at_construction() -> None: + """`strict_capabilities=True` with a version pin and no `prior_discover=` is rejected by + `__post_init__`: a bare pin never asks the server what it supports, so every + capability-gated method would be rejected before doing anything useful.""" + with pytest.raises(ValueError, match=r"prior_discover= or use mode='auto'"): + Client(MCPServer("test"), mode="2026-07-28", strict_capabilities=True) + + # ── SEP-2322 multi-round-trip auto-loop ──────────────────────────────────────── diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 2b6838681..6cf52dc53 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -196,14 +196,15 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "The client sends any request regardless of the server's advertised capabilities and " - "surfaces whatever the server answers; the spec's MUST is not enforced." + "Not the default: by default the client sends any request regardless of the " + "server's advertised capabilities and surfaces whatever the server answers. The " + "client-side pre-check is opt-in via Client(strict_capabilities=True), which " + "rejects with METHOD_NOT_FOUND before any wire traffic -- the same shape as the " + "TypeScript SDK's enforceStrictCapabilities (also default-off). The 2026-07-28 " + "revision removes the lifecycle page that carries this MUST; there the server's " + "-32601 is the authoritative signal and the SDK's server already returns it." ), ), - deferred=( - "Not implemented in the SDK: the client sends any request regardless of the server's " - "advertised capabilities and surfaces whatever the server answers." - ), ), "lifecycle:initialize:basic": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", diff --git a/tests/interaction/lowlevel/test_client_connect.py b/tests/interaction/lowlevel/test_client_connect.py index 69fd5c4e8..83f1b8a47 100644 --- a/tests/interaction/lowlevel/test_client_connect.py +++ b/tests/interaction/lowlevel/test_client_connect.py @@ -9,6 +9,9 @@ The fallback test alone hand-plays the server's side of the wire, because no real `Server` answers `server/discover` with -32601. + +The strict-capability tests at the bottom pin the opt-in `strict_capabilities=` gate: the same +recording seams prove which of the caller's requests reached the transport. """ import json @@ -19,6 +22,7 @@ import httpx import mcp_types as types import pytest +from inline_snapshot import snapshot from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, @@ -33,6 +37,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ResourcesCapability, ServerCapabilities, ToolsCapability, ) @@ -64,6 +69,18 @@ async def list_tools( return Server(name, on_list_tools=list_tools) +def _resources_server(name: str = "library") -> Server: + """A low-level server whose only handler is list-resources, so it advertises `resources` + with `subscribe=False` and no other capability.""" + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + return types.ListResourcesResult(resources=[]) + + return Server(name, on_list_resources=list_resources) + + def _request_recorder() -> tuple[list[httpx.Request], Callable[[httpx.Request], Awaitable[None]]]: """Return a list and an `on_request` hook that appends each outgoing httpx request to it.""" captured: list[httpx.Request] = [] @@ -368,3 +385,115 @@ async def test_http_protocol_version_header_matches_meta_protocol_version_on_eve body = json.loads(request.content) assert request.headers["mcp-protocol-version"] == body["params"]["_meta"][PROTOCOL_VERSION_META_KEY] assert request.headers["mcp-protocol-version"] == LATEST_MODERN_VERSION + + +@requirement("lifecycle:capability:server-not-advertised") +async def test_strict_capabilities_rejects_an_unadvertised_method_before_any_request_is_sent() -> None: + """`Client(..., strict_capabilities=True)` raises METHOD_NOT_FOUND for a method whose required + server capability the server did not advertise, and the request never reaches the transport. + + Requirement `lifecycle:capability:server-not-advertised` (spec basic/lifecycle#operation, + "Both parties MUST ... only use capabilities that were successfully negotiated"). The flag is + opt-in; the recorded wire log proves the rejection happened before the dispatcher. + """ + recording = RecordingTransport(InMemoryTransport(_tools_server())) + + with anyio.fail_after(5): + async with Client(recording, mode="legacy", strict_capabilities=True) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_resources() + assert exc_info.value.code == METHOD_NOT_FOUND + assert exc_info.value.message == snapshot( + "Server does not advertise the resources capability (required for resources/list)" + ) + assert exc_info.value.data == "resources/list" + + sent = [m.message for m in recording.sent] + methods = [m.method for m in sent if isinstance(m, JSONRPCRequest | JSONRPCNotification)] + assert methods == ["initialize", "notifications/initialized"] + + +@requirement("lifecycle:capability:server-not-advertised") +async def test_default_client_sends_an_unadvertised_method_and_surfaces_the_server_error() -> None: + """Without `strict_capabilities`, the client sends `resources/list` to a server that never + advertised `resources` and surfaces the server's METHOD_NOT_FOUND. + + Pins the recorded divergence on `lifecycle:capability:server-not-advertised`: the spec's + client-side MUST is not enforced by default; the pre-check is opt-in. + """ + recording = RecordingTransport(InMemoryTransport(_tools_server())) + + with anyio.fail_after(5): + async with Client(recording, mode="legacy") as client: + with pytest.raises(MCPError) as exc_info: + await client.list_resources() + assert exc_info.value.code == METHOD_NOT_FOUND + assert exc_info.value.message == snapshot("Method not found") + + sent = [m.message for m in recording.sent] + methods = [m.method for m in sent if isinstance(m, JSONRPCRequest | JSONRPCNotification)] + assert methods == ["initialize", "notifications/initialized", "resources/list"] + + +@requirement("lifecycle:capability:server-not-advertised") +async def test_strict_capabilities_distinguishes_a_sub_capability_from_its_parent() -> None: + """A strict client allows `resources/list` against a server advertising `resources` but + rejects `resources/subscribe` when `resources.subscribe` is not advertised. + + Steps: + 1. Connect to a server with only a list-resources handler: it advertises + `resources` with `subscribe=False` and nothing else. + 2. `list_resources()` succeeds (base `resources` is advertised). + 3. `subscribe_resource()` raises METHOD_NOT_FOUND client-side: the `resources.subscribe` + sub-capability is falsy. + 4. The wire log shows `resources/list` went out and `resources/subscribe` never did. + """ + recording = RecordingTransport(InMemoryTransport(_resources_server())) + + with anyio.fail_after(5): + async with Client(recording, mode="legacy", strict_capabilities=True) as client: + assert client.server_capabilities.resources == ResourcesCapability(subscribe=False, list_changed=False) + await client.list_resources() + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("res://x") + assert exc_info.value.code == METHOD_NOT_FOUND + assert exc_info.value.message == snapshot( + "Server does not advertise the resources.subscribe capability (required for resources/subscribe)" + ) + + sent = [m.message for m in recording.sent] + methods = [m.method for m in sent if isinstance(m, JSONRPCRequest | JSONRPCNotification)] + assert methods == ["initialize", "notifications/initialized", "resources/list"] + + +@requirement("lifecycle:capability:server-not-advertised") +async def test_strict_capabilities_uses_discover_era_capabilities_and_sends_nothing() -> None: + """On a modern (version-pinned + prior_discover) connection, the strict gate reads the + discover-era capabilities and rejects an unadvertised method with zero HTTP traffic. + + Same gate as the initialize-era tests: `server_capabilities` is the era-neutral accessor, + so no version branch exists to test separately. Asserted at the in-process streamable-HTTP + seam via the httpx event hook (the strongest "nothing was sent" available on this path). + """ + prior = DiscoverResult( + supported_versions=[LATEST_MODERN_VERSION], + capabilities=ServerCapabilities(tools=ToolsCapability(list_changed=False)), + server_info=Implementation(name="cached-server", version="9.9.9"), + ) + requests, on_request = _request_recorder() + + with anyio.fail_after(5): + async with ( + mounted_app(_tools_server(), on_request=on_request) as (http, _), + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http), + mode=LATEST_MODERN_VERSION, + prior_discover=prior, + strict_capabilities=True, + ) as client, + ): + with pytest.raises(MCPError) as exc_info: + await client.list_resources() + assert exc_info.value.code == METHOD_NOT_FOUND + + assert requests == [] diff --git a/tests/types/test_methods.py b/tests/types/test_methods.py index 79ea067c6..15f5ba79e 100644 --- a/tests/types/test_methods.py +++ b/tests/types/test_methods.py @@ -442,6 +442,48 @@ def test_spec_client_method_sets_are_the_client_direction_projection_of_the_surf assert "notifications/message" not in methods.SPEC_CLIENT_NOTIFICATION_METHODS +def test_server_capability_requirements_name_only_real_methods_and_attributes() -> None: + """Every gated method is one a client can send, and every path resolves on + `ServerCapabilities`; the ungated complement is exactly the methods that need no + server capability.""" + assert set(methods.SERVER_CAPABILITY_REQUIREMENTS) <= methods.SPEC_CLIENT_METHODS + assert methods.SPEC_CLIENT_METHODS - set(methods.SERVER_CAPABILITY_REQUIREMENTS) == { + "initialize", + "ping", + "server/discover", + "subscriptions/listen", + } + everything = types.ServerCapabilities( + completions=types.CompletionsCapability(), + logging=types.LoggingCapability(), + prompts=types.PromptsCapability(), + resources=types.ResourcesCapability(subscribe=True), + tools=types.ToolsCapability(), + ) + for method in methods.SERVER_CAPABILITY_REQUIREMENTS: + # `everything` is truthy at every step, so the walker's bare `getattr` runs for every + # attribute in the path with no short-circuit: a typo'd capability name fails loudly here. + assert methods.missing_server_capability(method, everything) is None, method + + +def test_missing_server_capability_names_the_first_unadvertised_step() -> None: + """An ungated method needs nothing; a gated method reports the FIRST step on its required + capability path that is unadvertised, not always the full path. + + The sub-capability case is the discriminating pair: `resources/subscribe` against a server + advertising no `resources` at all is told `resources`; against one advertising `resources` + but not `subscribe` it is told `resources.subscribe` -- mirroring the TypeScript SDK's two + distinct assertCapabilityForMethod messages. + """ + assert methods.missing_server_capability("ping", None) is None + assert methods.missing_server_capability("tools/list", None) == "tools" + assert methods.missing_server_capability("tools/list", types.ServerCapabilities()) == "tools" + assert methods.missing_server_capability("resources/subscribe", types.ServerCapabilities()) == "resources" + only_resources = types.ServerCapabilities(resources=types.ResourcesCapability()) + assert methods.missing_server_capability("resources/read", only_resources) is None + assert methods.missing_server_capability("resources/subscribe", only_resources) == "resources.subscribe" + + def test_elicit_result_surface_accepts_null_content_values_at_every_version_that_defines_it(): """Monolith superset leniency: hosts may answer optional form fields with null.""" for (method, _), surface in methods.CLIENT_RESULTS.items(): @@ -537,6 +579,7 @@ def test_built_in_maps_are_immutable(): "MONOLITH_NOTIFICATIONS", "MONOLITH_REQUESTS", "MONOLITH_RESULTS", + "SERVER_CAPABILITY_REQUIREMENTS", "SERVER_NOTIFICATIONS", "SERVER_REQUESTS", "SERVER_RESULTS", From eddfa29d9de745457f9dcd14fe3b3f8e1570de6a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 23:07:57 +0000 Subject: [PATCH 11/14] Deprecate ServerSession.send_progress_notification `ServerSession.send_progress_notification` takes an explicit progress token decoupled from the request it belongs to, so it can keep emitting progress for a request that has already completed -- which the spec forbids ("Progress notifications MUST stop after completion"). The request-scoped `report_progress` (and `Context.report_progress`) is the supported path: it reports against the inbound request's own token, no-ops when the caller did not ask for progress, and stops when the request completes. The deprecated method keeps working and emits `MCPDeprecationWarning`. The warning message deliberately departs from the " is deprecated as of " pattern used by the spec-driven deprecations: this one is an SDK API decision, not a spec retirement (2026-07-28 does not retire server-to-client progress). For "stops when the request completes" to hold on every dispatcher, `_DirectDispatchContext` now closes with its request the way `_JSONRPCDispatchContext` already did: `close()` runs in the dispatch handler's `finally`, after which `progress`/`notify` deliver nothing, `can_send_request` is False, and `send_raw_request` raises `NoBackChannelError` -- the closed state the `DispatchContext` protocol documents. Two pre-existing tests that asserted `can_send_request` on a context captured after its handler returned now sample it in-handler, and the closed-state contract tests are parametrized over both dispatchers. The interaction test that covered both the server and client side of late progress is split in two. The server-side property is proved positively on the wire through `report_progress`; it no longer relies on a session-bound standalone stream, so it also runs on the stateless streamable-http arm. The client-side late-drop test keeps using the deprecated explicit-token method -- the only API that can still produce a late notification -- under `pytest.warns`, and its arms are unchanged. The migration guide stops recommending the deprecated method anywhere and documents the replacement. --- docs/advanced/deprecated.md | 2 +- docs/migration.md | 32 ++++--- src/mcp/server/session.py | 5 + src/mcp/shared/direct_dispatcher.py | 22 ++++- tests/interaction/_requirements.py | 30 ++++-- tests/interaction/lowlevel/test_progress.py | 101 +++++++++++++++++--- tests/server/test_server_context.py | 6 +- tests/server/test_session.py | 27 +++++- tests/shared/test_context.py | 18 +++- tests/shared/test_dispatcher.py | 78 ++++++++++++++- 10 files changed, 273 insertions(+), 48 deletions(-) diff --git a/docs/advanced/deprecated.md b/docs/advanced/deprecated.md index 5bff0e955..1542e1b26 100644 --- a/docs/advanced/deprecated.md +++ b/docs/advanced/deprecated.md @@ -12,7 +12,7 @@ The table below names each deprecated feature, why it is going away, and the rep | **Server-initiated sampling**: `ctx.session.create_message()`, the `sampling_callback=` you pass to `Client(...)` | SEP-2577 retires the capability. | Return `InputRequiredResult` and let the client retry the call (see **Multi-round-trip requests**). | | **Protocol logging**: `ctx.log()`, `ctx.debug()`, `ctx.info()`, `ctx.warning()`, `ctx.error()`, `ctx.session.send_log_message()`, `client.set_logging_level()` | SEP-2577 retires the capability. Nothing in-protocol replaces it. | Ordinary `import logging` to stderr (see **Logging**). | | **`ping`**: `client.send_ping()` | **Removed** from the protocol, not merely deprecated. There is no `ping` method in 2026-07-28. | Nothing. It only works against a `mode="legacy"` connection. | -| **Client->server progress**: `client.send_progress_notification()` | 2026-07-28 makes progress server->client only. | Nothing to send. Your *server* reports progress with `ctx.report_progress()` (see **Progress**). | +| **Client->server progress**: `client.send_progress_notification()` | 2026-07-28 makes progress server->client only. | Nothing to send. Your *server* reports progress with `ctx.report_progress()` (see **Progress**). The SDK separately deprecates the server-side explicit-token `ctx.session.send_progress_notification()` in favour of `report_progress` (same warning category, same replacement). | Three things fall out of that table: diff --git a/docs/migration.md b/docs/migration.md index 392150df1..40f0c0b74 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -905,7 +905,7 @@ async def handle_tool(name: str, arguments: dict) -> list[TextContent]: # After (v2) async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: if ctx.meta and "progress_token" in ctx.meta: - await ctx.session.send_progress_notification(ctx.meta["progress_token"], 0.5, 100) + await ctx.session.report_progress(0.5, 100) ... server = Server("my-server", on_call_tool=handle_call_tool) @@ -967,7 +967,7 @@ async def my_tool(ctx: Context[MyLifespanState]) -> str: ... ### `ProgressContext` and `progress()` context manager removed -The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` or `session.send_progress_notification()` directly. +The `mcp.shared.progress` module (`ProgressContext`, `Progress`, and the `progress()` context manager) has been removed. This module had no real-world adoption — all users send progress notifications via `Context.report_progress()` directly. **Before (v1):** @@ -987,21 +987,11 @@ async def my_tool(x: int, ctx: Context) -> str: return "done" ``` -**After — use `session.send_progress_notification()` (low-level):** - -```python -await session.send_progress_notification( - progress_token=progress_token, - progress=25, - total=100, -) -``` - ### Handler progress reporting: prefer `ctx.report_progress()` over manual `progress_token` Reading `ctx.meta["progress_token"]` and calling `session.send_progress_notification(token, ...)` is specific to the JSON-RPC transport path. On the in-process modern path (`DirectDispatcher` / `Client(server)`), there is no wire token in `_meta`, so handlers that gate progress on the token's presence go silent. -`ctx.report_progress(progress, total, message)` works on every dispatcher: it sends a progress notification when a token is present and routes the update through the dispatcher's progress channel otherwise, no-opping only when the caller did not request progress at all. `session.send_progress_notification(progress_token, ...)` is unchanged and still works on JSON-RPC transports for code that already holds a token. +`ctx.report_progress(progress, total, message)` works on every dispatcher: it sends a progress notification when a token is present and routes the update through the dispatcher's progress channel otherwise, no-opping only when the caller did not request progress at all. `ServerSession.send_progress_notification(progress_token, ...)` is now deprecated; see **Progress API deprecations** below. ### `create_connected_server_and_client_session` removed @@ -1498,11 +1488,23 @@ warnings.filterwarnings("ignore", category=MCPDeprecationWarning) No migration is required during the deprecation window. New code should avoid building on these features, since they may be removed in a future spec version. -### Client-to-server progress deprecated (2026-07-28) +### Progress API deprecations (2026-07-28) The 2026-07-28 spec restricts `notifications/progress` to the server-to-client direction only — `ProgressNotification` is no longer in `ClientNotification`. `Client.send_progress_notification()` and `ClientSession.send_progress_notification()` now carry `typing_extensions.deprecated` and emit `mcp.MCPDeprecationWarning` at runtime. They continue to work against servers negotiating 2025-11-25 or earlier. -On the server side, prefer the new dispatcher-agnostic `ServerSession.report_progress(progress, total, message)` (and `Context.report_progress()` on `MCPServer`) over the raw `ServerSession.send_progress_notification(progress_token, …)`. `report_progress` encapsulates the "no-op when the caller did not request progress" rule and works on every dispatcher; the raw token-taking form remains for handlers that read `_meta.progressToken` directly. +On the server side, `ServerSession.send_progress_notification(progress_token, ...)` is also deprecated. It takes an explicit progress token decoupled from any request's lifetime, so it can emit progress for a request that has already completed -- which the spec forbids ("Progress notifications MUST stop after completion"). Use `Context.report_progress(progress, total, message)` (or `ServerSession.report_progress(...)` from a low-level handler): it reports against the inbound request's own progress token, works on every dispatcher, no-ops when the caller did not request progress, and is closed with the request, so a late call after the handler has returned sends nothing. + +```python +# Before +token = ctx.meta.get("progress_token") +if token is not None: + await ctx.session.send_progress_notification(token, 0.5, total=1.0) + +# After +await ctx.report_progress(0.5, total=1.0) +``` + +The deprecated method still works during the advisory window and emits `mcp.MCPDeprecationWarning`. ## Bug Fixes diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ca62fb9c8..01adb4c52 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -363,6 +363,11 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ await self._request_outbound.progress(progress, total, message) + @deprecated( + "send_progress_notification is deprecated; use report_progress, which is scoped to " + "the inbound request's progress token and stops when that request completes.", + category=MCPDeprecationWarning, + ) async def send_progress_notification( self, progress_token: str | int, diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 55ee1dc40..ec67248c5 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -61,12 +61,16 @@ class _DirectDispatchContext: """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) + _closed: bool = False @property def can_send_request(self) -> bool: - return self.transport.can_send_request + return self.transport.can_send_request and not self._closed async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + if self._closed: + logger.debug("dropped %s: dispatch context closed", method) + return await self._back_notify(method, params) async def send_raw_request( @@ -80,8 +84,14 @@ async def send_raw_request( return await self._back_request(method, params, opts) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - if self._on_progress is not None: - await self._on_progress(progress, total, message) + # Gated here, not via notify(): in-process progress never routes through + # notify() - it awaits the caller's callback inline (a pinned behaviour). + if self._closed or self._on_progress is None: + return + await self._on_progress(progress, total, message) + + def close(self) -> None: + self._closed = True class DirectDispatcher: @@ -241,6 +251,12 @@ async def _dispatch_request( if unexpected: logger.exception("request handler raised") raise MCPError(code=error.code, message=error.message, data=error.data) from None + finally: + # Close the back-channel: after the handler returns, a captured + # context's `progress`/`notify` deliver nothing and its + # `send_raw_request` raises `NoBackChannelError`, matching + # JSONRPCDispatcher's handler-finally. + dctx.close() except TimeoutError: raise MCPError( code=REQUEST_TIMEOUT, diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 6cf52dc53..ad492f095 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -655,8 +655,11 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "The spec MUST is not enforced: progress values are not validated on either side, so a " - "handler that emits non-increasing values has them forwarded to the callback unchanged." + "Intentional, not a gap to close: no MCP SDK (typescript, go, csharp) validates " + "sender-side progress monotonicity, and this one does not either. The spec MUST " + "binds the handler author, not the transport; non-increasing values are forwarded " + "to the callback unchanged so the receiving application sees what the sender sent. " + "docs/tutorial/progress.md states the author's obligation." ), ), ), @@ -665,13 +668,28 @@ def __post_init__(self) -> None: behavior="Progress notifications for a token stop once the associated request completes.", divergence=Divergence( note=( - "send_progress_notification does not check whether the token's request has already " - "completed; the late notification is sent and reaches the client." + "Holds on the supported path: report_progress is scoped to the inbound request's " + "dispatch context, which closes when the request completes, so a late report is a " + "no-op (proven by test_report_progress_after_the_request_completes_sends_nothing). " + "The deprecated explicit-token ServerSession.send_progress_notification has no " + "such gate and still delivers post-completion progress to the client (proven by " + "the test that pins the client-side late-drop). The gap closes when that method " + "is removed." ), ), arm_exclusions=( - ArmExclusion(reason="requires-session", transport="streamable-http-stateless"), - ArmExclusion(reason="requires-session", spec_version="2026-07-28"), + ArmExclusion( + reason="requires-session", + spec_version="2026-07-28", + note=( + "The wire proof observes notifications/progress via the message handler; " + "neither 2026-07-28 cell can produce one. The in-memory DirectDispatcher " + "delivers progress as an in-process callback and never constructs the " + "notification; the modern streamable-http dispatch context no-ops notify(). " + "The DirectDispatcher half of the property is covered by " + "tests/shared/test_dispatcher.py instead." + ), + ), ), ), "protocol:progress:late-dropped-by-client": Requirement( diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 7f75e18ee..42ea2a686 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -3,10 +3,10 @@ Server-to-client progress emitted during a request follows the same ordering guarantee as logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and over streamable HTTP only when sent with ``related_request_id`` so the notification rides the -originating request's POST stream rather than the standalone GET stream. These tests pass -``related_request_id`` so no synchronisation is needed. The client-to-server direction is a -standalone notification with no response to await, so that test waits on an event set by the -server's handler. +originating request's POST stream rather than the standalone GET stream. These tests report +through the request-scoped `report_progress`, which routes onto that stream, so no +synchronisation is needed. The client-to-server direction is a standalone notification with no +response to await, so that test waits on an event set by the server's handler. """ import anyio @@ -15,6 +15,7 @@ from inline_snapshot import snapshot from mcp_types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent +from mcp import MCPDeprecationWarning from mcp.server import Server, ServerRequestContext from mcp.server.session import ServerSession from mcp.shared.session import ProgressFnT @@ -197,15 +198,86 @@ async def call(label: str, collect: ProgressFnT) -> None: @requirement("protocol:progress:stops-after-completion") +async def test_report_progress_after_the_request_completes_sends_nothing(connect: Connect) -> None: + """Progress reported through `report_progress` after the request has completed never reaches + the client. + + The handler captures its `ServerSession`; once `call_tool` has returned, the request's + dispatch context is closed, so the late `report_progress` is dropped before any I/O. The + message handler is teed every inbound progress notification (matched-to-callback or not), so + `seen_on_wire == [0.5]` is a positive full-equality proof: the in-handler report arrived and + nothing else ever did. The `list_tools` round-trip after the late report flushes anything a + regression would have put in flight, so the negative is not racy. + + The arms here are all `JSONRPCDispatcher` (the two 2026-07-28 cells are excluded: neither can + put a `notifications/progress` message on the wire); the `DirectDispatcher` half of the same + close is proven by + `tests/shared/test_dispatcher.py::test_ctx_progress_and_notify_after_the_inbound_request_returns_are_dropped`, + which is parametrized over both dispatchers. + + Steps: + 1. The tool reports 0.5 through `report_progress` and captures `ctx.session`. + 2. After `call_tool` returns, wait until 0.5 has been observed on the wire. + 3. Call `report_progress(1.0)` on the captured session -- a no-op on the closed context. + 4. A second `list_tools` round-trip flushes the single ordered stream. + 5. Tear the connection down; assert the wire saw exactly [0.5]. + """ + captured: list[ServerSession] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + captured.append(ctx.session) + await ctx.session.report_progress(0.5) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + seen_on_wire: list[float] = [] + first_seen = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + assert isinstance(message, ProgressNotification) + seen_on_wire.append(message.params.progress) + first_seen.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + await first_seen.wait() + await captured[0].report_progress(1.0) + await client.list_tools() + + assert received == [0.5] + assert seen_on_wire == [0.5] + + @requirement("protocol:progress:late-dropped-by-client") -async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: - """A progress notification sent after the response is emitted, and the client drops it from the callback. - - This single body proves both halves: the server's `send_progress_notification` happily sends for - a token whose request has already completed (the spec MUST that progress stops is not enforced; - see the divergence on `stops-after-completion`), and the client, having removed the callback when - the call returned, does not deliver the late notification to it. The message handler observes the - late notification arriving so the test knows when to assert without polling. +@requirement("protocol:progress:stops-after-completion") +async def test_a_progress_notification_arriving_after_the_response_is_dropped_from_the_callback( + connect: Connect, +) -> None: + """A progress notification that arrives after its request has completed is not delivered to + the original progress callback. + + SDK-defined: the client removes the per-call progress callback when `call_tool` returns. + Producing a genuinely late notification requires the deprecated explicit-token + `ServerSession.send_progress_notification` -- the only API not tied to the request's + dispatch context -- used deliberately here under `pytest.warns`. The message handler + observes the late notification arriving so the test knows when to assert without polling. + + The arrival itself pins the server-side half: the deprecated explicit-token method has no + completion gate, so the late notification still reaches the wire. That is the known + stops-after-completion divergence; the spec-conforming `report_progress` path is proven + separately by `test_report_progress_after_the_request_completes_sends_nothing`. """ captured: list[tuple[ServerSession, ProgressToken]] = [] @@ -220,7 +292,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara token = ctx.meta.get("progress_token") assert token is not None captured.append((ctx.session, token)) - await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.report_progress(0.5) return CallToolResult(content=[TextContent(text="done")]) server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) @@ -241,7 +313,8 @@ async def message_handler(message: IncomingMessage) -> None: assert received == [0.5] server_session, token = captured[0] - await server_session.send_progress_notification(token, 1.0) + with pytest.warns(MCPDeprecationWarning, match=r"^send_progress_notification is deprecated"): + await server_session.send_progress_notification(token, 1.0) # pyright: ignore[reportDeprecated] await late_progress_arrived.wait() assert received == [0.5] diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 9a9eaa3d9..852b686fd 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -32,10 +32,14 @@ class _Lifespan: async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): captured: list[Context[_Lifespan]] = [] conn_holder: list[Connection] = [] + open_while_handling: list[bool] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn_holder[0]) captured.append(ctx) + # `can_send_request` is sampled in-handler: the dispatch context closes when the + # request returns, after which it is False on every dispatcher. + open_while_handling.append(ctx.can_send_request) return {} async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): @@ -46,7 +50,7 @@ async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | assert ctx.lifespan.name == "app" assert ctx.connection is conn_holder[0] assert ctx.transport.kind == "direct" - assert ctx.can_send_request is True + assert open_while_handling == [True] assert ctx.session_id == "sess-1" assert ctx.headers is None diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 25b8257eb..a1bc555d5 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -11,6 +11,7 @@ import mcp_types as types import pytest +from inline_snapshot import snapshot from mcp_types import ( ClientCapabilities, Implementation, @@ -20,6 +21,7 @@ from mcp_types.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION from pydantic import ValidationError +from mcp import MCPDeprecationWarning from mcp.server.connection import Connection from mcp.server.session import ServerSession from mcp.shared.dispatcher import CallOptions @@ -152,9 +154,30 @@ async def test_send_notification_routes_by_related_request_id(): standalone_ch = StubOutbound() session = _two_channel_session(request_ch, standalone_ch) await session.send_tool_list_changed() - await session.send_progress_notification("tok", 0.5, related_request_id="req-1") + await session.send_notification( + types.ResourceUpdatedNotification(params=types.ResourceUpdatedNotificationParams(uri="x://r")), + related_request_id="req-1", + ) assert [m for m, _ in standalone_ch.notifications] == ["notifications/tools/list_changed"] - assert [m for m, _ in request_ch.notifications] == ["notifications/progress"] + assert [m for m, _ in request_ch.notifications] == ["notifications/resources/updated"] + + +@pytest.mark.anyio +async def test_send_progress_notification_warns_but_still_sends_during_the_advisory_window(): + """SDK-defined: `send_progress_notification` is deprecated in favour of the request-scoped + `report_progress` (an explicit token is not tied to its request's lifetime). During the + advisory window it still sends.""" + request_ch = StubOutbound() + standalone_ch = StubOutbound() + session = _two_channel_session(request_ch, standalone_ch) + with pytest.warns(MCPDeprecationWarning) as caught: + await session.send_progress_notification("tok", 0.5) # pyright: ignore[reportDeprecated] + assert str(caught.list[0].message) == snapshot( + "send_progress_notification is deprecated; use report_progress, which is scoped to " + "the inbound request's progress token and stops when that request completes." + ) + assert [m for m, _ in standalone_ch.notifications] == ["notifications/progress"] + assert request_ch.notifications == [] @pytest.mark.anyio diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py index 25ebf8e3d..7ce8de31a 100644 --- a/tests/shared/test_context.py +++ b/tests/shared/test_context.py @@ -16,7 +16,7 @@ from mcp.shared.peer import ClientPeer from mcp.shared.transport_context import TransportContext -from .conftest import direct_pair, jsonrpc_pair +from .conftest import PairFactory, direct_pair from .test_dispatcher import Recorder, echo_handlers, running_pair DCtx = DispatchContext[TransportContext] @@ -25,10 +25,14 @@ @pytest.mark.anyio async def test_base_context_forwards_transport_and_cancel_requested(): captured: list[BaseContext[TransportContext]] = [] + open_while_handling: list[bool] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: bctx = BaseContext(ctx) captured.append(bctx) + # `can_send_request` is sampled in-handler: once the request returns the + # dispatch context closes and it becomes False (covered by the sibling below). + open_while_handling.append(bctx.can_send_request) return {} async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): @@ -37,21 +41,25 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | bctx = captured[0] assert bctx.transport.kind == "direct" assert isinstance(bctx.cancel_requested, anyio.Event) - assert bctx.can_send_request is True + assert open_while_handling == [True] assert bctx.meta is None @pytest.mark.anyio -async def test_base_context_can_send_request_reflects_dispatch_context_closed_state(): +async def test_base_context_can_send_request_reflects_dispatch_context_closed_state(pair_factory: PairFactory): """`can_send_request` must track the dctx, not the static transport flag, - so it agrees with whether `send_raw_request` would raise.""" + so it agrees with whether `send_raw_request` would raise. + + Parametrized over both dispatchers: the `DispatchContext` Protocol promises this for + every implementation, so the two must not be allowed to drift apart on it. + """ captured: list[BaseContext[TransportContext]] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: captured.append(BaseContext(ctx)) return {} - async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("t", None) bctx = captured[0] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 1f8208337..23c588bdc 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -25,7 +25,7 @@ from mcp.shared._compat import resync_tracer from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError +from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from .conftest import PairFactory, direct_pair @@ -226,6 +226,82 @@ async def server_on_request( assert result == {"ok": True} +@pytest.mark.anyio +async def test_ctx_progress_and_notify_after_the_inbound_request_returns_are_dropped( + pair_factory: PairFactory, +) -> None: + """A dispatch context is closed once its inbound request finishes, so a captured context's + `progress` and `notify` after the handler has returned deliver nothing. + + This is the `DispatchContext` contract (the `can_send_request` docstring in + `mcp/shared/dispatcher.py` names the closed state). The in-handler 0.5 report is the + positive control proving the progress channel works; the second round-trip after the + late calls flushes anything they would have put in flight, so the negative is not racy. + + `received == [(0.5, ...)]` is the load-bearing arm of the proof on `DirectDispatcher`, + which delivers progress straight to the caller's callback (no client-side late-drop + exists there). The `notifications/message` absence is the load-bearing arm on + `JSONRPCDispatcher`, whose receiving side tees every inbound notification to `on_notify` + regardless of callback registration. + """ + received: list[tuple[float, float | None, str | None]] = [] + contexts: list[DispatchContext[TransportContext]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + contexts.append(ctx) + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + late_ctx = contexts[0] + await late_ctx.progress(1.0) + await late_ctx.notify("notifications/message", {"level": "late"}) + await client.send_raw_request("tools/call", None) + assert received == [(0.5, 1.0, "halfway")] + assert ("notifications/message", {"level": "late"}) not in crec.notifications + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_after_the_inbound_request_returns_raises_no_back_channel( + pair_factory: PairFactory, +) -> None: + """A dispatch context's back-channel closes with its inbound request: once the handler has + returned, `can_send_request` is `False` and `send_raw_request` raises `NoBackChannelError`. + + This is the `DispatchContext` contract: `can_send_request` is `False` once the context has + been closed, and `send_raw_request` raises exactly then. The in-handler `True` is the + positive control -- the same context's back-channel was open while the request was in + flight, and the dispatcher pair is still running, so the rejection is the per-request + close, not a missing transport back-channel or a torn-down connection. + """ + contexts: list[DispatchContext[TransportContext]] = [] + open_while_handling: list[bool] = [] + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + contexts.append(ctx) + open_while_handling.append(ctx.can_send_request) + return {} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + late_ctx = contexts[0] + assert open_while_handling == [True] + assert late_ctx.can_send_request is False + with pytest.raises(NoBackChannelError) as exc: + await late_ctx.send_raw_request("ping", None) + assert exc.value.code == INVALID_REQUEST + + @pytest.mark.anyio async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair_factory: PairFactory): """Plain requests carry no transport metadata, so handlers see `None`.""" From 3605ec09b48f5fe245c911fb8e6f6c911ecb707c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 23:18:22 +0000 Subject: [PATCH 12/14] Type the elicitation requested schema on the send side ElicitRequestedSchema was a TypeAlias for dict[str, Any]; it is now a Pydantic model of the spec's restricted requested-schema subset, backed by a new PrimitiveSchemaDefinition union (StringSchema, NumberSchema, BooleanSchema, and the enum schemas). ServerSession.elicit_form (and the deprecated elicit alias) and ClientPeer.elicit_form accept only this model, so a nested-object property, an array-of-objects property, or an anyOf union is unconstructible at the only place a server author supplies a schema, rather than silently forwarded to the client. The spec restricts form-mode requested schemas to flat objects with primitive-typed properties only ("complex nested structures, arrays of objects ... are intentionally not supported"). The high-level Context.elicit / elicit_with_validation path is unchanged in behaviour: it converts the rendered JSON Schema into the typed model, keeping its existing per-field TypeError contract and producing value-identical wire output. Inbound is deliberately untouched. The wire field ElicitRequestFormParams.requested_schema stays a plain dict[str, Any], so older servers that emit anyOf for Optional form fields still reach the client's elicitation callback. The typed-model-to-wire-dict conversion lives in one place, ElicitRequestedSchema.to_wire(), which both send sites call. The schema family is extra="allow": keys schema.ts does not name (a top-level title, pattern, exclusiveMinimum, json_schema_extra keys) still round-trip, because the primitives-only restriction is carried by the union members' required type literals, not by extra-key rejection. --- docs/migration.md | 36 ++++ .../legacy_elicitation/server_lowlevel.py | 13 +- examples/stories/mrtr/server.py | 20 +- examples/stories/mrtr/server_lowlevel.py | 13 +- .../stories/stickynotes/server_lowlevel.py | 9 +- scripts/gen_surface_types.py | 2 +- src/mcp-types/mcp_types/__init__.py | 33 +++- src/mcp-types/mcp_types/_types.py | 177 ++++++++++++++++- src/mcp/server/elicitation.py | 28 +-- src/mcp/server/session.py | 8 +- src/mcp/shared/peer.py | 2 +- tests/interaction/_requirements.py | 20 +- .../interaction/lowlevel/test_elicitation.py | 185 +++++++++++++----- tests/interaction/lowlevel/test_flows.py | 7 +- .../server/mcpserver/test_url_elicitation.py | 4 +- tests/server/test_stateless_mode.py | 6 +- tests/shared/test_peer.py | 11 +- tests/types/test_parity.py | 36 +--- 18 files changed, 465 insertions(+), 145 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 40f0c0b74..c530682db 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -812,6 +812,42 @@ Positional calls (`await ctx.info("hello")`) are unaffected. `Context.elicit()` (and `elicit_with_validation()`) now render the schema first and validate each property against the spec's `PrimitiveSchemaDefinition`, raising `TypeError` at the call site for anything outside it. `Optional[T]` fields render as `{"type": ...}` with the field omitted from `required` (previously the non-spec `anyOf` shape). A bare `list[str]` field is rejected because it renders without the required enum items; use `list[Literal[...]]` or `list[str]` with `json_schema_extra` supplying the items. Unions of multiple primitives (e.g. `int | str`) and nested models are rejected. +### `ServerSession.elicit_form()` takes a typed `ElicitRequestedSchema` + +`ServerSession.elicit_form()` (and the deprecated `elicit()` alias, and `ClientPeer.elicit_form()`) +now take an `mcp_types.ElicitRequestedSchema` -- a Pydantic model of the spec's restricted +requested-schema subset -- instead of an arbitrary `dict[str, Any]`. `ElicitRequestedSchema` was +previously a `TypeAlias` for `dict[str, Any]`; it is now that model. A schema with a nested-object +property, an array-of-objects property, or an `anyOf` union is rejected at construction. + +**Why:** the spec restricts form-mode requested schemas to flat objects with primitive-typed +properties only ("complex nested structures, arrays of objects ... are intentionally not +supported"). Typing the send side makes a non-conforming schema impossible to construct rather +than silently forwarded. + +**How to migrate:** build the model in place of the dict, or validate an existing JSON Schema dict: + +```python +from mcp_types import BooleanSchema, ElicitRequestedSchema, StringSchema + +await ctx.session.elicit_form( + "Choose a username.", + ElicitRequestedSchema( + properties={"username": StringSchema(type="string"), "newsletter": BooleanSchema(type="boolean")}, + required=["username"], + ), +) + +# Or, if you already have a JSON Schema dict: +await ctx.session.elicit_form("Choose a username.", ElicitRequestedSchema.model_validate(my_schema)) +``` + +The high-level `Context.elicit()` / `elicit_with_validation()` path, which generates the schema +from a Pydantic model class, is unchanged. The wire type `ElicitRequestFormParams.requested_schema` +is still a plain `dict[str, Any]`: the client's inbound parsing deliberately tolerates +non-conforming schemas so older servers (which emit `anyOf` for `Optional` form fields) still +reach the elicitation callback. + ### Replace `RootModel` by union types with `TypeAdapter` validation The following union types are no longer `RootModel` subclasses: diff --git a/examples/stories/legacy_elicitation/server_lowlevel.py b/examples/stories/legacy_elicitation/server_lowlevel.py index 08c7c3a76..2570f3e81 100644 --- a/examples/stories/legacy_elicitation/server_lowlevel.py +++ b/examples/stories/legacy_elicitation/server_lowlevel.py @@ -8,14 +8,13 @@ from mcp.server.lowlevel import Server from stories._hosting import run_server_from_args -REGISTRATION_SCHEMA: types.ElicitRequestedSchema = { - "type": "object", - "properties": { - "username": {"type": "string"}, - "plan": {"type": "string", "enum": ["free", "pro", "team"]}, +REGISTRATION_SCHEMA = types.ElicitRequestedSchema( + properties={ + "username": types.StringSchema(type="string"), + "plan": types.UntitledSingleSelectEnumSchema(type="string", enum=["free", "pro", "team"]), }, - "required": ["username"], -} + required=["username"], +) LINK_INPUT_SCHEMA: dict[str, Any] = { "type": "object", "properties": {"provider": {"type": "string"}}, diff --git a/examples/stories/mrtr/server.py b/examples/stories/mrtr/server.py index d83c2e983..ca4623515 100644 --- a/examples/stories/mrtr/server.py +++ b/examples/stories/mrtr/server.py @@ -1,15 +1,21 @@ """Multi-round tool result (2026 era): a tool returns input_required and resumes from echoed state.""" -from mcp_types import ElicitRequest, ElicitRequestedSchema, ElicitRequestFormParams, ElicitResult, InputRequiredResult +from mcp_types import ( + BooleanSchema, + ElicitRequest, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitResult, + InputRequiredResult, +) from mcp.server.mcpserver import Context, MCPServer from stories._hosting import run_server_from_args -CONFIRM_SCHEMA: ElicitRequestedSchema = { - "type": "object", - "properties": {"confirm": {"type": "boolean", "description": "Proceed with the deployment?"}}, - "required": ["confirm"], -} +CONFIRM_SCHEMA = ElicitRequestedSchema( + properties={"confirm": BooleanSchema(type="boolean", description="Proceed with the deployment?")}, + required=["confirm"], +) def build_server() -> MCPServer: @@ -22,7 +28,7 @@ async def deploy(env: str, ctx: Context) -> str | InputRequiredResult: # First round: ask the client to elicit confirmation. request_state is opaque # to the client; here it carries the step name so the retry can verify the echo. ask = ElicitRequest( - params=ElicitRequestFormParams(message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA) + params=ElicitRequestFormParams(message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA.to_wire()) ) return InputRequiredResult(input_requests={"confirm": ask}, request_state="awaiting-confirm") # Retry round: the client echoed request_state byte-exact and supplied the answer. diff --git a/examples/stories/mrtr/server_lowlevel.py b/examples/stories/mrtr/server_lowlevel.py index 0ed13cea4..17b2356ae 100644 --- a/examples/stories/mrtr/server_lowlevel.py +++ b/examples/stories/mrtr/server_lowlevel.py @@ -8,11 +8,10 @@ from mcp.server.lowlevel import Server from stories._hosting import run_server_from_args -CONFIRM_SCHEMA: types.ElicitRequestedSchema = { - "type": "object", - "properties": {"confirm": {"type": "boolean", "description": "Proceed with the deployment?"}}, - "required": ["confirm"], -} +CONFIRM_SCHEMA = types.ElicitRequestedSchema( + properties={"confirm": types.BooleanSchema(type="boolean", description="Proceed with the deployment?")}, + required=["confirm"], +) DEPLOY_INPUT_SCHEMA: dict[str, Any] = { "type": "object", "properties": {"env": {"type": "string"}}, @@ -42,7 +41,9 @@ async def call_tool( responses = params.input_responses if responses is None or "confirm" not in responses: ask = types.ElicitRequest( - params=types.ElicitRequestFormParams(message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA) + params=types.ElicitRequestFormParams( + message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA.to_wire() + ) ) return types.InputRequiredResult(input_requests={"confirm": ask}, request_state="awaiting-confirm") assert params.request_state == "awaiting-confirm", params.request_state diff --git a/examples/stories/stickynotes/server_lowlevel.py b/examples/stories/stickynotes/server_lowlevel.py index 15a20a797..146ea6ed8 100644 --- a/examples/stories/stickynotes/server_lowlevel.py +++ b/examples/stories/stickynotes/server_lowlevel.py @@ -22,11 +22,10 @@ def claim_id(self) -> str: return nid -CONFIRM_SCHEMA: dict[str, Any] = { - "type": "object", - "properties": {"confirm": {"type": "boolean", "title": "Yes, permanently delete every sticky note"}}, - "required": ["confirm"], -} +CONFIRM_SCHEMA = types.ElicitRequestedSchema( + properties={"confirm": types.BooleanSchema(type="boolean", title="Yes, permanently delete every sticky note")}, + required=["confirm"], +) TOOLS = [ types.Tool( diff --git a/scripts/gen_surface_types.py b/scripts/gen_surface_types.py index f33862909..b754899aa 100644 --- a/scripts/gen_surface_types.py +++ b/scripts/gen_surface_types.py @@ -44,7 +44,7 @@ # Older python-sdk releases emit `anyOf` for Optional fields; the callback's # own schema validation is the real gate, so accept any property shape inbound. # PrimitiveSchemaDefinition becomes an orphan $def after this patch but - # datamodel-codegen still emits it; elicitation.py imports it as the gate type. + # datamodel-codegen still emits it; the monolith carries the user-facing family. ( "$defs/ElicitRequestFormParams/properties/requestedSchema/properties/properties/additionalProperties", {"$ref": "#/$defs/PrimitiveSchemaDefinition"}, diff --git a/src/mcp-types/mcp_types/__init__.py b/src/mcp-types/mcp_types/__init__.py index 2ed97cba3..4538bf221 100644 --- a/src/mcp-types/mcp_types/__init__.py +++ b/src/mcp-types/mcp_types/__init__.py @@ -15,6 +15,7 @@ AudioContent, BaseMetadata, BlobResourceContents, + BooleanSchema, CacheableResult, CallToolRequest, CallToolRequestParams, @@ -57,6 +58,8 @@ ElicitResult, EmbeddedResource, EmptyResult, + EnumOption, + EnumSchema, FormElicitationCapability, GetPromptRequest, GetPromptRequestParams, @@ -82,6 +85,7 @@ InputResponse, InputResponseRequestParams, InputResponses, + LegacyTitledEnumSchema, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, @@ -101,12 +105,15 @@ MissingRequiredClientCapabilityErrorData, ModelHint, ModelPreferences, + MultiSelectEnumSchema, Notification, NotificationParams, + NumberSchema, PaginatedRequest, PaginatedRequestParams, PaginatedResult, PingRequest, + PrimitiveSchemaDefinition, ProgressNotification, ProgressNotificationParams, ProgressToken, @@ -152,7 +159,9 @@ ServerTasksRequestsCapability, SetLevelRequest, SetLevelRequestParams, + SingleSelectEnumSchema, StopReason, + StringSchema, SubscribeRequest, SubscribeRequestParams, SubscriptionFilter, @@ -176,6 +185,9 @@ TasksToolsCapability, TextContent, TextResourceContents, + TitledMultiSelectEnumItems, + TitledMultiSelectEnumSchema, + TitledSingleSelectEnumSchema, Tool, ToolAnnotations, ToolChoice, @@ -187,6 +199,9 @@ UnsubscribeRequest, UnsubscribeRequestParams, UnsupportedProtocolVersionErrorData, + UntitledMultiSelectEnumItems, + UntitledMultiSelectEnumSchema, + UntitledSingleSelectEnumSchema, UrlElicitationCapability, client_notification_adapter, client_request_adapter, @@ -232,21 +247,37 @@ "LOG_LEVEL_META_KEY", # Type aliases and variables "ContentBlock", - "ElicitRequestedSchema", "ElicitRequestParams", + "EnumSchema", "IncludeContext", "InputRequest", "InputRequests", "InputResponse", "InputResponses", "LoggingLevel", + "MultiSelectEnumSchema", + "PrimitiveSchemaDefinition", "ProgressToken", "ResultType", "Role", "SamplingContent", "SamplingMessageContentBlock", + "SingleSelectEnumSchema", "StopReason", "TaskStatus", + # Elicitation requested-schema models (form-mode property schemas; primitives only) + "BooleanSchema", + "ElicitRequestedSchema", + "EnumOption", + "LegacyTitledEnumSchema", + "NumberSchema", + "StringSchema", + "TitledMultiSelectEnumItems", + "TitledMultiSelectEnumSchema", + "TitledSingleSelectEnumSchema", + "UntitledMultiSelectEnumItems", + "UntitledMultiSelectEnumSchema", + "UntitledSingleSelectEnumSchema", # Base classes "BaseMetadata", "Request", diff --git a/src/mcp-types/mcp_types/_types.py b/src/mcp-types/mcp_types/_types.py index 34dc10083..efac132df 100644 --- a/src/mcp-types/mcp_types/_types.py +++ b/src/mcp-types/mcp_types/_types.py @@ -1926,9 +1926,172 @@ class ElicitCompleteNotification( params: ElicitCompleteNotificationParams -# Kept as a raw JSON Schema dict so callers can hand it straight to a validator; -# the per-version packages model RequestedSchema/PrimitiveSchemaDefinition strictly. -ElicitRequestedSchema: TypeAlias = dict[str, Any] +class _OpenSchemaModel(MCPModel): + """Internal base for the elicitation requested-schema family: an open key set. + + Keys schema.ts does not name still round-trip to the wire. The high-level renderer + (`model_json_schema`) emits non-spec JSON Schema keys -- a top-level `title`, + `exclusiveMinimum` / `pattern` from `Field(gt=...)` / `Field(pattern=...)`, and anything + in `json_schema_extra` -- and `extra="allow"` forwards them unchanged. The + primitives-only restriction is carried by the union members' required `type` literals + and required fields, NOT by extra-key rejection: `forbid` would break the high-level + path, and `ignore` would silently rewrite the user's schema. + """ + + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, extra="allow") + + +class StringSchema(_OpenSchemaModel): + """A string-typed form field in an elicitation requested schema.""" + + type: Literal["string"] + title: str | None = None + description: str | None = None + min_length: int | None = None + max_length: int | None = None + format: Literal["email", "uri", "date", "date-time"] | None = None + default: str | None = None + + +class NumberSchema(_OpenSchemaModel): + """A number- or integer-typed form field in an elicitation requested schema.""" + + type: Literal["number", "integer"] + title: str | None = None + description: str | None = None + minimum: int | float | None = None + maximum: int | float | None = None + default: int | float | None = None + + +class BooleanSchema(_OpenSchemaModel): + """A boolean-typed form field in an elicitation requested schema.""" + + type: Literal["boolean"] + title: str | None = None + description: str | None = None + default: bool | None = None + + +class EnumOption(_OpenSchemaModel): + """An enum value paired with the display title clients show for it.""" + + const: str + title: str + + +class UntitledSingleSelectEnumSchema(_OpenSchemaModel): + """Single-selection enum form field whose options have no display titles.""" + + type: Literal["string"] + title: str | None = None + description: str | None = None + enum: list[str] + default: str | None = None + + +class TitledSingleSelectEnumSchema(_OpenSchemaModel): + """Single-selection enum form field with a display title for each option.""" + + type: Literal["string"] + title: str | None = None + description: str | None = None + one_of: list[EnumOption] + default: str | None = None + + +SingleSelectEnumSchema: TypeAlias = UntitledSingleSelectEnumSchema | TitledSingleSelectEnumSchema +"""Single-selection enum form field, with or without display titles for its options.""" + + +class UntitledMultiSelectEnumItems(_OpenSchemaModel): + """The `items` schema of a multi-select enum whose options have no display titles.""" + + type: Literal["string"] + enum: list[str] + + +class UntitledMultiSelectEnumSchema(_OpenSchemaModel): + """Multiple-selection enum form field whose options have no display titles.""" + + type: Literal["array"] + title: str | None = None + description: str | None = None + min_items: int | None = None + max_items: int | None = None + items: UntitledMultiSelectEnumItems + default: list[str] | None = None + + +class TitledMultiSelectEnumItems(_OpenSchemaModel): + """The `items` schema of a multi-select enum with a display title for each option.""" + + any_of: list[EnumOption] + + +class TitledMultiSelectEnumSchema(_OpenSchemaModel): + """Multiple-selection enum form field with a display title for each option.""" + + type: Literal["array"] + title: str | None = None + description: str | None = None + min_items: int | None = None + max_items: int | None = None + items: TitledMultiSelectEnumItems + default: list[str] | None = None + + +MultiSelectEnumSchema: TypeAlias = UntitledMultiSelectEnumSchema | TitledMultiSelectEnumSchema +"""Multiple-selection enum form field, with or without display titles for its options.""" + + +class LegacyTitledEnumSchema(_OpenSchemaModel): + """Enum form field using `enum` / `enumNames`; the spec will remove this in a future version. + + Use `TitledSingleSelectEnumSchema` instead. + """ + + type: Literal["string"] + title: str | None = None + description: str | None = None + enum: list[str] + enum_names: list[str] | None = None + default: str | None = None + + +EnumSchema: TypeAlias = SingleSelectEnumSchema | MultiSelectEnumSchema | LegacyTitledEnumSchema +"""Any of the elicitation enum form-field schemas.""" + +PrimitiveSchemaDefinition: TypeAlias = StringSchema | NumberSchema | BooleanSchema | EnumSchema +"""Restricted schema definitions allowed as elicitation form fields: primitives only, no nesting.""" + + +class ElicitRequestedSchema(_OpenSchemaModel): + """A restricted subset of JSON Schema: a flat object whose properties are all primitive-typed. + + This is the only schema type the typed send methods (`ServerSession.elicit_form`, + `ClientPeer.elicit_form`) accept, so a nested-object or array-of-objects property is + unconstructible on the send side. The wire field `ElicitRequestFormParams.requested_schema` + stays a plain dict so the inbound parse remains lenient (see that field's docstring). + """ + + schema_: Annotated[str | None, Field(alias="$schema")] = None + type: Literal["object"] = "object" + properties: dict[str, PrimitiveSchemaDefinition] + required: list[str] | None = None + + def to_wire(self) -> dict[str, Any]: + """Serialize to the plain dict that `ElicitRequestFormParams.requested_schema` carries. + + This type never appears on the wire itself: the wire field is a deliberately + lenient `dict[str, Any]` so the client's inbound parse tolerates non-conforming + schemas from older servers. This method is the one place a typed schema becomes + that wire dict, and the typed send methods (`ServerSession.elicit_form`, + `ClientPeer.elicit_form`) both call it. + """ + # exclude_none, not exclude_unset: a keyword-constructed instance must keep its + # defaulted `type` literals on the wire; only the None-valued optionals are dropped. + return self.model_dump(by_alias=True, exclude_none=True) class ElicitRequestFormParams(RequestParams): @@ -1944,10 +2107,16 @@ class ElicitRequestFormParams(RequestParams): message: str """The message to present to the user describing what information is being requested.""" - requested_schema: ElicitRequestedSchema + requested_schema: dict[str, Any] """ A restricted subset of JSON Schema defining the structure of the expected response. Only top-level properties are allowed, without nesting. + + Deliberately a plain dict, not `ElicitRequestedSchema`: this field is also the client's + inbound parse, and older servers emit `anyOf` for `Optional` form fields, which must + still reach the user's elicitation callback. The spec's primitives-only restriction is + enforced on the send side, where `ServerSession.elicit_form` and `ClientPeer.elicit_form` + only accept `ElicitRequestedSchema`. """ task: TaskMetadata | None = None diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index dc0e669c8..07976e921 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -4,11 +4,8 @@ from typing import Any, Generic, Literal, TypeVar -from mcp_types import RequestId - -# Internal surface package; imported as the gate's source of truth for spec-valid property schemas. -from mcp_types.v2025_11_25 import PrimitiveSchemaDefinition -from pydantic import BaseModel, ValidationError +from mcp_types import ElicitRequestedSchema, PrimitiveSchemaDefinition, RequestId +from pydantic import BaseModel, TypeAdapter, ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import core_schema @@ -66,20 +63,27 @@ def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaVal return result -def _validate_rendered_properties(json_schema: dict[str, Any]) -> None: - """Reject any `properties` entry the spec's `PrimitiveSchemaDefinition` won't accept. +# `PrimitiveSchemaDefinition` is a flat union alias; validate one rendered property at a time through an adapter. +_PRIMITIVE_SCHEMA_ADAPTER = TypeAdapter[PrimitiveSchemaDefinition](PrimitiveSchemaDefinition) + - Catches whatever the renderer let through that isn't spec-valid: bare - `list[str]` (no enum), multi-primitive unions, nested models. +def _rendered_requested_schema(json_schema: dict[str, Any]) -> ElicitRequestedSchema: + """Convert a rendered JSON Schema dict into the typed `ElicitRequestedSchema`. + + The per-property loop rejects whatever the renderer let through that isn't + spec-valid: bare `list[str]` (no enum), multi-primitive unions, nested models. """ for field_name, prop in json_schema.get("properties", {}).items(): try: - PrimitiveSchemaDefinition.model_validate(prop) + _PRIMITIVE_SCHEMA_ADAPTER.validate_python(prop) except ValidationError: raise TypeError( f"Elicitation schema field {field_name!r} rendered as {prop!r}, " f"which is not a valid PrimitiveSchemaDefinition" ) from None + # The loop exists only to name the offending field; the renderer always produces a + # top-level object schema, so once every property passed this validate succeeds. + return ElicitRequestedSchema.model_validate(json_schema) async def elicit_with_validation( @@ -99,11 +103,9 @@ async def elicit_with_validation( For sensitive data like credentials or OAuth flows, use elicit_url() instead. """ json_schema = schema.model_json_schema(schema_generator=_ElicitationJsonSchema) - _validate_rendered_properties(json_schema) - result = await session.elicit_form( message=message, - requested_schema=json_schema, + requested_schema=_rendered_requested_schema(json_schema), related_request_id=related_request_id, ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 01adb4c52..eaece202b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -266,7 +266,8 @@ async def elicit( Args: message: The message to present to the user. - requested_schema: Schema defining the expected response structure. + requested_schema: Typed `ElicitRequestedSchema` defining the expected response + structure; non-primitive and nested properties are rejected at construction. related_request_id: Optional ID of the request that triggered this elicitation. Returns: @@ -288,7 +289,8 @@ async def elicit_form( Args: message: The message to present to the user. - requested_schema: Schema defining the expected response structure. + requested_schema: Typed `ElicitRequestedSchema` defining the expected response + structure; non-primitive and nested properties are rejected at construction. related_request_id: Optional ID of the request that triggered this elicitation. Returns: @@ -302,7 +304,7 @@ async def elicit_form( types.ElicitRequest( params=types.ElicitRequestFormParams( message=message, - requested_schema=requested_schema, + requested_schema=requested_schema.to_wire(), ), ), types.ElicitResult, diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py index ca59b56af..a92931fff 100644 --- a/src/mcp/shared/peer.py +++ b/src/mcp/shared/peer.py @@ -176,7 +176,7 @@ async def elicit_form( NoBackChannelError: No back-channel for server-initiated requests. pydantic.ValidationError: The peer's result does not match the expected result type. """ - params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema.to_wire()) result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) return ElicitResult.model_validate(result, by_name=False) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index ad492f095..28a1b27b9 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1775,18 +1775,18 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " - "validation at the low-level session layer (the high-level Context.elicit / " - "elicit_with_validation helper enforces primitive-only fields before generating the schema). " - "ClientSession likewise does not enforce it: the inbound surface gate is relaxed for " - "requestedSchema.properties so older servers that emit anyOf for Optional fields still reach " - "the elicitation callback." + "Enforced on the send side only: ServerSession.elicit_form / ClientPeer.elicit_form " + "take a typed ElicitRequestedSchema, so a non-primitive or nested property is " + "unconstructible on send. Inbound is deliberately lenient: " + "ElicitRequestFormParams.requested_schema stays a plain dict and the inbound surface " + "gate is relaxed for requestedSchema.properties, so older servers that emit anyOf " + "for Optional fields (and any other non-conforming schema) still reach the " + "elicitation callback. Interop beats purity on receive. The outbound closure is " + "scoped to property shapes, matching this requirement's behavior text: top-level " + "JSON Schema composition keys (allOf, $defs) are not rejected, because the top " + "level must stay open for the SDK's own renderer's non-spec title key." ), ), - arm_exclusions=( - ArmExclusion(reason="server-initiated-request", transport="streamable-http-stateless"), - ArmExclusion(reason="server-initiated-request", spec_version="2026-07-28"), - ), ), "elicitation:form:response-validation": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 0b687a7b0..306bfec87 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -1,15 +1,19 @@ """Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. -The final test plays the server's side of the wire by hand to issue an elicitation request with no -mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +Two tests play the server's side of the wire by hand: one to issue an elicitation request with no +mode field (the typed server API always serializes one), and one to deliver a non-conforming +requested schema (the typed server API can no longer construct one). """ +from typing import Any + import anyio import mcp_types as types import pytest from inline_snapshot import snapshot from mcp_types import ( INVALID_PARAMS, + BooleanSchema, CallToolResult, ElicitCompleteNotification, ElicitCompleteNotificationParams, @@ -25,8 +29,10 @@ JSONRPCRequest, JSONRPCResponse, ServerCapabilities, + StringSchema, TextContent, ) +from pydantic import ValidationError from mcp import MCPError, UrlElicitationRequiredError from mcp.client import ClientRequestContext, ClientSession @@ -39,14 +45,13 @@ pytestmark = pytest.mark.anyio -REQUESTED_SCHEMA: dict[str, object] = { - "type": "object", - "properties": { - "username": {"type": "string"}, - "newsletter": {"type": "boolean"}, +REQUESTED_SCHEMA = ElicitRequestedSchema( + properties={ + "username": StringSchema(type="string"), + "newsletter": BooleanSchema(type="boolean"), }, - "required": ["username"], -} + required=["username"], +) @requirement("elicitation:form:action:accept") @@ -118,7 +123,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "confirm" - answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + answer = await ctx.session.elicit_form("Proceed?", ElicitRequestedSchema(properties={})) return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -145,7 +150,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "confirm" - answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + answer = await ctx.session.elicit_form("Proceed?", ElicitRequestedSchema(properties={})) return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -182,7 +187,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "ask" try: - await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + await ctx.session.elicit_form("Anyone there?", ElicitRequestedSchema(properties={})) except MCPError as exc: errors.append(exc.error) return CallToolResult(content=[TextContent(text=exc.error.message)]) @@ -407,15 +412,18 @@ async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the ) -> None: """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. - One schema with one property per kind: a formatted string, an integer with bounds, a number, - a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The - callback observing the same schema as the handler sent proves both the primitive coverage and - the enum-variant coverage in one snapshot. + One schema with one property per kind: a formatted string, a constrained string, an integer with + bounds, a number, a boolean, a plain enum, a oneOf-const titled enum, and a multi-select + array-of-enum. The callback observing the exact dict `ElicitRequestedSchema` was built from + proves the primitive and enum-variant coverage AND that the typed model round-trips losslessly: + `username` carries `pattern`, a key the spec's prose documents on `StringSchema` but schema.ts + omits, so it must survive as an extra rather than be silently dropped from the wire. """ - schema: ElicitRequestedSchema = { + wire_schema: dict[str, Any] = { "type": "object", "properties": { "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "username": {"type": "string", "minLength": 3, "pattern": "^[A-Za-z]+$"}, "age": {"type": "integer", "minimum": 0, "maximum": 150}, "score": {"type": "number"}, "subscribe": {"type": "boolean", "default": False}, @@ -441,7 +449,9 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "onboard" - answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + answer = await ctx.session.elicit_form( + "Tell us about yourself.", ElicitRequestedSchema.model_validate(wire_schema) + ) return CallToolResult(content=[TextContent(text=answer.action)]) server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) @@ -457,58 +467,127 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest assert len(received) == 1 assert isinstance(received[0], ElicitRequestFormParams) - assert received[0].requested_schema == schema + assert received[0].requested_schema == wire_schema @requirement("elicitation:form:schema:restricted-subset") -async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: - """A requested schema with nested-object and array-of-object properties passes through unchanged. +def test_elicit_form_requested_schema_rejects_nested_object_and_array_of_object_properties() -> None: + """Spec-mandated: form-mode requested schemas are flat objects with primitive-typed properties only. + + `ElicitRequestedSchema` is the only schema type the typed send methods (`ServerSession.elicit_form`, + `ClientPeer.elicit_form`) accept, so a non-conforming schema is unconstructible on the send side. + The closure is scoped to property shapes: top-level keys outside schema.ts (`title`, `allOf`, + `$defs`, ...) still pass, because the top level is deliberately an open bag -- the SDK's own + schema renderer emits a top-level `title`. + """ + # No `match=`: pydantic-authored message text, not ours to pin. + with pytest.raises(ValidationError): + ElicitRequestedSchema.model_validate( + { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + }, + } + ) + with pytest.raises(ValidationError): + ElicitRequestedSchema.model_validate( + { + "type": "object", + "properties": { + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + ) + - The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; - this test pins that the SDK does not enforce that restriction on either side (see the - divergence on the requirement). The inbound surface gate is deliberately relaxed here so older - servers that emit `anyOf` for `Optional` form fields still reach the elicitation callback. +@requirement("elicitation:form:schema:restricted-subset") +async def test_an_inbound_elicitation_with_a_non_conforming_schema_still_reaches_the_callback() -> None: + """An inbound elicitation whose requested schema is non-conforming reaches the callback unchanged. + + Pins a deliberate, documented divergence: interop beats purity on receive. Older python-sdk + servers emit `anyOf` for `Optional` form fields, and the inbound surface gate is relaxed for + `requestedSchema.properties`, so those (and any other non-conforming property) still reach the + user's elicitation callback. The server's side of the wire is scripted by hand because the typed + server API can no longer produce a non-conforming schema. """ - schema: ElicitRequestedSchema = { + wire_schema: dict[str, Any] = { "type": "object", "properties": { - "address": { - "type": "object", - "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, - }, - "contacts": { - "type": "array", - "items": {"type": "object", "properties": {"name": {"type": "string"}}}, - }, + "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + "address": {"type": "object", "properties": {"street": {"type": "string"}}}, }, } - - async def list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] - ) - - async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: - assert params.name == "profile" - answer = await ctx.session.elicit_form("Profile details.", schema) - return CallToolResult(content=[TextContent(text=answer.action)]) - - server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) - received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: received.append(params) return ElicitResult(action="decline") - async with connect(server, elicitation_callback=answer_form) as client: - await client.call_tool("profile", {}) + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="older-sdk", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"mode": "form", "message": "Profile details.", "requestedSchema": wire_schema}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() assert len(received) == 1 assert isinstance(received[0], ElicitRequestFormParams) - assert received[0].requested_schema == schema + assert received[0].requested_schema == wire_schema + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 @requirement("elicitation:form:response-validation") @@ -533,7 +612,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert params.name == "signup" answer = await ctx.session.elicit_form( "Choose a name.", - {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ElicitRequestedSchema(properties={"name": StringSchema(type="string")}, required=["name"]), ) return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py index 19788db4a..bfd3888a1 100644 --- a/tests/interaction/lowlevel/test_flows.py +++ b/tests/interaction/lowlevel/test_flows.py @@ -16,13 +16,16 @@ URL_ELICITATION_REQUIRED, CallToolResult, ElicitCompleteNotification, + ElicitRequestedSchema, ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, EmptyResult, ListToolsResult, + NumberSchema, ReadResourceResult, ResourceLink, + StringSchema, TextContent, TextResourceContents, Tool, @@ -101,12 +104,12 @@ async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forwa async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "onboard" first = await ctx.session.elicit_form( - "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + "Step 1: choose a username.", ElicitRequestedSchema(properties={"name": StringSchema(type="string")}) ) assert first.action == "accept" and first.content is not None second = await ctx.session.elicit_form( f"Step 2: confirm age for {first.content['name']}.", - {"type": "object", "properties": {"age": {"type": "integer"}}}, + ElicitRequestedSchema(properties={"age": NumberSchema(type="integer")}), ) assert second.action == "accept" and second.content is not None return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index f2ce5e013..30710ae10 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -3,7 +3,7 @@ import anyio import mcp_types as types import pytest -from mcp_types import ElicitRequestParams, ElicitResult, TextContent +from mcp_types import ElicitRequestedSchema, ElicitRequestParams, ElicitResult, TextContent from pydantic import BaseModel, Field from mcp import Client @@ -296,7 +296,7 @@ async def use_deprecated_elicit(ctx: Context) -> str: # Use the deprecated elicit() method which should call elicit_form() result = await ctx.session.elicit( message="Enter your email", - requested_schema=EmailSchema.model_json_schema(), + requested_schema=ElicitRequestedSchema.model_validate(EmailSchema.model_json_schema()), ) if result.action == "accept" and result.content: diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 1124d69b7..11ec5c4de 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -100,7 +100,7 @@ async def test_create_message_raises_no_back_channel_without_related_id(no_chann async def test_elicit_form_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): """SDK-defined: `elicit_form` without a related id rides the standalone channel and raises.""" with pytest.raises(NoBackChannelError) as exc: - await no_channel_session.elicit_form(message="m", requested_schema={"type": "object", "properties": {}}) + await no_channel_session.elicit_form(message="m", requested_schema=types.ElicitRequestedSchema(properties={})) assert exc.value.method == "elicitation/create" @@ -116,7 +116,7 @@ async def test_elicit_url_raises_no_back_channel_without_related_id(no_channel_s async def test_elicit_deprecated_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): """SDK-defined: the deprecated `elicit` alias routes the same as `elicit_form` and raises.""" with pytest.raises(NoBackChannelError) as exc: - await no_channel_session.elicit(message="m", requested_schema={"type": "object", "properties": {}}) + await no_channel_session.elicit(message="m", requested_schema=types.ElicitRequestedSchema(properties={})) assert exc.value.method == "elicitation/create" @@ -134,7 +134,7 @@ async def test_elicit_form_with_related_id_rides_the_request_channel(): channel, so the no-channel standalone is never touched and the call succeeds.""" session, request_ch = _no_channel_session(StubOutbound(result={"action": "cancel"})) result = await session.elicit_form( - message="m", requested_schema={"type": "object", "properties": {}}, related_request_id=3 + message="m", requested_schema=types.ElicitRequestedSchema(properties={}), related_request_id=3 ) assert isinstance(result, types.ElicitResult) assert request_ch.requests[0][0] == "elicitation/create" diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py index 2fc92e2c8..82fb6508d 100644 --- a/tests/shared/test_peer.py +++ b/tests/shared/test_peer.py @@ -13,6 +13,7 @@ from mcp_types import ( CreateMessageResult, CreateMessageResultWithTools, + ElicitRequestedSchema, ElicitResult, ListRootsResult, SamplingMessage, @@ -93,15 +94,23 @@ async def test_peer_sample_with_tools_returns_with_tools_result(): @pytest.mark.anyio async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + """`elicit_form` puts the typed schema on the wire as the exact dict it was validated from. + + `minLength` is aliased (`min_length` on the model) and `required` / per-property optionals are + unset, so the equality fails if the typed-model-to-wire conversion drops the alias or starts + emitting null-valued keys. + """ rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + wire_schema: dict[str, Any] = {"type": "object", "properties": {"name": {"type": "string", "minLength": 2}}} async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): peer = ClientPeer(client) with anyio.fail_after(5): - result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + result = await peer.elicit_form("Your name?", ElicitRequestedSchema.model_validate(wire_schema)) method, params = rec.seen[0] assert method == "elicitation/create" assert params is not None and params["mode"] == "form" assert params["message"] == "Your name?" + assert params["requestedSchema"] == wire_schema assert isinstance(result, ElicitResult) diff --git a/tests/types/test_parity.py b/tests/types/test_parity.py index 080f343c3..0d1639c7a 100644 --- a/tests/types/test_parity.py +++ b/tests/types/test_parity.py @@ -20,16 +20,21 @@ # Surface classes whose monolith counterpart has a different name (key: "."). NAME_MAP: dict[str, type[BaseModel]] = { # v2025_11_25 + "v2025_11_25.AnyOfItem": monolith.EnumOption, "v2025_11_25.Argument": monolith.CompletionArgument, "v2025_11_25.Context": monolith.CompletionContext, "v2025_11_25.Data": monolith.ElicitationRequiredErrorData, "v2025_11_25.Elicitation": monolith.ElicitationCapability, "v2025_11_25.Elicitation1": monolith.TasksElicitationCapability, "v2025_11_25.ElicitationCompleteNotification": monolith.ElicitCompleteNotification, + "v2025_11_25.Items": monolith.TitledMultiSelectEnumItems, + "v2025_11_25.Items1": monolith.UntitledMultiSelectEnumItems, + "v2025_11_25.OneOfItem": monolith.EnumOption, "v2025_11_25.Params": monolith.CancelTaskRequestParams, "v2025_11_25.Params1": monolith.ElicitCompleteNotificationParams, "v2025_11_25.Params2": monolith.GetTaskPayloadRequestParams, "v2025_11_25.Params3": monolith.GetTaskRequestParams, + "v2025_11_25.RequestedSchema": monolith.ElicitRequestedSchema, "v2025_11_25.Error": monolith.ErrorData, "v2025_11_25.JSONRPCErrorResponse": monolith.JSONRPCError, "v2025_11_25.JSONRPCResultResponse": monolith.JSONRPCResponse, @@ -45,15 +50,20 @@ "v2025_11_25.Tools": monolith.TasksToolsCapability, "v2025_11_25.Tools1": monolith.ToolsCapability, # v2026_07_28 + "v2026_07_28.AnyOfItem": monolith.EnumOption, "v2026_07_28.Argument": monolith.CompletionArgument, "v2026_07_28.Context": monolith.CompletionContext, "v2026_07_28.Data": monolith.MissingRequiredClientCapabilityErrorData, "v2026_07_28.Data1": monolith.UnsupportedProtocolVersionErrorData, "v2026_07_28.Elicitation": monolith.ElicitationCapability, "v2026_07_28.Error": monolith.ErrorData, + "v2026_07_28.Items": monolith.TitledMultiSelectEnumItems, + "v2026_07_28.Items1": monolith.UntitledMultiSelectEnumItems, "v2026_07_28.JSONRPCErrorResponse": monolith.JSONRPCError, "v2026_07_28.JSONRPCResultResponse": monolith.JSONRPCResponse, + "v2026_07_28.OneOfItem": monolith.EnumOption, "v2026_07_28.Prompts": monolith.PromptsCapability, + "v2026_07_28.RequestedSchema": monolith.ElicitRequestedSchema, "v2026_07_28.Resources": monolith.ResourcesCapability, "v2026_07_28.Sampling": monolith.SamplingCapability, "v2026_07_28.Tools": monolith.ToolsCapability, @@ -63,30 +73,15 @@ SKIP: frozenset[str] = frozenset( { # v2025_11_25 - "v2025_11_25.AnyOfItem", - "v2025_11_25.BooleanSchema", "v2025_11_25.Error1", "v2025_11_25.Icons", "v2025_11_25.InputSchema", - "v2025_11_25.Items", - "v2025_11_25.Items1", - "v2025_11_25.LegacyTitledEnumSchema", "v2025_11_25.Meta", - "v2025_11_25.NumberSchema", - "v2025_11_25.OneOfItem", "v2025_11_25.OutputSchema", - "v2025_11_25.RequestedSchema", "v2025_11_25.ResourceRequestParams", - "v2025_11_25.StringSchema", "v2025_11_25.TaskAugmentedRequestParams", - "v2025_11_25.TitledMultiSelectEnumSchema", - "v2025_11_25.TitledSingleSelectEnumSchema", "v2025_11_25.URLElicitationRequiredError", - "v2025_11_25.UntitledMultiSelectEnumSchema", - "v2025_11_25.UntitledSingleSelectEnumSchema", # v2026_07_28 - "v2026_07_28.AnyOfItem", - "v2026_07_28.BooleanSchema", "v2026_07_28.CallToolResultResponse", "v2026_07_28.ClientNotification", "v2026_07_28.CompleteResultResponse", @@ -101,9 +96,6 @@ "v2026_07_28.InternalError", "v2026_07_28.InvalidParamsError", "v2026_07_28.InvalidRequestError", - "v2026_07_28.Items", - "v2026_07_28.Items1", - "v2026_07_28.LegacyTitledEnumSchema", "v2026_07_28.ListPromptsResultResponse", "v2026_07_28.ListResourceTemplatesResultResponse", "v2026_07_28.ListResourcesResultResponse", @@ -112,22 +104,14 @@ "v2026_07_28.MethodNotFoundError", "v2026_07_28.MissingRequiredClientCapabilityError", "v2026_07_28.NotificationMetaObject", - "v2026_07_28.NumberSchema", - "v2026_07_28.OneOfItem", "v2026_07_28.OutputSchema", "v2026_07_28.Params", "v2026_07_28.ParseError", "v2026_07_28.ReadResourceResultResponse", "v2026_07_28.RequestMetaObject", - "v2026_07_28.RequestedSchema", "v2026_07_28.ResourceRequestParams", - "v2026_07_28.StringSchema", "v2026_07_28.SubscriptionsListenResultMeta", - "v2026_07_28.TitledMultiSelectEnumSchema", - "v2026_07_28.TitledSingleSelectEnumSchema", "v2026_07_28.UnsupportedProtocolVersionError", - "v2026_07_28.UntitledMultiSelectEnumSchema", - "v2026_07_28.UntitledSingleSelectEnumSchema", } ) From c2b3e8ee170e73a852f6d380d1e1c4a3580f6198 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 26 Jun 2026 23:23:48 +0000 Subject: [PATCH 13/14] Record two divergences as intentional, ecosystem-consistent choices Two interaction-suite divergence notes described their gaps as unenforced spec MUSTs without saying whether closing them was planned. Both are deliberate, and the notes now say so. transport:stdio:stream-purity -- stdio_server does not redirect sys.stdout, so a handler print() corrupts the protocol stream. No MCP SDK redirects stdout, and a redirect would only catch print(), not os.write(1, ...) or C extensions writing to file descriptor 1, so it would be a partial guard rather than a structural fix. docs/tutorial/logging.md already tells server authors to log to stderr and never print() in a stdio server, so no docs change was needed. protocol:progress:monotonic -- no MCP SDK validates sender-side progress monotonicity; the spec MUST is a contract on the handler author, not on the transport, and the test pins the unvalidated pass-through. No code or test behaviour changes. --- tests/interaction/_requirements.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 28a1b27b9..c8986f28d 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -656,10 +656,11 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "Intentional, not a gap to close: no MCP SDK (typescript, go, csharp) validates " - "sender-side progress monotonicity, and this one does not either. The spec MUST " - "binds the handler author, not the transport; non-increasing values are forwarded " - "to the callback unchanged so the receiving application sees what the sender sent. " - "docs/tutorial/progress.md states the author's obligation." + "sender-side progress monotonicity, and this one does not either. The spec MUST is " + "a contract on the handler author, not on the transport; non-increasing values are " + "forwarded to the callback unchanged so the receiving application sees what the " + "sender sent, and the test pins that pass-through. docs/tutorial/progress.md " + "states the author's obligation." ), ), ), @@ -3651,9 +3652,13 @@ def __post_init__(self) -> None: note="Only observable over stdio: stdin/stdout purity is stdio-specific.", divergence=Divergence( note=( - "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " - "handler code that calls print() writes directly to the protocol stream and corrupts the " - "framing. The spec MUST is satisfied only as long as application code behaves." + "Intentional, not a gap to close: the SDK's own writes are pure, but sys.stdout is not " + "redirected, so handler code that calls print() writes into the protocol stream and " + "corrupts the framing. This matches every other MCP SDK -- none redirects stdout. A " + "redirect would only catch print(): os.write(1, ...) and C extensions that write to file " + "descriptor 1 bypass sys.stdout entirely, so it would be a partial guard rather than a " + "structural fix. docs/tutorial/logging.md tells server authors to log to stderr and " + "never print() in a stdio server." ), ), ), From c53aefd2931f85193c18e2487aed5212966ead98 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sun, 28 Jun 2026 11:30:26 +0000 Subject: [PATCH 14/14] Close cancelled HTTP exchanges and harden auth validation Review-feedback round on the conformance burn-down: - Cancelled requests no longer leave the legacy streamable-HTTP POST hanging. The dispatcher emits a RequestSettled marker when a handler is cancelled without producing a response; the transport consumes it by closing the per-request stream, so the POST's SSE stream terminates without a response frame and JSON-response mode completes with 204 No Content (the client treats 202/204 alike). Per-request streams are released instead of leaking until session teardown, and a handler that survives the cancellation still delivers its normal response. The marker is type-visible on the dispatcher write stream and is stripped by every serializing transport, so it can never appear on a wire. - A bearer token whose audience cannot be canonicalized (out-of-range or non-numeric port) is now rejected with the standard 401 invalid_token instead of raising through the auth middleware as a 500. - The bundled authorization server's /register now accepts only https redirect URIs or http on a loopback host; other schemes on loopback hosts (ftp, ws, javascript, custom) are rejected. - OAuth client scope selection falls back to the caller-configured OAuthClientMetadata.scope when neither the WWW-Authenticate challenge nor protected-resource metadata names scopes, matching the TypeScript SDK, so the documented migration path works as written. - The cross-dispatcher contract that handler-raised MCPError subclasses surface to callers as plain MCPError is now pinned by an explicit test and documented; rehydrate with from_error when the subclass matters. - Docs: migration notes for the bearer-challenge wire-shape changes and the cancellation wire spellings; story READMEs updated to the landed error contract; strict-capabilities doc corrected to state that resources/unsubscribe is gated by the base resources capability only. --- docs/client/index.md | 2 +- docs/migration.md | 18 +- .../mcp_simple_auth_client/main.py | 6 +- examples/stories/error_handling/README.md | 7 +- examples/stories/streaming/README.md | 8 +- src/mcp/client/__main__.py | 6 +- src/mcp/client/_transport.py | 7 +- src/mcp/client/auth/oauth2.py | 6 +- src/mcp/client/auth/utils.py | 15 +- src/mcp/client/session.py | 6 +- src/mcp/client/sse.py | 6 +- src/mcp/client/stdio.py | 6 +- src/mcp/client/streamable_http.py | 23 +- src/mcp/server/auth/handlers/register.py | 17 +- src/mcp/server/lowlevel/server.py | 6 +- src/mcp/server/runner.py | 6 +- src/mcp/server/sse.py | 6 +- src/mcp/server/stdio.py | 6 +- src/mcp/server/streamable_http.py | 38 +- src/mcp/shared/auth_utils.py | 22 +- src/mcp/shared/direct_dispatcher.py | 5 + src/mcp/shared/jsonrpc_dispatcher.py | 65 ++- src/mcp/shared/memory.py | 18 +- src/mcp/shared/message.py | 27 +- tests/client/test_auth.py | 245 +++++++++++ tests/client/test_notification_response.py | 12 +- tests/client/test_session.py | 114 +++-- tests/client/test_stdio.py | 4 +- tests/interaction/_helpers.py | 22 +- tests/interaction/_requirements.py | 55 +++ tests/interaction/auth/test_as_handlers.py | 59 ++- .../interaction/auth/test_authorize_token.py | 110 ++++- tests/interaction/auth/test_bearer.py | 37 ++ .../transports/test_hosting_http.py | 229 ++++++++++ .../transports/test_streamable_http.py | 59 +++ tests/issues/test_192_request_id.py | 4 +- tests/issues/test_88_random_error.py | 14 +- tests/server/test_cancel_handling.py | 6 +- tests/server/test_lifespan.py | 10 +- .../test_lowlevel_exception_handling.py | 4 +- tests/shared/conftest.py | 6 +- tests/shared/test_auth_utils.py | 44 ++ tests/shared/test_dispatcher.py | 42 +- tests/shared/test_jsonrpc_dispatcher.py | 392 ++++++++++++------ tests/shared/test_message.py | 24 ++ tests/shared/test_streamable_http.py | 92 +++- 46 files changed, 1611 insertions(+), 305 deletions(-) create mode 100644 tests/shared/test_message.py diff --git a/docs/client/index.md b/docs/client/index.md index 2e74e900d..634825aa3 100644 --- a/docs/client/index.md +++ b/docs/client/index.md @@ -145,7 +145,7 @@ The resource verbs come in pairs: two ways to list, one way to read. `read_resource` returns `contents`, a list of `TextResourceContents` or `BlobResourceContents`. Same idea as tool content: narrow with `isinstance`, then read `.text` (or `.blob`). -A client can also **subscribe** to a resource and be told when it changes: `subscribe_resource(uri)` and `unsubscribe_resource(uri)`, same shape as everything else here. `MCPServer` doesn't implement that half. It says so up front (`server_capabilities.resources.subscribe` is `False`) and answers the request with an `MCPError`: `-32601`, *Method not found*. With `strict_capabilities=True` you get the same `-32601` without the round trip: the client sees `server_capabilities.resources.subscribe` is falsy and never sends the request. A server that does support subscriptions is built on the low-level `Server` (**The low-level Server**). +A client can also **subscribe** to a resource and be told when it changes: `subscribe_resource(uri)` and `unsubscribe_resource(uri)`, same shape as everything else here. `MCPServer` doesn't implement that half. It says so up front (`server_capabilities.resources.subscribe` is `False`) and answers the request with an `MCPError`: `-32601`, *Method not found*. With `strict_capabilities=True` you get the same `-32601` for `subscribe_resource` without the round trip: the client sees `server_capabilities.resources.subscribe` is falsy and never sends the request. `unsubscribe_resource` is still sent — only the base `resources` capability gates it, matching the TypeScript SDK — so its `-32601` comes from the server. A server that does support subscriptions is built on the low-level `Server` (**The low-level Server**). ## Prompts diff --git a/docs/migration.md b/docs/migration.md index c530682db..bad79afb5 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -41,13 +41,19 @@ An unhandled exception in a request handler now produces JSON-RPC error `-32603` code `0` with `str(exc)` as the message, leaking handler internals to the peer; the exception is still logged server-side via `logger.exception`. To send a specific code and message, raise `MCPError` (unchanged); a pydantic `ValidationError` is -still mapped to `INVALID_PARAMS`. +still mapped to `INVALID_PARAMS`. An `MCPError` *subclass* raised by a handler +(e.g. `UrlElicitationRequiredError`) reaches the caller as plain `MCPError` with +the same `code`, `message`, and `data` on every dispatch path — including the +in-process `Client(server)` — so catch `MCPError` and match on `error.code` +rather than on the subclass type. A request cancelled via `notifications/cancelled` now receives no response at all, per the spec's SHOULD. v1 answered the cancelled request with an error (`code=0, message="Request cancelled"`). The sender's awaiting call already fails with anyio cancellation when its scope is cancelled, so no reply is needed to -unblock it. +unblock it. On legacy streamable HTTP this means the cancelled request's POST SSE +stream now terminates without a response frame, and in JSON-response mode the POST +completes with `204 No Content`. ### A second `initialize` on an already-initialized session is rejected @@ -148,7 +154,7 @@ When no authorization-server metadata document could be discovered at all, the f The specification's scope-selection chain is two steps: the `scope` parameter from the `WWW-Authenticate` challenge, then `scopes_supported` from the Protected Resource Metadata document, *otherwise the `scope` parameter is omitted*. The SDK inserted an extra fallback between those two steps — the **authorization-server** metadata's `scopes_supported` — which over-requests (an authorization server may serve many resource servers, so its list is a superset of any one resource's) and caused real `access_denied` failures ([#1307](https://github.com/modelcontextprotocol/python-sdk/issues/1307)). That fallback is removed: when neither the challenge nor the PRM names scopes, the client now omits the `scope` parameter and lets the authorization server apply its defaults. -This also affects the SEP-2207 `offline_access` augmentation, which only fires once a base scope was selected: if the authorization server's `scopes_supported` was your only scope source, the client now sends no `scope` at all (not even `offline_access`) and the authorization server's defaults decide whether a refresh token is issued. In either case, if you relied on the removed fallback, pass an explicit `scope` on the `OAuthClientMetadata` you give to `OAuthClientProvider`. +This also affects the SEP-2207 `offline_access` augmentation, which only fires once a base scope was selected: if the authorization server's `scopes_supported` was your only scope source, the client now sends no `scope` at all (not even `offline_access`) and the authorization server's defaults decide whether a refresh token is issued. In either case, if you relied on the removed fallback, pass an explicit `scope` on the `OAuthClientMetadata` you give to `OAuthClientProvider`. That explicit scope ranks *below* the spec's two sources (matching the TypeScript SDK): a `scope` on the `WWW-Authenticate` challenge or a PRM `scopes_supported` still wins, so the explicit value only takes effect when both are silent. ### `get_session_id` callback removed from `streamable_http_client` @@ -1583,6 +1589,12 @@ Leaving `resource_server_url=None` continues to disable the check entirely (ther `RefreshToken` gains an optional `resource` field so an `OAuthAuthorizationServerProvider` can propagate the original grant's audience binding through `exchange_refresh_token`; without it a refreshed access token would carry no audience and be rejected. `BearerAuthBackend.__init__` gains a keyword-only `resource_server_url: AnyHttpUrl | None = None`, wired automatically from `AuthSettings.enforced_audience`; `None` (the default, and what the SDK passes when `verifier_validates_audience` is set) means no audience is enforced. +The error responses the bearer gate sends changed shape in the same release, following RFC 6750 §3: + +- A request presenting **no credentials at all** (no `Authorization: Bearer` header) is now answered with a bare challenge — `WWW-Authenticate: Bearer scope="…", resource_metadata="…"` and an empty JSON body `{}` — instead of `error="invalid_token", error_description="Authentication required"` in both the header and the body. RFC 6750 §3.1 says the `error` attribute should only appear when the request actually carried a token, so "no token" and "rejected token" are now distinguishable; anything matching on the literal `Authentication required` string or expecting a non-empty 401 body must be updated. +- Every challenge — the `401`s and the `403` — now advertises the configured `required_scopes` in an RFC 6750 `scope="…"` parameter, which spec-conformant clients (including this SDK's OAuth client) read as the first step of scope selection and step-up authorization. +- The `error_description` strings changed. A rejected token now states the failure: `The access token is malformed or unknown`, `The access token has expired`, `The access token carries no audience claim`, or `The access token was issued for a different resource` (previously all of these — where they were rejected at all — said `Authentication required`). The `403 insufficient_scope` description is now the fixed string `The access token lacks a required scope` instead of `Required scope: {scope}`; the required scope moved to the machine-readable `scope=` challenge parameter. Treat `error_description` as human-readable prose, not a contract. + ### Bundled authorization server: RFC-correct redirect-URI handling Two fixes to the optional bundled OAuth authorization server (the `auth_server_provider=` path). diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 0d461d5d1..ce3906c4b 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -24,7 +24,7 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage class InMemoryTokenStorage(TokenStorage): @@ -248,8 +248,8 @@ async def _default_redirect_handler(authorization_url: str) -> None: async def _run_session( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception | RequestSettled], + write_stream: WriteStream[SessionMessage | RequestSettled], ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") diff --git a/examples/stories/error_handling/README.md b/examples/stories/error_handling/README.md index 475a2a0b2..f7ce3313b 100644 --- a/examples/stories/error_handling/README.md +++ b/examples/stories/error_handling/README.md @@ -35,9 +35,10 @@ uv run python -m stories.error_handling.client --http --server server_lowlevel ## Caveats - The "any other exception → `is_error` result" contract on `MCPServer` and the - "uncaught exception → `code=0`" behaviour on `lowlevel.Server` are **not - shown** — the contract is under design and the legacy code is a known spec - divergence. This story will grow those cases once the contract lands. + "uncaught exception → `-32603` `Internal server error`" behaviour on + `lowlevel.Server` are **not shown** here. The lowlevel reply is deliberately + opaque — handler internals never reach the peer; the exception is logged + server-side. - `MCPServer` prefixes the execution-error message with `"Error executing tool {name}: "`; build a `CallToolResult` directly from a lowlevel handler if you need verbatim control. diff --git a/examples/stories/streaming/README.md b/examples/stories/streaming/README.md index e6bedb915..493a5f2fb 100644 --- a/examples/stories/streaming/README.md +++ b/examples/stories/streaming/README.md @@ -56,10 +56,10 @@ uv run python -m stories.streaming.client --http --server server_lowlevel OpenTelemetry instead of `notifications/message`. It is shown here because servers still need to support 2025-era clients during that window. Progress and cancellation are **not** deprecated. TODO(maxisbey): revisit before beta. -- When a request is cancelled the server currently replies with - `ErrorData(code=0, message="Request cancelled")`; the spec says it should not - reply at all. The client never observes it (its awaiting task is already - cancelled), so this story does not assert on the reply. +- When a request is cancelled the server sends no reply at all, per the spec's + SHOULD. The client's awaiting task is already cancelled, so there is nothing + to observe on that call; this story asserts only that the in-flight call was + cancelled and that the session survives for a follow-up call. ## Spec diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 5fa3ce109..81ddce9f8 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -12,7 +12,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage from mcp.shared.session import RequestResponder if not sys.warnoptions: @@ -33,8 +33,8 @@ async def message_handler( async def run_session( - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception | RequestSettled], + write_stream: WriteStream[SessionMessage | RequestSettled], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/_transport.py b/src/mcp/client/_transport.py index 0163fef95..9ffca455a 100644 --- a/src/mcp/client/_transport.py +++ b/src/mcp/client/_transport.py @@ -6,11 +6,14 @@ from typing import Protocol from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage __all__ = ["ReadStream", "WriteStream", "Transport", "TransportStreams"] -TransportStreams = tuple[ReadStream[SessionMessage | Exception], WriteStream[SessionMessage]] +TransportStreams = tuple[ + ReadStream[SessionMessage | Exception | RequestSettled], + WriteStream[SessionMessage | RequestSettled], +] class Transport(AbstractAsyncContextManager[TransportStreams], Protocol): diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index f929decc4..feeb8f8af 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -645,12 +645,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_info = None self.context.clear_tokens() - # Step 3: Apply scope selection strategy + # Step 3: Apply scope selection strategy. The configured client-metadata + # scope is the lowest-priority fallback, and the selection is written back + # so registration (Step 4) and authorization (Step 5) read it — a later + # re-auth therefore falls back to this selection, not the constructor value. self.context.client_metadata.scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, self.context.client_metadata.grant_types, + self.context.client_metadata.scope, ) # Step 4: Register client or use URL-based client ID (CIMD) diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 848ef7360..c1becc33a 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -105,18 +105,25 @@ def get_client_metadata_scopes( protected_resource_metadata: ProtectedResourceMetadata | None, authorization_server_metadata: OAuthMetadata | None = None, client_grant_types: list[str] | None = None, + client_metadata_scope: str | None = None, ) -> str | None: - """Select effective scopes and augment for refresh token support.""" - selected_scope: str | None = None + """Select effective scopes and augment for refresh token support. - # MCP spec scope selection priority: + Follows the MCP spec's scope-selection strategy (challenge scope, then PRM + `scopes_supported`), with the caller's pre-configured `OAuthClientMetadata.scope` as a + final SDK-defined fallback (TypeScript-SDK parity) before omitting the parameter. + """ + # Scope selection priority (1-2 are the spec's chain; 3 is the SDK fallback): # 1. WWW-Authenticate header scope # 2. PRM scopes_supported - # 3. Omit scope parameter + # 3. Caller-supplied client metadata scope + # 4. Omit scope parameter if www_authenticate_scope is not None: selected_scope = www_authenticate_scope elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: selected_scope = " ".join(protected_resource_metadata.scopes_supported) + else: + selected_scope = client_metadata_scope # SEP-2207: append offline_access when the AS supports it and the client can use refresh tokens if ( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 64968856d..8e15eb70d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -45,7 +45,7 @@ x_mcp_header_map, ) from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, RequestSettled, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext @@ -215,8 +215,8 @@ class ClientSession: def __init__( self, - read_stream: ReadStream[SessionMessage | Exception] | None = None, - write_stream: WriteStream[SessionMessage] | None = None, + read_stream: ReadStream[SessionMessage | Exception | RequestSettled] | None = None, + write_stream: WriteStream[SessionMessage | RequestSettled] | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8b482932a..3bdd5f0de 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -13,7 +13,7 @@ from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import create_context_streams from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage, wire_messages logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ async def sse_client( logger.debug("SSE connection established") read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage | RequestSettled](0) async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): try: @@ -132,7 +132,7 @@ async def _send_message(session_message: SessionMessage) -> None: response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") - async for session_message in write_stream_reader: + async for session_message in wire_messages(write_stream_reader): sender_ctx = write_stream_reader.last_context if sender_ctx is not None: async with anyio.create_task_group() as tg: diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 3e03eef9e..ed2ce8b69 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -32,7 +32,7 @@ get_windows_executable_command, terminate_windows_process_tree, ) -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage, wire_messages logger = logging.getLogger(__name__) @@ -133,7 +133,7 @@ async def stdio_client( # The spawn succeeded; no awaits until the task group is entered, or a # cancellation delivered in the gap would leak the live process. read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage | RequestSettled](0) shutting_down = False writer_done = anyio.Event() @@ -170,7 +170,7 @@ async def stdin_writer() -> None: try: async with write_stream_reader: - async for session_message in write_stream_reader: + async for session_message in wire_messages(write_stream_reader): json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) data = (json + "\n").encode(encoding=server.encoding, errors=server.encoding_error_handler) await process.stdin.send(data) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index f28eb7c7a..285e2bfd6 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -33,7 +33,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, RequestSettled, SessionMessage, wire_messages logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ # TODO(Marcelo): Put the TransportStreams in a module under shared, so we can import here. SessionMessageOrError = SessionMessage | Exception StreamWriter = ContextSendStream[SessionMessageOrError] -StreamReader = ContextReceiveStream[SessionMessage] +StreamReader = ContextReceiveStream[SessionMessage | RequestSettled] MCP_SESSION_ID = "mcp-session-id" LAST_EVENT_ID = "last-event-id" @@ -265,8 +265,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: json=message.model_dump(by_alias=True, mode="json", exclude_unset=True), headers=headers, ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") + if response.status_code in (202, 204): + # 202: notification/response accepted. 204: the server settled the + # request with no reply (peer-cancelled); nothing to deliver — the + # waiter was retired at cancel time. + logger.debug(f"Received {response.status_code}") return if response.status_code >= 400: @@ -377,8 +380,10 @@ async def _handle_sse_response( except Exception: logger.debug("SSE stream ended", exc_info=True) # pragma: lax no cover - # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: # pragma: no branch + # Stream ended without response - reconnect if we received an event with ID. + # No priming event (no event store) means no id: the server settled the + # request without a reply and ended the stream, so the task just returns. + if last_event_id is not None: logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) @@ -445,7 +450,7 @@ async def post_writer( client: httpx.AsyncClient, write_stream_reader: StreamReader, read_stream_writer: StreamWriter, - write_stream: ContextSendStream[SessionMessage], + write_stream: ContextSendStream[SessionMessage | RequestSettled], start_get_stream: Callable[[], None], tg: TaskGroup, ) -> None: @@ -490,7 +495,7 @@ async def handle_request_async(): else: await handle_request_async() - async for session_message in write_stream_reader: + async for session_message in wire_messages(write_stream_reader): sender_ctx = write_stream_reader.last_context if sender_ctx is not None: async with anyio.create_task_group() as tg_local: @@ -568,7 +573,7 @@ async def streamable_http_client( await stack.enter_async_context(client) read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage | RequestSettled](0) async with ( read_stream_writer, diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 59ea456ef..b6f762694 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -23,20 +23,21 @@ class RegistrationRequest(OAuthClientMetadata): """The registration endpoint's inbound client metadata, with server-side redirect-URI policy. - The MCP authorization spec requires every redirect URI to use HTTPS or target a loopback - host, and OAuth 2.1 section 2.3 forbids a fragment component, so a request carrying a URI - that violates either fails validation and never reaches the provider. The base - `OAuthClientMetadata` stays permissive: the client also serializes it when registering - against third-party authorization servers whose redirect-URI policies the SDK does not own. + The MCP authorization spec and OAuth 2.1 require every redirect URI to use HTTPS, with plain + HTTP on a loopback host as the sole carve-out, and OAuth 2.1 section 2.3 forbids a fragment + component, so a request carrying a URI that violates either fails validation and never reaches + the provider. The base `OAuthClientMetadata` stays permissive: the client also serializes it + when registering against third-party authorization servers whose redirect-URI policies the SDK + does not own. """ @field_validator("redirect_uris", mode="after") @classmethod - def _https_or_loopback_without_fragment(cls, v: list[AnyUrl] | None) -> list[AnyUrl] | None: + def _https_or_loopback_http_without_fragment(cls, v: list[AnyUrl] | None) -> list[AnyUrl] | None: # None and an empty list both mean there is nothing to check. for uri in v or []: - if uri.scheme != "https" and uri.host not in LOOPBACK_HOSTS: - raise ValueError(f"redirect_uri must use https or target a loopback host: {uri}") + if not (uri.scheme == "https" or (uri.scheme == "http" and uri.host in LOOPBACK_HOSTS)): + raise ValueError(f"redirect_uri must use https or be a loopback http URI: {uri}") # `is not None`, not truthiness: a bare `https://x/cb#` parses with fragment == "". if uri.fragment is not None: raise ValueError(f"redirect_uri must not include a fragment: {uri}") diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dcfa2974a..6a6cb0bd1 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -67,7 +67,7 @@ async def main(): from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.exceptions import MCPDeprecationWarning -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage logger = logging.getLogger(__name__) @@ -633,8 +633,8 @@ def session_manager(self) -> StreamableHTTPSessionManager: async def run( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception | RequestSettled], + write_stream: WriteStream[SessionMessage | RequestSettled], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 1149bf2be..10bc380c3 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -49,7 +49,7 @@ from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import RequestSettled, ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext if TYPE_CHECKING: @@ -395,8 +395,8 @@ async def serve_connection( async def serve_loop( server: Server[LifespanT], - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception | RequestSettled], + write_stream: WriteStream[SessionMessage | RequestSettled], *, lifespan_state: LifespanT, session_id: str | None = None, diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 4d02fc4a7..8058d027f 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -56,7 +56,7 @@ async def handle_sse(request): TransportSecuritySettings, ) from mcp.shared._context_streams import ContextSendStream, create_context_streams -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import RequestSettled, ServerMessageMetadata, SessionMessage, wire_messages logger = logging.getLogger(__name__) @@ -136,7 +136,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): logger.debug("Setting up SSE connection") read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage | RequestSettled](0) session_id = uuid4() user = scope.get("user") @@ -168,7 +168,7 @@ async def sse_writer(): await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) logger.debug(f"Sent endpoint event: {client_post_uri_data}") - async for session_message in write_stream_reader: + async for session_message in wire_messages(write_stream_reader): logger.debug(f"Sending message via SSE: {session_message}") await sse_stream_writer.send( { diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 876d256dd..77a2c71f0 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -26,7 +26,7 @@ async def run_server(): import mcp_types as types from mcp.shared._context_streams import create_context_streams -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage, wire_messages @asynccontextmanager @@ -44,7 +44,7 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio. stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage | RequestSettled](0) async def stdin_reader(): try: @@ -64,7 +64,7 @@ async def stdin_reader(): async def stdout_writer(): try: async with write_stream_reader: - async for session_message in write_stream_reader: + async for session_message in wire_messages(write_stream_reader): json = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await stdout.write(json + "\n") await stdout.flush() diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d316345c7..32505e4a3 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -44,7 +44,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import RequestSettled, ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -149,8 +149,8 @@ class StreamableHTTPServerTransport: # Server notification streams for POST requests as well as standalone SSE stream _read_stream_writer: ContextSendStream[SessionMessage | Exception] | None = None _read_stream: ContextReceiveStream[SessionMessage | Exception] | None = None - _write_stream: ContextSendStream[SessionMessage] | None = None - _write_stream_reader: ContextReceiveStream[SessionMessage] | None = None + _write_stream: ContextSendStream[SessionMessage | RequestSettled] | None = None + _write_stream_reader: ContextReceiveStream[SessionMessage | RequestSettled] | None = None _security: TransportSecurityMiddleware def __init__( @@ -564,6 +564,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re response_message = None # Use similar approach to SSE writer for consistency + # (no branch: the natural-exit arc IS taken — a settled request closes the + # stream and the loop ends without a response — but coverage.py loses the + # async-for exhaustion arc on Python 3.14.) async for event_message in request_stream_reader: # pragma: no branch # If it's a response, this is what we're waiting for if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): @@ -573,19 +576,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re else: # pragma: no cover logger.debug(f"received: {event_message.message.method}") - # At this point we should have a response if response_message: # Create JSON response response = self._create_json_response(response_message) - await response(scope, receive, send) - else: # pragma: no cover - # This shouldn't happen in normal operation - logger.error("No response message received before stream closed") - response = self._create_error_response( - "Error processing request: No response received", - HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) + else: + # The request ended with no JSON-RPC frame to carry (peer-cancelled + # and the handler wrote nothing, or the session tore down + # mid-request): end the exchange with 204 No Content. + response = self._create_json_response(None, HTTPStatus.NO_CONTENT) + await response(scope, receive, send) except Exception: # pragma: no cover logger.exception("Error processing JSON response") response = self._create_error_response( @@ -952,7 +951,7 @@ async def connect( ) -> AsyncGenerator[ tuple[ ReadStream[SessionMessage | Exception], - WriteStream[SessionMessage], + WriteStream[SessionMessage | RequestSettled], ], None, ]: @@ -965,7 +964,7 @@ async def connect( # Create the memory streams for this connection read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) - write_stream, write_stream_reader = create_context_streams[SessionMessage](0) + write_stream, write_stream_reader = create_context_streams[SessionMessage | RequestSettled](0) # Store the streams self._read_stream_writer = read_stream_writer @@ -979,6 +978,15 @@ async def connect( async def message_router(): try: async for session_message in write_stream_reader: # pragma: no branch + if isinstance(session_message, RequestSettled): + # The dispatcher ended this request with no reply (peer-cancelled). + # Close the per-POST stream's send side: the SSE writer / JSON wait + # loop finish their iteration and clean up in their own finally. + # Nothing is stored or sent on the wire. Keyed by str(id), matching + # the registration at POST time. + if (streams := self._request_streams.get(str(session_message.request_id))) is not None: + streams[0].close() + continue # Determine which request stream(s) should receive this message message = session_message.message target_request_id = None diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index d1e3b3c82..9ea1690e4 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -20,6 +20,11 @@ def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: Returns: Canonical resource URL string + + Raises: + ValueError: If the URL's port is non-numeric or out of range. RFC 3986's + grammar puts no upper bound on port digits, so such URLs can arrive + from outside; callers passing untrusted input must handle this. """ # Convert to string if needed url_str = str(url) @@ -80,9 +85,20 @@ def check_token_audience(token_resource: str, server_resource: str | HttpUrl | A origin is NOT for this server. Contrast check_resource_allowed, which is the client-side hierarchical question and intentionally more permissive. """ - return resource_url_from_server_url(token_resource).rstrip("/") == resource_url_from_server_url( - server_resource - ).rstrip("/") + try: + token_canonical = resource_url_from_server_url(token_resource) + except ValueError: + # An audience we cannot canonicalize does not identify this server. The + # server side stays unwrapped: it is AnyHttpUrl-validated at config time, + # and a garbage own-config URL should fail loudly, not silently 401. + return False + # The rstrip is deliberate trailing-slash tolerance, not 3986 equivalence: + # authorization.mdx's canonical-URI note expects both spellings of one resource + # to circulate (recommending the slashless form for interop), and pydantic's + # AnyHttpUrl forces a root slash (str(AnyHttpUrl("https://h")) == "https://h/") + # while the spec's own example token request sends resource=https://h — without + # this, every root-path deployment would 401 spec-conformant clients. + return token_canonical.rstrip("/") == resource_url_from_server_url(server_resource).rstrip("/") def calculate_token_expiry(expires_in: int | str | None) -> float | None: diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index ec67248c5..09d23e93f 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -107,6 +107,11 @@ class DirectDispatcher: inbound requests fail the peer's call the same way instead of invoking the handler. Notifications are fire-and-forget in both directions: after close they are silently dropped. + + Handler-raised `MCPError` subclasses flatten to plain `MCPError` with equal + `ErrorData` on every dispatch path, matching the wire, where subclass + identity cannot survive; callers needing the subclass rehydrate it from + `MCPError.error` (e.g. `UrlElicitationRequiredError.from_error`). """ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions: bool = True): diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 37ef6721d..af63732d8 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -44,6 +44,7 @@ from mcp.shared.message import ( ClientMessageMetadata, MessageMetadata, + RequestSettled, ServerMessageMetadata, SessionMessage, ) @@ -259,8 +260,8 @@ class JSONRPCDispatcher(Dispatcher[TransportT]): def __init__( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception | RequestSettled], + write_stream: WriteStream[SessionMessage | RequestSettled], *, transport_builder: Callable[[MessageMetadata], TransportT] | None = None, peer_cancel_mode: PeerCancelMode = "interrupt", @@ -501,7 +502,7 @@ async def run( async def _dispatch( self, - item: SessionMessage | Exception, + item: SessionMessage | Exception | RequestSettled, on_request: OnRequest, on_notify: OnNotify, sender_ctx: contextvars.Context | None, @@ -511,6 +512,14 @@ async def _dispatch( Only `inline_methods` requests and the `on_stream_exception` observer are awaited; any other `await` would head-of-line block the read loop. """ + if isinstance(item, RequestSettled): + # Peer dispatcher ended a request without a reply. Reaches us only over + # the direct in-memory pair (wire transports strip it); nothing to do — + # the side that cancelled retired its waiter at cancel time, and a + # hand-built cancel that didn't gets the same silence as every wire + # transport. + logger.debug("peer settled request %r without a response", item.request_id) + return if isinstance(item, Exception): if self.on_stream_exception is None: logger.debug("transport yielded exception: %r", item) @@ -687,6 +696,10 @@ async def _handle_request( The single exception-to-wire boundary: handler exceptions become `JSONRPCError` here. """ answer_write_started = False + # Tuple-bound so the absorbed-cancel path, which continues past the + # `with scope:` with `result` unbound, has a sound "did the handler + # return?" signal. + produced: tuple[dict[str, Any]] | None = None try: with scope: try: @@ -699,21 +712,22 @@ async def _handle_request( key = _coerce_id(req.id) if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: del self._in_flight[key] - # A write interrupted by cancellation may still have delivered - # (a memory-stream send can hand its item to the receiver and - # still raise), so a started answer write counts as sent below: - # peers drop late responses, while a second answer for one id - # would break JSON-RPC. + produced = (result,) + # Past the scope: a still-pending peer cancel was revoked with the + # scope (anyio drops an undelivered cancellation at scope exit), so a + # handler that survived the cancel and returned gets its response + # delivered — the spec's MAY-ignore-and-respond arm. This is the one + # place a peer-cancelled request can still produce a reply; everywhere + # else the cancel interrupted the handler and nothing is written for + # the id. + if produced is not None: + # A teardown cancel can interrupt the write and may still have + # delivered (a memory-stream send can hand its item to the + # receiver and still raise), so a started answer write counts as + # sent below: peers drop late responses, while a second answer + # for one id would break JSON-RPC. answer_write_started = True - await self._write_result(req.id, result) - if scope.cancelled_caught: - # anyio absorbs the scope's own cancel at __exit__, and - # `cancelled_caught` (unlike `cancel_called`) guarantees the - # result write above did not happen. Spec: receivers SHOULD NOT - # respond to a cancelled request; the sender retired its own - # waiter when it cancelled (`send_raw_request` finally), so no - # reply is needed to unblock it. - return + await self._write_result(req.id, produced[0]) except anyio.get_cancelled_exc_class(): # Shutdown: answer the request so the peer isn't left waiting - unless # an answer write already started (it may have reached the transport; @@ -734,6 +748,23 @@ async def _handle_request( await self._write_error(req.id, error) if unexpected and self._raise_handler_exceptions: raise + if scope.cancelled_caught: + # The peer cancel interrupted the handler. With the result write + # outside the scope and no checkpoint between handler return and scope + # exit, `cancelled_caught` structurally means nothing was produced or + # written for this id. Spec: receivers SHOULD NOT respond to a + # cancelled request; the sender retired its own waiter when it + # cancelled (`send_raw_request` finally), so no reply is needed to + # unblock it. The settled marker (never serialized) lets transports + # release per-request resources — without it the legacy + # streamable-HTTP POST waits forever for a reply that never comes. + # Sits outside the try: the send parks on a 0-buffer stream, and a + # teardown cancel landing mid-send must propagate instead of being + # misread by the cancelled arm as an unanswered request. + try: + await self._write_stream.send(RequestSettled(request_id=req.id)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped settled marker for %r: write stream closed", req.id) # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. def _allocate_id(self) -> int: diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 01cab77c8..93fd47ee0 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -7,9 +7,15 @@ from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage -MessageStream = tuple[ContextReceiveStream[SessionMessage | Exception], ContextSendStream[SessionMessage | Exception]] +# The full triple union is real only here: the streams are cross-connected +# dispatcher-to-dispatcher with no pump task, so the peer's read loop is the +# consumer of `RequestSettled` markers (it drops them). +MessageStream = tuple[ + ContextReceiveStream[SessionMessage | Exception | RequestSettled], + ContextSendStream[SessionMessage | Exception | RequestSettled], +] @asynccontextmanager @@ -21,8 +27,12 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = create_context_streams[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = create_context_streams[SessionMessage | Exception](1) + server_to_client_send, server_to_client_receive = create_context_streams[ + SessionMessage | Exception | RequestSettled + ](1) + client_to_server_send, client_to_server_receive = create_context_streams[ + SessionMessage | Exception | RequestSettled + ](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 236569fac..11504729b 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -4,7 +4,7 @@ to support transport-specific features like resumability. """ -from collections.abc import Awaitable, Callable +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from dataclasses import dataclass from typing import Any @@ -52,3 +52,28 @@ class SessionMessage: message: JSONRPCMessage metadata: MessageMetadata = None + + +@dataclass(slots=True, frozen=True) +class RequestSettled: + """An inbound request finished without any JSON-RPC reply being written. + + Emitted by the dispatcher (only) when a peer cancellation interrupted the + handler — the spec says receivers SHOULD NOT respond to a cancelled + request. Transport-internal: transports with per-request resources (the + legacy streamable-HTTP per-POST stream) consume it to end the exchange; + serializing transports strip it via `wire_messages`. It is never put on + any wire. + """ + + request_id: RequestId + + +async def wire_messages( + stream: AsyncIterable[SessionMessage | RequestSettled], +) -> AsyncIterator[SessionMessage]: + """Yield only serializable frames, stripping dispatcher lifecycle markers.""" + async for item in stream: + if isinstance(item, RequestSettled): + continue + yield item diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 322d02d22..abd5433b7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1510,6 +1510,251 @@ async def mock_callback() -> AuthorizationCodeResult: pass +def test_scope_selection_falls_back_to_the_client_metadata_scope_when_challenge_and_prm_are_silent(): + """SDK-defined fallback (TypeScript-SDK parity): when neither the WWW-Authenticate challenge + nor the PRM names scopes, the caller's pre-configured client-metadata scope is selected + instead of omitting the parameter. + """ + assert get_client_metadata_scopes(None, None, client_metadata_scope="custom:scope") == "custom:scope" + + +def test_scope_selection_ranks_the_client_metadata_scope_below_challenge_and_prm_scopes(): + """Spec-mandated priority: the challenge scope and PRM `scopes_supported` both outrank the + caller's pre-configured client-metadata scope, which is only the final fallback. + """ + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + scopes_supported=["resource:read", "resource:write"], + ) + assert get_client_metadata_scopes("from:header", prm, client_metadata_scope="custom:scope") == "from:header" + assert get_client_metadata_scopes(None, prm, client_metadata_scope="custom:scope") == "resource:read resource:write" + + +def test_the_offline_access_augmentation_applies_to_the_fallback_client_metadata_scope(): + """SEP-2207: `offline_access` is appended to whichever scope was selected, including the + SDK-fallback client-metadata scope. + """ + asm = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + scopes_supported=["offline_access"], + ) + assert ( + get_client_metadata_scopes(None, None, asm, ["authorization_code", "refresh_token"], "custom:scope") + == "custom:scope offline_access" + ) + + +@pytest.mark.anyio +async def test_an_explicit_client_metadata_scope_survives_a_scopeless_challenge_and_prm( + oauth_provider: OAuthClientProvider, +): + """An explicit `OAuthClientMetadata.scope` ("read write" on the fixture) reaches the + registration body and the authorize URL when neither the 401 challenge nor the PRM names + scopes (SDK fallback, TypeScript-SDK parity). The AS metadata advertises a different + `scopes_supported` so a regression to the removed AS-metadata fallback fails the assertions. + + Steps: + 1. 401 challenge without `scope` + 2. PRM discovery -> no `scopes_supported` + 3. ASM discovery -> `scopes_supported: ["as-advertised"]`, S256 PKCE + 4. DCR -> registration body carries the explicit scope + 5. authorize redirect -> URL carries the explicit scope + 6. token exchange, retried request, flow completes + """ + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + captured_state: str | None = None + authorize_scope: str | None = None + + async def capture_redirect(url: str) -> None: + nonlocal captured_state, authorize_scope + params = parse_qs(urlparse(url).query) + authorize_scope = params["scope"][0] + captured_state = params.get("state", [None])[0] + + async def callback() -> AuthorizationCodeResult: + return AuthorizationCodeResult(code="auth_code", state=captured_state) + + oauth_provider.context.redirect_handler = capture_redirect + oauth_provider.context.callback_handler = callback + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + await auth_flow.__anext__() + + response_401 = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + prm_request = await auth_flow.asend(response_401) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + + asm_request = await auth_flow.asend(prm_response) + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register", ' + b'"scopes_supported": ["as-advertised"], ' + b'"code_challenge_methods_supported": ["S256"]}' + ), + request=asm_request, + ) + + registration_request = await auth_flow.asend(asm_response) + assert json.loads(registration_request.content)["scope"] == "read write" + registration_response = httpx.Response( + 201, + content=( + b'{"client_id": "test_client_id", "client_secret": "test_client_secret", ' + b'"redirect_uris": ["http://localhost:3030/callback"]}' + ), + request=registration_request, + ) + + token_request = await auth_flow.asend(registration_response) + assert authorize_scope == "read write" + token_response = httpx.Response( + 200, + json={"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600}, + request=token_request, + ) + final_request = await auth_flow.asend(token_response) + try: + await auth_flow.asend(httpx.Response(200, request=final_request)) + except StopAsyncIteration: + pass + + assert oauth_provider.context.client_metadata.scope == "read write" + + +@pytest.mark.anyio +async def test_a_challenged_scope_replaces_the_explicit_scope_and_seeds_the_next_reauth( + oauth_provider: OAuthClientProvider, +): + """A 401 challenge's scope wins over the explicit `OAuthClientMetadata.scope`, and the + selection is written back into `client_metadata.scope`: a second 401 whose challenge and PRM + are scope-silent falls back to the previously challenged scope, not the constructor value. + The write-back retention is intended behaviour — it mirrors the SEP-2350 step-up union on + the 403 path, which likewise folds prior selections forward. + + Steps: + 1. 401 with scope="granted:scope" -> authorize URL carries it, not the fixture's "read write" + 2. flow completes -> the selection is written back into client_metadata.scope + 3. second 401 without scope, PRM still scope-less -> authorize URL carries "granted:scope" + """ + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + captured_state: str | None = None + authorize_scopes: list[str] = [] + + async def capture_redirect(url: str) -> None: + nonlocal captured_state + params = parse_qs(urlparse(url).query) + authorize_scopes.append(params["scope"][0]) + captured_state = params.get("state", [None])[0] + + async def callback() -> AuthorizationCodeResult: + return AuthorizationCodeResult(code="auth_code", state=captured_state) + + oauth_provider.context.redirect_handler = capture_redirect + oauth_provider.context.callback_handler = callback + + prm_content = ( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ) + asm_content = ( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register", ' + b'"code_challenge_methods_supported": ["S256"]}' + ) + + # First flow: the challenge names a scope, overriding the fixture's "read write". + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + await auth_flow.__anext__() + response_401 = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer scope="granted:scope", ' + 'resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + prm_request = await auth_flow.asend(response_401) + asm_request = await auth_flow.asend(httpx.Response(200, content=prm_content, request=prm_request)) + registration_request = await auth_flow.asend(httpx.Response(200, content=asm_content, request=asm_request)) + registration_response = httpx.Response( + 201, + content=( + b'{"client_id": "test_client_id", "client_secret": "test_client_secret", ' + b'"redirect_uris": ["http://localhost:3030/callback"]}' + ), + request=registration_request, + ) + token_request = await auth_flow.asend(registration_response) + token_response = httpx.Response( + 200, + json={"access_token": "token_1", "token_type": "Bearer", "expires_in": 3600}, + request=token_request, + ) + final_request = await auth_flow.asend(token_response) + try: + await auth_flow.asend(httpx.Response(200, request=final_request)) + except StopAsyncIteration: + pass + + assert authorize_scopes == ["granted:scope"] + assert oauth_provider.context.client_metadata.scope == "granted:scope" + + # Second flow: challenge and PRM are scope-silent, so the fallback reads the prior selection. + test_request_2 = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow_2 = oauth_provider.async_auth_flow(test_request_2) + await auth_flow_2.__anext__() + response_401_2 = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request_2, + ) + prm_request_2 = await auth_flow_2.asend(response_401_2) + asm_request_2 = await auth_flow_2.asend(httpx.Response(200, content=prm_content, request=prm_request_2)) + # The client is already registered, so the flow goes straight to authorization. + token_request_2 = await auth_flow_2.asend(httpx.Response(200, content=asm_content, request=asm_request_2)) + token_response_2 = httpx.Response( + 200, + json={"access_token": "token_2", "token_type": "Bearer", "expires_in": 3600}, + request=token_request_2, + ) + final_request_2 = await auth_flow_2.asend(token_response_2) + try: + await auth_flow_2.asend(httpx.Response(200, request=final_request_2)) + except StopAsyncIteration: + pass + + assert authorize_scopes == ["granted:scope", "granted:scope"] + + @pytest.mark.parametrize( ( "issuer_url", diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 418a6bc54..e5236709b 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -42,9 +42,11 @@ async def handle_mcp_request(request: Request) -> Response: if data.get("method") == "initialize": return _init_json_response(data) - # For notifications, return 204 No Content (non-SDK behavior) + # For notifications, return a bare 200 (non-SDK behavior; 204 now has a defined + # meaning on this transport — "request settled with no reply" — so the unexpected + # status exercised here is one with no assigned semantics). if "id" not in data: - return Response(status_code=204, headers={"Content-Type": "application/json"}) + return Response(status_code=200, headers={"Content-Type": "application/json"}) return JSONResponse( # pragma: no cover {"jsonrpc": "2.0", "id": data.get("id"), "error": {"code": -32601, "message": "Method not found"}} @@ -77,8 +79,8 @@ async def test_non_compliant_notification_response() -> None: The spec states notifications should get either 202 + no response body, or 4xx + optional error body (https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server), - but some servers wrongly return other 2xx codes (e.g. 204). For now we simply ignore unexpected responses - (aligning behaviour w/ the TS SDK). + but some servers wrongly return other 2xx codes (e.g. a bare 200). For now we simply ignore + unexpected responses (aligning behaviour w/ the TS SDK). """ returned_exception = None @@ -94,7 +96,7 @@ async def message_handler( # pragma: no cover async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: await session.initialize() - # The test server returns a 204 instead of the expected 202 + # The test server returns a bare 200 instead of the expected 202 await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) if returned_exception: # pragma: no cover diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 5eaa5eff7..d4301f81b 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -42,12 +42,12 @@ from mcp.server import Server, ServerRequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext _SendToClient = anyio.streams.memory.MemoryObjectSendStream[SessionMessage | Exception] -_RecvFromClient = anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage] +_RecvFromClient = anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage | RequestSettled] @asynccontextmanager @@ -60,7 +60,7 @@ async def raw_client_session( transport-level exceptions. No initialize handshake is performed. """ s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | RequestSettled](32) async with ClientSession(s2c_recv, c2s_send, **kwargs) as session: try: with anyio.fail_after(5): @@ -72,8 +72,12 @@ async def raw_client_session( @pytest.mark.anyio async def test_client_session_initialize(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) initialized_notification = None result = None @@ -82,6 +86,7 @@ async def mock_server(): nonlocal initialized_notification session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -113,6 +118,7 @@ async def mock_server(): ) ) session_notification = await client_to_server_receive.receive() + assert isinstance(session_notification, SessionMessage) jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification, JSONRPCNotification) initialized_notification = client_notification_adapter.validate_python( @@ -155,8 +161,12 @@ async def message_handler( # pragma: no cover @pytest.mark.anyio async def test_client_session_custom_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) custom_client_info = Implementation(name="test-client", version="1.2.3") received_client_info = None @@ -165,6 +175,7 @@ async def mock_server(): nonlocal received_client_info session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -213,8 +224,12 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) received_client_info = None @@ -222,6 +237,7 @@ async def mock_server(): nonlocal received_client_info session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -267,12 +283,17 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_success(): """Test successful version negotiation with supported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) result = None async def mock_server(): session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -323,11 +344,16 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_failure(): """Test version negotiation failure with unsupported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) async def mock_server(): session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -371,8 +397,12 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_default(): """Test that client capabilities are properly set with default callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) received_capabilities = None @@ -380,6 +410,7 @@ async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -427,8 +458,12 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_with_custom_callbacks(): """Test that client capabilities are properly set with custom callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) received_capabilities = None @@ -451,6 +486,7 @@ async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -511,8 +547,12 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_with_sampling_tools(): """Test that sampling capabilities with tools are properly advertised""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) received_capabilities = None @@ -530,6 +570,7 @@ async def mock_server(): nonlocal received_capabilities session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -585,8 +626,12 @@ async def mock_server(): @pytest.mark.anyio async def test_initialize_result(): """Test that initialize_result is None before init and contains the full result after.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) expected_capabilities = ServerCapabilities( logging=types.LoggingCapability(), @@ -599,6 +644,7 @@ async def test_initialize_result(): async def mock_server(): session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -658,14 +704,19 @@ async def mock_server(): @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None): """Test that client tool call requests can include metadata""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage | RequestSettled + ](1) mocked_tool = types.Tool(name="sample_tool", input_schema={"type": "object"}) async def mock_server(): # Receive initialization request from client session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) request = client_request_adapter.validate_python( @@ -695,6 +746,7 @@ async def mock_server(): # Wait for the client to send a 'tools/call' request session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -721,6 +773,7 @@ async def mock_server(): # Wait for the tools/list request from the client # The client requires this step to validate the tool output schema session_message = await client_to_server_receive.receive() + assert isinstance(session_message, SessionMessage) jsonrpc_request = session_message.message assert isinstance(jsonrpc_request, JSONRPCRequest) @@ -763,6 +816,7 @@ async def test_receive_loop_answers_malformed_inbound_request_with_invalid_param SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="sampling/createMessage", params={"broken": 1})) ) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCError) assert out.message.id == 7 assert out.message.error.code == INVALID_PARAMS @@ -774,6 +828,7 @@ async def test_receive_loop_answers_unknown_request_method_with_method_not_found async with raw_client_session() as (_session, to_client, from_client): await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCError) assert out.message.id == 7 assert out.message.error == types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown") @@ -787,6 +842,7 @@ async def test_receive_loop_drops_unknown_notification_method_without_response() # The answered follow-up ping proves no response was emitted and the loop survived. await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCResponse) assert out.message.id == 1 @@ -812,6 +868,7 @@ async def test_on_request_rejects_a_server_request_absent_at_the_negotiated_vers SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="elicitation/create", params={"message": "hi"})) ) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCError) assert out.message.error.code == METHOD_NOT_FOUND assert out.message.error.data == "elicitation/create" @@ -835,6 +892,7 @@ async def sampling( SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=2, method="sampling/createMessage", params=request_params)) ) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCResponse) assert out.message.result == {"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"} @@ -853,6 +911,7 @@ async def list_roots(ctx: ClientRequestContext) -> types.ListRootsResult | types async with raw_client_session(list_roots_callback=list_roots) as (_session, to_client, from_client): await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=3, method="roots/list"))) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCError) assert out.message.error.code == INTERNAL_ERROR assert out.message.error.message == "Client callback returned an invalid result" @@ -913,6 +972,7 @@ async def elicitation(ctx: ClientRequestContext, params: types.ElicitRequestPara SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=4, method="elicitation/create", params=request_params)) ) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCResponse) assert out.message.result == {"action": "accept", "content": {"x": 1}} assert len(seen) == 1 @@ -931,6 +991,7 @@ async def call() -> None: tg.start_soon(call) request = await from_client.receive() + assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) await to_client.send( SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={"tools": "nope"})) @@ -951,6 +1012,7 @@ async def call() -> None: tg.start_soon(call) request = await from_client.receive() + assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={}))) @@ -973,6 +1035,7 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=8, method="sampling/createMessage", params=params)) ) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCError) assert out.message.error == types.ErrorData(code=INTERNAL_ERROR, message="Internal server error") @@ -1023,6 +1086,7 @@ async def handler(msg: object) -> None: await delivered.wait() await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=9, method="ping"))) out = await from_client.receive() + assert isinstance(out, SessionMessage) assert seen == [exc] assert isinstance(out.message, JSONRPCResponse) assert out.message.id == 9 @@ -1046,6 +1110,7 @@ async def handler(msg: object) -> None: await to_client.send(ValueError("bad bytes")) # Serve the handler's ping like a transport would; inline delivery would deadlock here. out = await from_client.receive() + assert isinstance(out, SessionMessage) assert isinstance(out.message, JSONRPCRequest) assert out.message.method == "ping" await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={}))) @@ -1126,6 +1191,7 @@ async def call() -> None: tg.start_soon(call) request = await from_client.receive() + assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) request_id = request.message.id # The request id doubles as the progress token. @@ -1436,7 +1502,7 @@ async def test_send_notification_after_close_is_dropped_silently(): """Post-close `send_notification` is fire-and-forget: the notification is dropped, not surfaced as a raw transport error (v1 leaked `anyio.ClosedResourceError`).""" s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | RequestSettled](4) try: async with ClientSession(s2c_recv, c2s_send) as session: pass diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0b0695378..33eddee55 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -42,7 +42,7 @@ from mcp.os.posix.utilities import terminate_posix_process_tree from mcp.os.win32.utilities import FallbackProcess from mcp.shared.exceptions import MCPError -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage # --------------------------------------------------------------------------- # In-process fake of the spawned server process @@ -219,7 +219,7 @@ def _line(message: JSONRPCMessage) -> bytes: return (message.model_dump_json(by_alias=True, exclude_unset=True) + "\n").encode() -async def _next_message(read_stream: ReadStream[SessionMessage | Exception]) -> JSONRPCMessage: +async def _next_message(read_stream: ReadStream[SessionMessage | Exception | RequestSettled]) -> JSONRPCMessage: received = await read_stream.receive() assert isinstance(received, SessionMessage) return received.message diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 0641aeab9..90b644d3e 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -13,7 +13,7 @@ from typing_extensions import Self from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage from mcp.shared.session import RequestResponder # TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), @@ -28,12 +28,19 @@ class _RecordingReadStream: """Delegates to a read stream, appending every received message to a log.""" - def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + def __init__( + self, + inner: ReadStream[SessionMessage | Exception | RequestSettled], + log: list[SessionMessage | Exception], + ) -> None: self._inner = inner self._log = log - async def receive(self) -> SessionMessage | Exception: + async def receive(self) -> SessionMessage | Exception | RequestSettled: item = await self._inner.receive() + # None of the recorded suites cancel, so no `RequestSettled` ever crosses this seam; + # if one does, the recording (which exists to pin wire payloads) needs a story for it. + assert not isinstance(item, RequestSettled) self._log.append(item) return item @@ -43,7 +50,7 @@ async def aclose(self) -> None: def __aiter__(self) -> Self: return self - async def __anext__(self) -> SessionMessage | Exception: + async def __anext__(self) -> SessionMessage | Exception | RequestSettled: try: return await self.receive() except anyio.EndOfStream: @@ -62,13 +69,16 @@ async def __aexit__( class _RecordingWriteStream: """Delegates to a write stream, appending every sent message to a log.""" - def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + def __init__(self, inner: WriteStream[SessionMessage | RequestSettled], log: list[SessionMessage]) -> None: self._inner = inner self._log = log - async def send(self, item: SessionMessage, /) -> None: + async def send(self, item: SessionMessage | RequestSettled, /) -> None: # Record only after the inner send returns: a failed or cancelled send never reached the transport. await self._inner.send(item) + # None of the recorded suites cancel, so no `RequestSettled` ever crosses this seam + # (see `_RecordingReadStream.receive`). + assert isinstance(item, SessionMessage) self._log.append(item) async def aclose(self) -> None: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index c8986f28d..c2de7668e 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -2819,6 +2819,51 @@ def __post_init__(self) -> None: transports=("streamable-http",), note="Only observable over HTTP: POST-body framing is HTTP-specific.", ), + "hosting:http:cancel-ends-post-sse-stream": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "After notifications/cancelled stops a request's handler, the original POST's SSE stream " + "terminates without ever carrying a response for the cancelled id." + ), + transports=("streamable-http",), + note=( + "Only observable over HTTP: SSE stream lifecycle is HTTP-specific. The no-response half is " + "spec-mandated (receivers SHOULD NOT respond to a cancelled request); terminating the " + "now-permanently-silent stream is SDK-defined — the spec has no spelling for a request that " + "settles without a response." + ), + ), + "hosting:http:cancel-json-mode-204": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "In JSON response mode, a POST whose request was cancelled mid-handler completes with " + "204 No Content and an empty body instead of holding the connection open forever." + ), + transports=("streamable-http",), + note="Only observable over HTTP: 204 is an HTTP status code.", + divergence=Divergence( + note=( + "The transports section's MUST offers only text/event-stream or a single JSON object as " + "the response to a request POST; the spec has no spelling for 'request settled with no " + "response' in JSON mode, so the SDK answers 204 with no body. Tracked for upstreaming." + ), + ), + ), + "hosting:http:cancel-receipt-keeps-stream-open": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Receiving notifications/cancelled does not by itself end the request's exchange: a handler " + "that ignores the cancellation (the spec's MAY arm) still streams related notifications on " + "the open POST stream and has its response delivered; the exchange ends only after that " + "response." + ), + transports=("streamable-http",), + note=( + "Only observable over HTTP: exchange lifecycle is HTTP-specific. Shielding the handler body " + "(anyio.CancelScope(shield=True)) is the SDK-defined way for a handler to take the spec's " + "MAY-ignore-and-respond arm." + ), + ), "hosting:http:content-type-415": Requirement( source="sdk", behavior="A POST with a Content-Type other than application/json returns 415.", @@ -3085,6 +3130,16 @@ def __post_init__(self) -> None: # ═══════════════════════════════════════════════════════════════════════════ # Client transport: streamable HTTP # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:204-settled-exchange": Requirement( + source="sdk", + behavior=( + "A 204 No Content response to a request POST is consumed as 'request settled with no reply': " + "the transport's request task completes without synthesizing a response or an error, and the " + "session continues to serve requests." + ), + transports=("streamable-http",), + note="Only observable over HTTP: 204 is an HTTP status code.", + ), "client-transport:http:404-surfaces": Requirement( source="sdk", behavior="A 404 (session expired) on a request surfaces as an error to the caller.", diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py index 23fae5eca..5d5daf7bb 100644 --- a/tests/interaction/auth/test_as_handlers.py +++ b/tests/interaction/auth/test_as_handlers.py @@ -279,24 +279,45 @@ async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( @requirement("hosting:auth:as:redirect-uri-scheme") @pytest.mark.parametrize( - "redirect_uri", + ("redirect_uri", "rejection"), [ - "http://evil.example/callback", - "http://localhost.evil.example/callback", - "javascript:alert(1)", - "com.example.app:/oauth/cb", + ( + "http://evil.example/callback", + snapshot("redirect_uri must use https or be a loopback http URI: http://evil.example/callback"), + ), + ( + "http://localhost.evil.example/callback", + snapshot("redirect_uri must use https or be a loopback http URI: http://localhost.evil.example/callback"), + ), + ("javascript:alert(1)", snapshot("redirect_uri must use https or be a loopback http URI: javascript:alert(1)")), + ( + "com.example.app:/oauth/cb", + snapshot("redirect_uri must use https or be a loopback http URI: com.example.app:/oauth/cb"), + ), + ("ftp://127.0.0.1/cb", snapshot("redirect_uri must use https or be a loopback http URI: ftp://127.0.0.1/cb")), + ("ws://localhost/cb", snapshot("redirect_uri must use https or be a loopback http URI: ws://localhost/cb")), + ( + "javascript://localhost/%0aalert(1)", + snapshot("redirect_uri must use https or be a loopback http URI: javascript://localhost/%0aalert(1)"), + ), + ( + "custom-scheme://localhost/cb", + snapshot("redirect_uri must use https or be a loopback http URI: custom-scheme://localhost/cb"), + ), + ("ftp://[::1]/cb", snapshot("redirect_uri must use https or be a loopback http URI: ftp://[::1]/cb")), ], ) -async def test_a_redirect_uri_that_is_neither_https_nor_loopback_is_rejected_at_registration( - as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str +async def test_a_redirect_uri_that_is_neither_https_nor_loopback_http_is_rejected_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str, rejection: str ) -> None: - """A registration whose redirect URI is neither HTTPS nor a loopback host is rejected with 400. - - The spec requires every redirect URI to be either HTTPS or a loopback host; the - registration request model enforces this at parse time so the provider never sees the - client. Loopback is matched on the whole host (`localhost.evil.example` is not loopback), - and a scheme with no authority — `javascript:`, or an RFC 8252 private-use scheme such as - `com.example.app:` — fails the same check. + """A registration whose redirect URI is neither HTTPS nor loopback HTTP is rejected with 400. + + OAuth 2.1's only carve-out from HTTPS redirect URIs is plain HTTP on a loopback host; the + registration request model enforces scheme and host together at parse time so the provider + never sees the client. Loopback is matched on the whole host (`localhost.evil.example` is + not loopback), a loopback host does not launder a non-HTTP scheme (`ftp://127.0.0.1`, + `javascript://localhost`), and a scheme with no authority — `javascript:`, or an RFC 8252 + private-use scheme such as `com.example.app:` — fails the same check. """ http, provider = as_app body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) @@ -307,9 +328,9 @@ async def test_a_redirect_uri_that_is_neither_https_nor_loopback_is_rejected_at_ assert response.status_code == 400 error = response.json() assert error["error"] == "invalid_client_metadata" - # Pydantic frames the validator's message as `redirect_uris: Value error, ` (third-party - # text), so assert only the SDK-authored sentence to pin which validation fired. - assert "redirect_uri must use https or target a loopback host" in error["error_description"] + # Substring: pydantic wraps the SDK validator's sentence in its own framing + # (`redirect_uris: Value error, …`), which is deliberately not pinned. + assert rejection in error["error_description"] assert provider.clients == {} @@ -323,10 +344,10 @@ async def test_a_redirect_uri_that_is_neither_https_nor_loopback_is_rejected_at_ "http://[::1]:8000/callback", ], ) -async def test_an_https_or_loopback_redirect_uri_is_accepted_at_registration( +async def test_an_https_or_loopback_http_redirect_uri_is_accepted_at_registration( as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], redirect_uri: str ) -> None: - """A registration whose redirect URI uses HTTPS or targets a loopback host is accepted and stored. + """A registration whose redirect URI uses HTTPS or plain HTTP on a loopback host is accepted and stored. Loopback covers exactly the three forms OAuth 2.1 names: the hostname `localhost` and the loopback IP literals `127.0.0.1` and `[::1]`, on any port, over plain HTTP. diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py index fd645a37e..f0e5b117c 100644 --- a/tests/interaction/auth/test_authorize_token.py +++ b/tests/interaction/auth/test_authorize_token.py @@ -28,7 +28,12 @@ from mcp.client.auth import OAuthFlowError from mcp.server import Server, ServerRequestContext -from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata, ProtectedResourceMetadata +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + ProtectedResourceMetadata, +) from tests.interaction._connect import BASE_URL from tests.interaction._requirements import requirement from tests.interaction.auth._harness import ( @@ -406,6 +411,109 @@ async def test_scope_is_omitted_when_neither_the_challenge_nor_prm_supply_scopes assert "scope" not in json.loads(register.content) +@requirement("client-auth:scope-selection:priority") +async def test_an_explicit_client_metadata_scope_is_requested_when_challenge_and_prm_supply_none() -> None: + """When the challenge and the PRM are both scope-silent, the configured `OAuthClientMetadata.scope` + is requested instead of omitting the parameter. + + SDK-defined fallback (TypeScript-SDK parity), ranked one step below the spec's two-step + chain. The served AS metadata advertises a different `scopes_supported` so a regression to + the removed AS-metadata fallback fails the assertions; `valid_scopes` includes the explicit + scope so the SDK's registration handler accepts it, and token verification is off so the + granted scope need not satisfy the bearer gate. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(valid_scopes=["mcp", "custom:explicit"]) + client_metadata = OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="custom:explicit", + ) + challenge = f'Bearer resource_metadata="{BASE_URL}{PRM_PATH}"' + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + asm = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp", "as-advertised"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + serve = { + PRM_PATH: prm.model_dump_json(exclude_none=True).encode(), + ASM_PATH: asm.model_dump_json(exclude_none=True).encode(), + } + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + client_metadata=client_metadata, + verify_tokens=False, + app_shim=lambda app: first_challenge_shim(challenge)(shimmed_app(app, serve=serve)), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "custom:explicit" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "custom:explicit" + + +@requirement("client-auth:scope-selection:priority") +async def test_prm_scopes_win_over_an_explicit_client_metadata_scope() -> None: + """`scopes_supported` from the PRM outranks the configured `OAuthClientMetadata.scope`. + + The spec's chain is consulted first; the explicit scope is only the final fallback. The + challenge is scope-less (hand-supplied via `first_challenge_shim` because the real bearer + middleware always emits `scope=`), making the PRM the highest non-silent source. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(valid_scopes=["mcp", "from-prm"]) + client_metadata = OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="custom:explicit", + ) + challenge = f'Bearer resource_metadata="{BASE_URL}{PRM_PATH}"' + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), + authorization_servers=[AnyHttpUrl(BASE_URL)], + scopes_supported=["from-prm"], + ) + serve = {PRM_PATH: prm.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + client_metadata=client_metadata, + verify_tokens=False, + app_shim=lambda app: first_challenge_shim(challenge)(shimmed_app(app, serve=serve)), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-prm" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-prm" + + @requirement("client-auth:pkce:refuse-if-unsupported") async def test_the_flow_aborts_before_any_authorize_redirect_when_as_metadata_omits_code_challenge_methods() -> None: """AS metadata without `code_challenge_methods_supported` aborts the flow before any authorize redirect. diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py index 483cd1087..88384c96c 100644 --- a/tests/interaction/auth/test_bearer.py +++ b/tests/interaction/auth/test_bearer.py @@ -57,6 +57,20 @@ expires_at=_FUTURE, resource="http://127.0.0.1:8000/", ), + "tok-overflow-port-aud": AccessToken( + token="tok-overflow-port-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="http://127.0.0.1:99999/mcp", + ), + "tok-nonnumeric-port-aud": AccessToken( + token="tok-nonnumeric-port-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="http://127.0.0.1:abc/mcp", + ), "tok-no-aud": AccessToken(token="tok-no-aud", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), } @@ -211,6 +225,29 @@ async def test_a_token_for_a_parent_path_on_the_same_origin_is_answered_401_inva } +@requirement("hosting:auth:aud-validation") +@pytest.mark.parametrize("bearer", ["tok-overflow-port-aud", "tok-nonnumeric-port-aud"]) +async def test_a_token_whose_audience_has_an_unparseable_port_is_answered_401_invalid_token( + protected: httpx.AsyncClient, bearer: str +) -> None: + """A token audience with an out-of-range or non-numeric port is a mismatch, not a server error. + + RFC 3986's grammar puts no upper bound on port digits, so an authorization server can + legitimately issue a token for `http://h:99999/mcp`; urllib refuses to parse such ports. + The gate must answer the same 401 `invalid_token` as any other audience mismatch — the + ValueError must never escape the middleware as a 500. + """ + response = await post_mcp(protected, bearer=bearer) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "The access token was issued for a different resource", + "scope": REQUIRED_SCOPE, + "resource_metadata": RESOURCE_METADATA_URL, + } + + @requirement("hosting:auth:aud-validation") async def test_a_token_without_a_resource_claim_is_answered_401_invalid_token( protected: httpx.AsyncClient, diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 6331c2dae..21547279d 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -7,6 +7,7 @@ """ import anyio +import httpx import pytest from anyio.lowlevel import checkpoint from httpx_sse import ServerSentEvent, aconnect_sse @@ -20,6 +21,7 @@ UNSUPPORTED_PROTOCOL_VERSION, CallToolRequestParams, CallToolResult, + CancelledNotificationParams, EmptyResult, JSONRPCError, JSONRPCNotification, @@ -27,6 +29,8 @@ JSONRPCResponse, ListResourcesResult, ListToolsResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, PaginatedRequestParams, SetLevelRequestParams, SubscribeRequestParams, @@ -379,3 +383,228 @@ async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> No assert [event async for event in unguarded.aiter_sse()] assert status == 200 + + +def _tool_call_body(request_id: int, name: str) -> dict[str, object]: + """A wire-level tools/call JSON-RPC request body.""" + return {"jsonrpc": "2.0", "id": request_id, "method": "tools/call", "params": {"name": name, "arguments": {}}} + + +def _cancel_body(request_id: int) -> dict[str, object]: + """A wire-level notifications/cancelled body, exactly as an SDK client would send it.""" + return {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": request_id}} + + +def _blocking_server() -> tuple[Server, anyio.Event, anyio.Event]: + """A server with one tool that blocks until cancelled, plus its started/cancelled events.""" + started = anyio.Event() + cancelled = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + started.set() + try: + await anyio.Event().wait() # blocks until the cancellation interrupts it + except anyio.get_cancelled_exc_class(): + cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + return Server("blocker", on_call_tool=call_tool), started, cancelled + + +async def _initialize_via_http_json(http: httpx.AsyncClient) -> str: + """`initialize_via_http` for a json_response=True server: the answers are JSON bodies, not SSE.""" + response = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + assert response.status_code == 200 + assert JSONRPCResponse.model_validate(response.json()).id == 1 + session_id = response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + +@requirement("hosting:http:cancel-ends-post-sse-stream") +async def test_cancelling_an_in_flight_request_ends_its_post_sse_stream() -> None: + """After notifications/cancelled stops the handler, the original POST's SSE stream terminates + without carrying any frame: no response for the cancelled id is ever sent (spec-mandated) and + the exchange ends instead of holding the connection open forever (SDK-defined). Raw httpx + because the SDK client retires its waiter at cancel time and never observes the POST stream's + fate.""" + server, started, cancelled = _blocking_server() + events: list[ServerSentEvent] = [] + async with mounted_app(server) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def hold_post_open() -> None: + async with aconnect_sse( + http, + "POST", + "/mcp", + json=_tool_call_body(2, "block"), + headers=base_headers(session_id=session_id), + ) as source: + assert source.response.status_code == 200 + events.extend([event async for event in source.aiter_sse()]) + + tg.start_soon(hold_post_open) + await started.wait() + cancel = await http.post("/mcp", json=_cancel_body(2), headers=base_headers(session_id=session_id)) + assert cancel.status_code == 202 + await cancelled.wait() + # The task-group join is the regression assertion: the POST stream must + # terminate once the request settles; before the settled marker existed it + # stayed open forever and the enclosing fail_after fired. + assert parse_sse_messages(events) == [] + + +@requirement("hosting:http:cancel-json-mode-204") +async def test_cancelling_an_in_flight_request_in_json_mode_completes_the_post_with_204() -> None: + """In JSON response mode the cancelled request's POST completes with 204 No Content and an + empty body (SDK-defined; see the requirement's divergence note for the spec gap), and the + session keeps answering subsequent requests. (The release of the per-request stream the + wait loop parked on is pinned in `tests/shared/test_streamable_http.py`.)""" + server, started, cancelled = _blocking_server() + responses: list[httpx.Response] = [] + async with mounted_app(server, json_response=True) as (http, _): + session_id = await _initialize_via_http_json(http) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def post_blocked_request() -> None: + responses.append( + await http.post( + "/mcp", json=_tool_call_body(2, "block"), headers=base_headers(session_id=session_id) + ) + ) + + tg.start_soon(post_blocked_request) + await started.wait() + cancel = await http.post("/mcp", json=_cancel_body(2), headers=base_headers(session_id=session_id)) + assert cancel.status_code == 202 + await cancelled.wait() + # tg join: at HEAD the POST never completed and the fail_after fired here. + (response,) = responses + assert (response.status_code, response.content) == (204, b"") + # The settled exchange is final, not fatal: the same session still answers. + ping = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 3, "method": "ping"}, headers=base_headers(session_id=session_id) + ) + assert ping.status_code == 200 + assert JSONRPCResponse.model_validate(ping.json()).id == 3 + + +@requirement("hosting:http:cancel-receipt-keeps-stream-open") +async def test_cancel_receipt_does_not_end_the_exchange_for_a_handler_that_ignores_it() -> None: + """Steps: + + 1. The tool handler shields its body, taking the spec's MAY-ignore-the-cancellation arm; it + wakes only after the cancel has landed on its request scope. + 2. The cancel POST returns 202, then the related log notification arrives on the still-open + POST SSE stream — receipt of notifications/cancelled did not close the exchange. + 3. The handler's response is delivered after it (the pending cancel cannot eat the result + write), and the stream terminates only then — wire proof the settled marker did not fire. + """ + started = anyio.Event() + cancel_delivered = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "stubborn" + with anyio.CancelScope(shield=True): + started.set() + await cancel_delivered.wait() + await ctx.session.send_notification( + LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="still working")), + related_request_id=ctx.request_id, + ) + return CallToolResult(content=[TextContent(text="finished anyway")]) + + async def on_cancelled(ctx: ServerRequestContext, params: CancelledNotificationParams) -> None: + # Runs after the dispatcher already applied the cancel to the request's scope, + # so the shielded handler provably survives a landed cancellation, not a race. + cancel_delivered.set() + + server = Server("stubborn", on_call_tool=call_tool) + server.add_notification_handler("notifications/cancelled", CancelledNotificationParams, on_cancelled) + + events: list[ServerSentEvent] = [] + async with mounted_app(server) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def hold_post_open() -> None: + async with aconnect_sse( + http, + "POST", + "/mcp", + json=_tool_call_body(2, "stubborn"), + headers=base_headers(session_id=session_id), + ) as source: + assert source.response.status_code == 200 + events.extend([event async for event in source.aiter_sse()]) + + tg.start_soon(hold_post_open) + await started.wait() + cancel = await http.post("/mcp", json=_cancel_body(2), headers=base_headers(session_id=session_id)) + assert cancel.status_code == 202 + # tg join: the stream ends, but only after the response below arrived. + messages = parse_sse_messages(events) + assert [type(m).__name__ for m in messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + notification, response = messages + assert isinstance(notification, JSONRPCNotification) + assert notification.method == "notifications/message" + assert (notification.params or {})["data"] == "still working" + assert isinstance(response, JSONRPCResponse) + assert response.id == 2 + assert response.result["content"][0]["text"] == "finished anyway" + + +@requirement("hosting:http:cancel-receipt-keeps-stream-open") +async def test_handler_that_ignores_the_cancel_in_json_mode_still_gets_its_response_delivered() -> None: + """The MAY-ignore arm in JSON response mode: the POST completes with the handler's real 200 + JSON response, not the settled exchange's 204.""" + started = anyio.Event() + cancel_delivered = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "stubborn" + with anyio.CancelScope(shield=True): + started.set() + await cancel_delivered.wait() + return CallToolResult(content=[TextContent(text="finished anyway")]) + + async def on_cancelled(ctx: ServerRequestContext, params: CancelledNotificationParams) -> None: + # Runs after the dispatcher already applied the cancel to the request's scope. + cancel_delivered.set() + + server = Server("stubborn", on_call_tool=call_tool) + server.add_notification_handler("notifications/cancelled", CancelledNotificationParams, on_cancelled) + responses: list[httpx.Response] = [] + async with mounted_app(server, json_response=True) as (http, _): + session_id = await _initialize_via_http_json(http) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def post_stubborn_request() -> None: + responses.append( + await http.post( + "/mcp", json=_tool_call_body(2, "stubborn"), headers=base_headers(session_id=session_id) + ) + ) + + tg.start_soon(post_stubborn_request) + await started.wait() + cancel = await http.post("/mcp", json=_cancel_body(2), headers=base_headers(session_id=session_id)) + assert cancel.status_code == 202 + (response,) = responses + assert response.status_code == 200 + body = JSONRPCResponse.model_validate(response.json()) + assert body.id == 2 + assert body.result["content"][0]["text"] == "finished anyway" diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index 779a46054..18ae86e20 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -168,3 +168,62 @@ async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) ) assert [params.message for params in asked] == snapshot(["Proceed?"]) + + +@requirement("client-transport:http:204-settled-exchange") +async def test_cancelled_call_in_json_mode_settles_cleanly_and_the_session_keeps_serving( + caplog: pytest.LogCaptureFixture, +) -> None: + """Scope-cancelling an in-flight call over JSON response mode ends with the server settling + the request as 204 No Content: the transport's parked POST task consumes it as 'settled, + nothing to deliver' (no response is synthesized from the empty body), proved by a follow-up + call on the same session succeeding after the 204 has been consumed.""" + mcp = MCPServer("cancellable") + started = anyio.Event() + handler_cancelled = anyio.Event() + + @mcp.tool() + async def block() -> str: + """Block until the cancellation interrupts the handler.""" + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + async with connect_over_streamable_http(mcp, json_response=True) as client: + with anyio.fail_after(5): + with anyio.CancelScope() as scope: + async with anyio.create_task_group() as task_group: + + async def call() -> None: + await client.call_tool("block", {}) + raise NotImplementedError # unreachable: the scope is cancelled + + task_group.start_soon(call) + await started.wait() + scope.cancel() + assert scope.cancelled_caught + await handler_cancelled.wait() + # Quiesce so the parked POST task has consumed the server's settle (the 204) before + # the follow-up call and the caplog check below — otherwise the test could finish + # before the would-be parse even happened. + await anyio.wait_all_tasks_blocked() + result = await client.call_tool("echo", {"text": "still here"}) + + # The follow-up's success on the same session is the proof the settled exchange left the + # transport serviceable. + assert result == snapshot( + CallToolResult(content=[TextContent(text="still here")], structured_content={"result": "still here"}) + ) + # Secondary, meaningful only after the quiesce above: the consumed 204 was not fed to the + # JSON body parser (which would log this error and synthesize a PARSE_ERROR reply for the + # already-retired request id). + assert "Error parsing JSON response" not in caplog.text diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index bc2df01cd..a6fb80d06 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -14,7 +14,7 @@ from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage @pytest.mark.anyio @@ -25,7 +25,7 @@ async def test_request_id_match() -> None: # Create memory streams for communication client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) - server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) # Server task to process the request async def run_server(): diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 409ac96a2..44779a29c 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -19,7 +19,7 @@ from mcp.client.session import ClientSession from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage @pytest.mark.anyio @@ -72,8 +72,8 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server = Server(name="test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) async def server_handler( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | RequestSettled], + write_stream: MemoryObjectSendStream[SessionMessage | RequestSettled], task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, ): with anyio.CancelScope() as scope: @@ -86,8 +86,8 @@ async def server_handler( ) async def client( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | RequestSettled], + write_stream: MemoryObjectSendStream[SessionMessage | RequestSettled], scope: anyio.CancelScope, ): # No session-level timeout to avoid race conditions with fast operations @@ -115,8 +115,8 @@ async def client( scope.cancel() # pragma: lax no cover # Run server and client in separate task groups to avoid cancellation - server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage](1) - client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage](1) + server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | RequestSettled](1) + client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage | RequestSettled](1) async with anyio.create_task_group() as tg: scope = await tg.start(server_handler, server_reader, client_writer) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index e7fed4073..42bf2324c 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -20,7 +20,7 @@ from mcp import Client from mcp.server import Server, ServerRequestContext -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage @pytest.mark.anyio @@ -107,7 +107,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server = Server("test", on_call_tool=handle_call_tool) to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) - server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage | RequestSettled](10) async def run_server(): await server.run(server_read, server_write, server.create_initialization_options()) @@ -181,7 +181,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server = Server("test", on_call_tool=handle_call_tool) to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) - server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage | RequestSettled](10) async def run_server(): await server.run(server_read, server_write, server.create_initialization_options()) diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 4cfff47c3..eea4ed78c 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -23,7 +23,7 @@ from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.mcpserver import Context, MCPServer from mcp.server.models import InitializationOptions -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage @pytest.mark.anyio @@ -53,7 +53,7 @@ async def check_lifespan( # Create memory streams for testing send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage | RequestSettled](100) # Run server in background task async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: @@ -92,6 +92,7 @@ async def run_server(): ) ) response = await receive_stream2.receive() + assert isinstance(response, SessionMessage) response = response.message # Send initialized notification @@ -111,6 +112,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() + assert isinstance(response, SessionMessage) response = response.message assert isinstance(response, JSONRPCMessage) assert isinstance(response, JSONRPCResponse) @@ -138,7 +140,7 @@ async def test_lifespan(server: MCPServer) -> AsyncIterator[dict[str, bool]]: # Create memory streams for testing send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage | RequestSettled](100) # Add a tool that checks lifespan context @server.tool() @@ -179,6 +181,7 @@ async def run_server(): ) ) response = await receive_stream2.receive() + assert isinstance(response, SessionMessage) response = response.message # Send initialized notification @@ -198,6 +201,7 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() + assert isinstance(response, SessionMessage) response = response.message assert isinstance(response, JSONRPCMessage) assert isinstance(response, JSONRPCResponse) diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 15df7f1ce..176ca90ce 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -2,7 +2,7 @@ import pytest from mcp.server.lowlevel.server import Server -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage @pytest.mark.anyio @@ -30,7 +30,7 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo # `async with read_stream, write_stream:` block and closes the stream, at # which point the blocked send raises ClosedResourceError. This # deterministically reproduces the race without sleeps. - write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) + write_send, write_recv = anyio.create_memory_object_stream[SessionMessage | RequestSettled](0) # What the streamable HTTP transport does: push the exception, then close. read_send.send_nowait(RuntimeError("simulated transport error")) diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py index 7b53b4265..e07626eac 100644 --- a/tests/shared/conftest.py +++ b/tests/shared/conftest.py @@ -13,7 +13,7 @@ from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import Dispatcher from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import SessionMessage +from mcp.shared.message import RequestSettled, SessionMessage from mcp.shared.transport_context import TransportContext DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] @@ -32,8 +32,8 @@ def close() -> None: def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) def builder(_meta: object) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=can_send_request) diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 23f1e0620..4d2201ff4 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -1,5 +1,6 @@ """Tests for OAuth 2.0 Resource Indicators utilities.""" +import pytest from pydantic import HttpUrl from mcp.shared.auth_utils import check_resource_allowed, check_token_audience, resource_url_from_server_url @@ -51,6 +52,36 @@ def test_check_token_audience_ignores_default_port(): assert check_token_audience("https://h:8443/mcp", "https://h/mcp") is False +def test_check_token_audience_treats_an_unparseable_audience_as_a_mismatch(): + """A token audience whose port cannot be parsed does not identify this server. + + SDK-defined: RFC 3986's grammar puts no upper bound on port digits, so an AS can + legitimately issue a token for `https://h:99999/mcp`; urllib refuses to parse such + ports, and that canonicalization failure must read as a mismatch, not an error. + """ + assert check_token_audience("https://h:99999/mcp", "https://h/mcp") is False + assert check_token_audience("https://h:abc/mcp", "https://h/mcp") is False + + +def test_check_token_audience_treats_trailing_slash_variants_as_one_resource(): + """`https://h/api/` and `https://h/api` are the same audience, in either direction. + + SDK-defined interop tolerance per authorization.mdx's canonical-URI note (both + spellings of one resource circulate; the slashless form is merely recommended), and + required at root because pydantic's `AnyHttpUrl` renders `https://h` as `https://h/` + while the spec's example token request sends the slashless form. + """ + assert check_token_audience("https://h/api/", "https://h/api") is True + assert check_token_audience("https://h/api", "https://h/api/") is True + assert check_token_audience("https://h", "https://h/") is True + + +def test_check_token_audience_rejects_sibling_and_child_paths(): + """Trailing-slash tolerance does not loosen path equality: siblings and children mismatch.""" + assert check_token_audience("https://h/api123", "https://h/api") is False + assert check_token_audience("https://h/api/sub", "https://h/api") is False + + def test_resource_url_from_server_url_lowercase_scheme_and_host(): """Scheme and host should be lowercase for canonical form.""" assert resource_url_from_server_url("HTTPS://EXAMPLE.COM/path") == "https://example.com/path" @@ -63,6 +94,19 @@ def test_resource_url_from_server_url_handles_pydantic_urls(): assert resource_url_from_server_url(url) == "https://example.com/path" +def test_resource_url_from_server_url_raises_on_unparseable_port(): + """An out-of-range or non-numeric port raises ValueError, as documented. + + SDK-defined: the canonicalizer stays strict for its trusted own-config callers; + `check_token_audience` wraps the untrusted token side. The message is urllib's, + so only the exception type is pinned. + """ + with pytest.raises(ValueError): + resource_url_from_server_url("https://example.com:99999/mcp") + with pytest.raises(ValueError): + resource_url_from_server_url("https://example.com:abc/mcp") + + # Tests for check_resource_allowed function diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 23c588bdc..b560c0c8e 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -18,6 +18,7 @@ INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT, + ElicitRequestURLParams, ErrorData, Tool, ) @@ -25,7 +26,7 @@ from mcp.shared._compat import resync_tracer from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.exceptions import MCPError, NoBackChannelError, UrlElicitationRequiredError from mcp.shared.transport_context import TransportContext from .conftest import PairFactory, direct_pair @@ -95,7 +96,11 @@ async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): +async def test_send_raw_request_surfaces_handler_mcperror_code_and_message(pair_factory: PairFactory): + """A handler-raised `MCPError`'s code and message surface to the caller on every + dispatcher (SDK-defined). The caller gets an equal-valued re-raise, not the + handler's exception object — see the subclass-flattening test below.""" + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -108,6 +113,39 @@ async def on_request( assert exc.value.error.message == "bad cursor" +@pytest.mark.anyio +async def test_send_raw_request_flattens_handler_mcperror_subclass_to_plain_mcperror(pair_factory: PairFactory): + """A handler-raised `MCPError` subclass surfaces to the caller as plain `MCPError` + with equal `ErrorData` on every dispatcher (SDK-defined): subclass identity cannot + cross the JSON-RPC wire, and `DirectDispatcher` matches so in-process callers see + the same error surface. Callers needing the subclass rehydrate it from the error + data (e.g. `UrlElicitationRequiredError.from_error`).""" + raised: list[UrlElicitationRequiredError] = [] + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + error = UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required", + url="https://example.com/authorize", + elicitation_id="auth-001", + ) + ] + ) + raised.append(error) + raise error + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", {}) + # Flattened: exactly MCPError, the subclass type does not survive dispatch. + assert type(exc.value) is MCPError + # ...but the full ErrorData (code/message/data) of the raised subclass does. + assert exc.value.error == raised[0].error + + @pytest.mark.anyio async def test_send_raw_request_maps_validation_error_to_invalid_params(pair_factory: PairFactory): """A pydantic `ValidationError` from the handler surfaces as the diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 79f45542e..c65c7f06c 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -43,7 +43,13 @@ _Pending, _plan_outbound, ) -from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + RequestSettled, + ServerMessageMetadata, + SessionMessage, +) from mcp.shared.transport_context import TransportContext from .conftest import jsonrpc_pair @@ -56,9 +62,9 @@ class RecordingWriteStream: """Records sends without a checkpoint, so a pending cancellation cannot interrupt the write or mask it.""" def __init__(self) -> None: - self.sent: list[SessionMessage] = [] + self.sent: list[SessionMessage | RequestSettled] = [] - async def send(self, item: SessionMessage) -> None: + async def send(self, item: SessionMessage | RequestSettled) -> None: self.sent.append(item) async def aclose(self) -> None: @@ -119,15 +125,16 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio -async def test_peer_cancel_interrupt_mode_writes_no_response(): - """Spec-mandated: a peer-cancelled request is interrupted and the receiver writes no response. +async def test_peer_cancel_interrupt_mode_writes_no_response_and_emits_settled_marker(): + """A peer-cancelled request is interrupted, writes no JSON-RPC response (spec-mandated), and + emits a `RequestSettled` marker so transports can release per-request resources (SDK-defined). Scripted at the wire: the handler-exit event proves the cancel reached the running handler; - a follow-up request's response being the *first* thing on the ordered server→client stream - proves nothing was emitted for the cancelled id. + the marker being the only thing emitted before a follow-up probe's response proves nothing + was written for the cancelled id. """ - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() handler_exited = anyio.Event() @@ -162,12 +169,16 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) await handler_exited.wait() assert seen_ctx[0].cancel_requested.is_set() - await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="probe"))) + # Receive the marker before sending the probe: it must arrive with no + # subsequent traffic required, and nothing may precede it. first = await s2c_recv.receive() - assert isinstance(first, SessionMessage) - assert isinstance(first.message, JSONRPCResponse) - assert first.message.id == 2 - assert first.message.result == {"ok": True} + assert first == RequestSettled(request_id=1) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="probe"))) + second = await s2c_recv.receive() + assert isinstance(second, SessionMessage) + assert isinstance(second.message, JSONRPCResponse) + assert second.message.id == 2 + assert second.message.result == {"ok": True} tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -176,10 +187,11 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_peer_cancel_landing_after_handlers_last_checkpoint_writes_only_the_result(): - """A peer cancel that fails to interrupt the handler writes only the result: one answer per - id goes on the wire (SDK-defined). The recording stream is needed because a memory stream's - `send` checkpoints, letting the deferred cancellation land mid-write and hide a double answer.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + """A peer cancel that fails to interrupt the handler writes only the result — one answer per + id, and no settled marker (SDK-defined). The result write runs outside the request scope, so + the deferred cancellation is revoked at scope exit and cannot eat the write at its leading + checkpoint; the recording stream is now just an ordered recorder.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) recording = RecordingWriteStream() server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) handler_started = anyio.Event() @@ -212,11 +224,144 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> finally: c2s_send.close() c2s_recv.close() - assert [m.message for m in recording.sent] == [ - JSONRPCResponse(jsonrpc="2.0", id=1, result={"completed": "after-cancel"}) + assert recording.sent == [ + SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=1, result={"completed": "after-cancel"})) + ] + + +@pytest.mark.anyio +async def test_settled_marker_follows_the_handlers_related_notifications(): + """The settled marker rides the same FIFO as the handler's own traffic: a notification sent + before the cancel landed precedes the `RequestSettled`, so a transport consuming the marker + cannot release the request's resources ahead of buffered related messages (SDK-defined).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + recording = RecordingWriteStream() + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + notified = anyio.Event() + handler_exited = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.notify("notifications/progress", {"progressToken": "t", "progress": 1}) + notified.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="slow"))) + await notified.wait() + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ) + ) + ) + # The recording stream never checkpoints, so by the time the handler-exit + # wakeup reaches this task the marker emission has already completed. + await handler_exited.wait() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + assert recording.sent == [ + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", method="notifications/progress", params={"progressToken": "t", "progress": 1} + ), + metadata=ServerMessageMetadata(related_request_id=1), + ), + RequestSettled(request_id=1), ] +@pytest.mark.anyio +async def test_inbound_settled_marker_is_dropped_and_the_loop_keeps_serving(): + """A `RequestSettled` arriving on the read stream (the in-memory pair delivers the peer's + markers straight to the other dispatcher) is dropped without dispatch; the next request + still round-trips (SDK-defined).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(RequestSettled(request_id=9)) + # The read loop is serial, so this request reaching its handler proves + # the marker was consumed without crashing or answering it. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + reply = await s2c_recv.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + assert reply.message.id == 1 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_settled_marker_write_after_peer_drop_is_swallowed(caplog: pytest.LogCaptureFixture): + """The settled-marker write hitting a torn-down transport is dropped with a debug log and the + loop keeps serving — same policy as the result and error writes.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + handler_started = anyio.Event() + handler_exited = anyio.Event() + probe_handled = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "probe": + probe_handled.set() + return {"ok": True} + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="slow"))) + await handler_started.wait() + # Peer drops: the marker write below hits BrokenResourceError. + s2c_recv.close() + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ) + ) + ) + await handler_exited.wait() + # The loop surviving the dropped marker is the behaviour under test. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="probe"))) + await probe_handled.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert "dropped settled marker for 1: write stream closed" in caplog.text + + @pytest.mark.anyio async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): handler_started = anyio.Event() @@ -252,8 +397,8 @@ async def call() -> None: @pytest.mark.anyio async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -277,8 +422,8 @@ async def caller() -> None: @pytest.mark.anyio async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): """Iterating a closed receive end is EOF, not a crash (stateless SHTTP closes it during teardown).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) on_request, on_notify = echo_handlers(Recorder()) # Close the receive end itself (not the send end): __anext__ then raises ClosedResourceError. @@ -293,8 +438,8 @@ async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): async def test_run_cancels_in_flight_handlers_when_read_stream_eofs(): """run() cancels still-running handlers at read-stream EOF; otherwise its join waits forever (over SSE, leaking the handler and the GET request hosting the session).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() handler_cancelled = anyio.Event() @@ -330,8 +475,8 @@ async def drive() -> None: @pytest.mark.anyio async def test_run_closes_write_stream_on_exit(): """run() owns both streams; the write end is released once the EOF teardown completes.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) on_request, on_notify = echo_handlers(Recorder()) async with anyio.create_task_group() as tg: @@ -365,8 +510,8 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio async def test_raise_handler_exceptions_true_propagates_out_of_run(): - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) def builder(_meta: object) -> TransportContext: return TransportContext(kind="jsonrpc", can_send_request=True) @@ -402,8 +547,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): """Server-to-client requests carry related_request_id for SHTTP routing.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -440,8 +585,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async def test_courtesy_cancel_on_timeout_tags_outbound_with_server_message_metadata(): """The timeout-path `notifications/cancelled` carries the originating request id: SHTTP's `message_router` keys on `related_request_id`; without it the cancel would be dropped.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -486,8 +631,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async def test_dispatch_context_request_with_dropped_resumption_hints_still_sends_courtesy_cancel(): """Resumption hints that never reach the transport must not suppress the abandon cancel: `related_request_id` takes metadata precedence and drops the hints, so the request is not resumable.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -530,8 +675,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_caller_cancel_sends_courtesy_cancellation_on_the_wire(): """Cancelling the scope around send_raw_request emits notifications/cancelled by default.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -581,10 +726,11 @@ def __init__(self) -> None: self.sent: list[SessionMessage] = [] self.first_write_started = anyio.Event() - async def send(self, item: SessionMessage) -> None: + async def send(self, item: SessionMessage | RequestSettled) -> None: if not self.first_write_started.is_set(): self.first_write_started.set() await anyio.sleep_forever() # the request write wedges until the caller is cancelled + assert isinstance(item, SessionMessage) # this client handles no inbound request, so no marker self.sent.append(item) async def aclose(self) -> None: @@ -601,7 +747,7 @@ async def __aexit__( ) -> bool | None: return None - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) wedged = FirstWriteWedgedStream() client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, wedged) on_request, on_notify = echo_handlers(Recorder()) @@ -646,8 +792,8 @@ async def test_caller_cancel_during_delivered_request_write_sends_courtesy_cance """A cancelled request write may still deliver: on a buffer-0 stream the transport can pop the parked request in the same tick the cancel lands, so send() raises CancelledError after handing the message over. The peer saw the id, so the courtesy cancel must still go out.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -701,8 +847,8 @@ async def marker_after_caller_unwinds() -> None: async def test_caller_cancelled_before_request_write_starts_sends_no_courtesy_cancellation(): """A caller whose scope is already cancelled never gets the request onto the wire, so no courtesy cancel goes out either: there is provably no id for the peer to stop.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -740,8 +886,8 @@ async def caller() -> None: @pytest.mark.anyio async def test_caller_cancel_with_resumption_hints_suppresses_the_courtesy_cancellation(): """A request sent with resumption hints is meant to be resumed; abandoning it must not stop the peer's work.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -785,8 +931,8 @@ async def caller() -> None: @pytest.mark.anyio async def test_timeout_with_resumption_hints_suppresses_the_courtesy_cancellation(): """A timed-out request that carries resumption hints stays resumable: no cancellation is sent.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -816,8 +962,8 @@ async def test_timeout_with_resumption_hints_suppresses_the_courtesy_cancellatio @pytest.mark.anyio async def test_cancel_on_abandon_false_suppresses_the_courtesy_cancellation_on_timeout(): """Callers opt out per call for requests the protocol forbids cancelling (initialize).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -851,7 +997,7 @@ def __init__(self) -> None: self.attempts = 0 self.error = TimeoutError("transport send timed out") - async def send(self, item: SessionMessage) -> None: + async def send(self, item: SessionMessage | RequestSettled) -> None: self.attempts += 1 raise self.error @@ -877,7 +1023,7 @@ async def test_transport_write_timeout_propagates_raw_when_no_request_timeout_is fired — and must propagate raw instead of being mislabelled REQUEST_TIMEOUT. (Genuine expiry after a completed write is pinned by the timeout tests above and `test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_wedged`.)""" - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) transport = TimingOutWriteStream() client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, transport) on_request, on_notify = echo_handlers(Recorder()) @@ -909,8 +1055,8 @@ async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wed """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang: `_ABANDON_WRITE_TIMEOUT` abandons the courtesy-cancel write (SDK-defined bound). On regression the test hangs rather than failing fast - fail_after cannot cancel through the shield.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -956,8 +1102,8 @@ async def test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_we ): """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined bound): `_ABANDON_WRITE_TIMEOUT` abandons the courtesy cancel so the error still surfaces.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -1001,8 +1147,8 @@ async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_we """Cancelling the task group hosting run() completes even when the shutdown error write wedges: only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). A 0-buffer stream nobody reads expresses the wedge: run() closes its write stream only after the join, so the send stays parked.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() @@ -1035,8 +1181,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async def test_shutdown_answers_in_flight_request_with_connection_closed(): """Read-stream EOF answers a still-running request with CONNECTION_CLOSED (SDK-defined): run() keeps the write stream open until the task-group join, so the shielded teardown write lands.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() @@ -1072,8 +1218,8 @@ async def test_shutdown_cancel_during_delivered_result_write_writes_no_second_an the transport pops the parked send in the same tick the shutdown cancel lands. The shutdown arm must not stack a CONNECTION_CLOSED answer on top - one request id, at most one answer (peers drop a missing answer via their own close fan-out, but a duplicate id breaks JSON-RPC).""" - read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) + read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + write_send, write_recv = anyio.create_memory_object_stream[SessionMessage | RequestSettled](0) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_recv, write_send) on_request, on_notify = echo_handlers(Recorder()) outer = anyio.CancelScope() @@ -1082,7 +1228,7 @@ async def run_server() -> None: with outer: await server.run(on_request, on_notify) - received: list[SessionMessage] = [] + received: list[SessionMessage | RequestSettled] = [] try: async with anyio.create_task_group() as tg: tg.start_soon(run_server) @@ -1103,7 +1249,7 @@ async def run_server() -> None: for s in (read_send, read_recv, write_send, write_recv): s.close() assert outer.cancelled_caught - assert [m.message for m in received] == [JSONRPCResponse(jsonrpc="2.0", id=7, result={"echoed": "t", "params": {}})] + assert received == [SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=7, result={"echoed": "t", "params": {}}))] @pytest.mark.anyio @@ -1112,7 +1258,7 @@ async def test_request_write_failure_propagates_and_leaves_no_pending_entry(): boom = RuntimeError("write failed") class RaisingWriteStream: - async def send(self, item: SessionMessage) -> None: + async def send(self, item: SessionMessage | RequestSettled) -> None: raise boom async def aclose(self) -> None: @@ -1129,7 +1275,7 @@ async def __aexit__( ) -> bool | None: return None - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, RaisingWriteStream()) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1148,8 +1294,8 @@ async def __aexit__( @pytest.mark.anyio async def test_request_write_on_torn_down_transport_raises_connection_closed(): """A write onto a torn-down transport surfaces as MCPError(CONNECTION_CLOSED), not a raw `BrokenResourceError`.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1170,8 +1316,8 @@ async def test_request_write_on_torn_down_transport_raises_connection_closed(): async def test_notify_after_connection_close_is_dropped_with_debug_log(caplog: pytest.LogCaptureFixture): """notify() after run() saw EOF is fire-and-forget: dropped with a debug log, matching the response-write policy, while the sibling send_raw_request raises.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1192,8 +1338,8 @@ async def test_notify_after_connection_close_is_dropped_with_debug_log(caplog: p @pytest.mark.anyio async def test_notify_on_torn_down_transport_is_dropped_with_debug_log(caplog: pytest.LogCaptureFixture): """A notify racing transport teardown (run() hasn't seen EOF yet) is dropped, not a raw `BrokenResourceError`.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1255,8 +1401,8 @@ async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | N @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_request_metadata(): """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) metadata = ServerMessageMetadata(request_context="request-scoped-data") seen: list[MessageMetadata] = [] @@ -1290,8 +1436,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_notification_metadata(): """Notifications get the same metadata pass-through as requests.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) metadata = ServerMessageMetadata(request_context="request-scoped-data") seen: list[MessageMetadata] = [] @@ -1390,8 +1536,8 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio async def test_inline_methods_are_handled_before_next_message_is_dequeued(): """An `inline_methods` method runs to completion before the next message is dispatched.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( c2s_recv, s2c_send, inline_methods=frozenset({"first"}) ) @@ -1460,8 +1606,8 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio async def test_send_raw_request_before_run_raises_runtimeerror(): - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) try: with pytest.raises(RuntimeError, match="before run"): @@ -1475,8 +1621,8 @@ async def test_send_raw_request_before_run_raises_runtimeerror(): async def test_send_raw_request_after_connection_close_raises_connection_closed(): """Sending after run() saw EOF raises MCPError(CONNECTION_CLOSED) — the same contract in-flight waiters get — not RuntimeError (SDK-defined).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1493,8 +1639,8 @@ async def test_send_raw_request_after_connection_close_raises_connection_closed( @pytest.mark.anyio async def test_transport_exception_in_read_stream_is_logged_and_dropped(): - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1516,8 +1662,8 @@ async def test_transport_exception_in_read_stream_is_logged_and_dropped(): @pytest.mark.anyio async def test_on_stream_exception_observes_transport_exceptions(): """With an observer set, Exception items reach it instead of being dropped; the loop stays healthy.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) seen: list[Exception] = [] @@ -1546,8 +1692,8 @@ async def observe(exc: Exception) -> None: @pytest.mark.anyio async def test_on_stream_exception_observer_raising_is_contained(caplog: pytest.LogCaptureFixture): """A raising observer costs the item, not the connection: it runs in the read loop itself.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) async def observe(exc: Exception) -> None: raise RuntimeError("observer boom") @@ -1636,7 +1782,7 @@ async def test_handler_inherits_sender_contextvars(inline: frozenset[str]): raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) write_send = ContextSendStream[SessionMessage | Exception](raw_send) - out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send, inline_methods=inline) seen: list[str] = [] @@ -1672,8 +1818,8 @@ async def sender() -> None: @pytest.mark.anyio async def test_response_write_after_peer_drop_is_swallowed(): """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) proceed = anyio.Event() handlers_done = anyio.Event() @@ -1711,8 +1857,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) caller_done = anyio.Event() @@ -1745,8 +1891,8 @@ async def caller() -> None: def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): """White-box: a response for an id still in _pending but whose waiter has gone.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] @@ -1758,8 +1904,8 @@ def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): """White-box: the buffer=1 invariant - WouldBlock means waiter already has an outcome.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] @@ -1786,8 +1932,8 @@ def test_plan_outbound_with_resumption_token_returns_client_metadata_and_suppres async def test_send_raw_request_projects_opts_headers_onto_message_metadata(): """`opts["headers"]` alone yields `ClientMessageMetadata(headers=...)` on the outbound `SessionMessage` (SDK-defined: the headers sidecar is the path the session uses to reach the transport).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) @@ -1818,8 +1964,8 @@ async def caller() -> None: @pytest.mark.anyio async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): """A peer that echoes the request ID as a JSON string still resolves the waiter.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1849,8 +1995,8 @@ async def respond_stringly() -> None: @pytest.mark.anyio async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (same `_coerce_id` path).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1885,8 +2031,8 @@ async def reject_stringly() -> None: @pytest.mark.anyio async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) seen: list[float] = [] @@ -1936,8 +2082,8 @@ def test_coerce_id_passes_through_non_numeric_string_and_int(): @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): """Parse-error responses (id=null) have no waiter; they're dropped and the read loop stays healthy.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) try: @@ -1967,8 +2113,8 @@ async def respond() -> None: @pytest.mark.anyio async def test_notify_without_params_omits_params_key_on_the_wire(): """JSON-RPC 2.0 forbids `params: null`: `notify` leaves `params` unset (transports use `exclude_unset=True`).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](4) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) try: await d.notify("notifications/tools/list_changed", None) @@ -1989,8 +2135,8 @@ async def test_notify_without_params_omits_params_key_on_the_wire(): @pytest.mark.anyio async def test_transport_builder_exception_on_request_is_answered_with_internal_error(): """A raising builder costs only the one request, not the connection.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) calls = 0 def builder(_meta: MessageMetadata) -> TransportContext: @@ -2028,8 +2174,8 @@ def builder(_meta: MessageMetadata) -> TransportContext: @pytest.mark.anyio async def test_transport_builder_exception_on_notification_drops_only_that_notification(): """A raising builder drops the one notification; the read loop survives.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) calls = 0 def builder(_meta: MessageMetadata) -> TransportContext: @@ -2100,8 +2246,8 @@ async def call() -> None: async def test_progress_with_bool_token_or_bool_progress_does_not_fire_callback(): """Bool `progressToken`/`progress` values are malformed; the callback must not fire for the unrelated request keyed by id 1 (`True == 1`).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) on_request, on_notify = echo_handlers(Recorder()) seen: list[float] = [] @@ -2147,8 +2293,8 @@ async def on_progress(progress: float, total: float | None, message: str | None) @pytest.mark.anyio async def test_request_with_bool_meta_progress_token_is_not_adopted(): """A bool `_meta.progressToken` is malformed: `ctx.progress()` must be a no-op, not emit `progressToken: true`.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -2185,8 +2331,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) async def test_cancelled_correlates_across_string_and_int_request_id_forms(request_id: RequestId, cancel_id: object): """A peer that stringifies the id between request and cancel still cancels (same `_coerce_id` path).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_exited = anyio.Event() @@ -2227,9 +2373,9 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async def test_completed_handler_does_not_evict_reused_request_id_from_in_flight(): """A second request reusing an id while the first handler is parked in its response write keeps its own `_in_flight` entry (a post-write pop would evict it and break peer-cancellation).""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) # buffer=0: the first handler's response write parks until the test receives. - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](0) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) calls = 0 second_started = anyio.Event() @@ -2285,8 +2431,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async def test_duplicate_request_id_completion_of_first_handler_keeps_second_cancellable(): """A duplicate inbound id overwrites `_in_flight` (parity with v1/TS); the identity-guarded pop keeps the first handler's completion from evicting the second's entry and breaking its cancellation.""" - c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception | RequestSettled](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) first_started = anyio.Event() release_first = anyio.Event() diff --git a/tests/shared/test_message.py b/tests/shared/test_message.py new file mode 100644 index 000000000..c10e157ea --- /dev/null +++ b/tests/shared/test_message.py @@ -0,0 +1,24 @@ +"""Tests for the transport-facing helpers in `mcp.shared.message`.""" + +import anyio +import pytest +from mcp_types import JSONRPCNotification + +from mcp.shared.message import RequestSettled, SessionMessage, wire_messages + + +@pytest.mark.anyio +async def test_wire_messages_strips_settled_markers_and_preserves_frame_order(): + """`wire_messages` yields only serializable frames: `RequestSettled` markers are dropped (they + must never reach any wire) and the surviving frames keep their order.""" + send, receive = anyio.create_memory_object_stream[SessionMessage | RequestSettled](3) + first = SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/first")) + last = SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/last")) + send.send_nowait(first) + send.send_nowait(RequestSettled(request_id=1)) + send.send_nowait(last) + send.close() + + with anyio.fail_after(5): + assert [frame async for frame in wire_messages(receive)] == [first, last] + receive.close() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3bcc8be8e..d9a5b0b11 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -64,7 +64,7 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import create_context_streams -from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, RequestSettled, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from tests.interaction.transports import StreamingASGITransport @@ -2284,3 +2284,93 @@ async def asgi_receive() -> Message: assert body_chunks[-1] == {"type": "http.response.body", "body": b"", "more_body": False} assert "Error in standalone SSE writer" not in caplog.text assert "Error in standalone SSE response" not in caplog.text + + +@pytest.mark.anyio +async def test_settled_marker_closes_only_the_request_streams_send_side() -> None: + """A `RequestSettled` on the write stream ends its per-request stream after buffered events + drain (send-side close only — a receive-side close would lose them), and a marker whose id has + no registered stream is a no-op: the router keeps routing (SDK-defined). + + Drives `connect()` with raw streams because the miss arm needs a marker for an id no POST + ever registered, which no full exchange can produce. + """ + transport = StreamableHTTPServerTransport(mcp_session_id=None) + async with transport.connect() as (_read_stream, write_stream): + send, receive = anyio.create_memory_object_stream[EventMessage](16) + transport._request_streams["7"] = (send, receive) + buffered = types.JSONRPCNotification(jsonrpc="2.0", method="notifications/message") + send.send_nowait(EventMessage(buffered)) + with anyio.fail_after(5): + # Unknown id first: the lookup misses and must not break the router. + await write_stream.send(RequestSettled(request_id=9)) + await write_stream.send(RequestSettled(request_id=7)) + # The buffered event survives the close and arrives before end-of-stream. + event = await receive.receive() + assert event.message is buffered + with pytest.raises(anyio.EndOfStream): + await receive.receive() + receive.close() + + +@pytest.mark.anyio +async def test_cancelled_request_in_json_mode_releases_its_per_request_stream() -> None: + """After a peer cancellation settles a JSON-mode POST as 204, the per-request stream the wait + loop parked on is deregistered (SDK-defined). Reaches into `_request_streams` because the + pre-fix leak — the cancelled request's stream stayed registered forever — is invisible on any + public surface; the wire-level 204 contract is pinned in + `tests/interaction/transports/test_hosting_http.py`. + """ + started = anyio.Event() + cancelled = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + started.set() + try: + await anyio.Event().wait() # blocks until the cancellation interrupts it + except anyio.get_cancelled_exc_class(): + cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + session_manager = StreamableHTTPSessionManager( + app=Server("blocker", on_call_tool=call_tool), + json_response=True, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + headers = {"Accept": "application/json", "Content-Type": "application/json"} + async with session_manager.run(), make_client(app) as http: + init_response = await http.post("/mcp", json=INIT_REQUEST, headers=headers) + assert init_response.status_code == 200 + headers[MCP_SESSION_ID_HEADER] = init_response.headers[MCP_SESSION_ID_HEADER] + headers[MCP_PROTOCOL_VERSION_HEADER] = init_response.json()["result"]["protocolVersion"] + call_body: dict[str, Any] = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "block", "arguments": {}}, + } + cancel_body: dict[str, Any] = { + "jsonrpc": "2.0", + "method": "notifications/cancelled", + "params": {"requestId": 2}, + } + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def post_blocked_request() -> None: + response = await http.post("/mcp", json=call_body, headers=headers) + assert response.status_code == 204 + + tg.start_soon(post_blocked_request) + await started.wait() + cancel = await http.post("/mcp", json=cancel_body, headers=headers) + assert cancel.status_code == 202 + await cancelled.wait() + # Quiesce so the settled exchange's stream cleanup (which runs after the 204 + # body is sent) finishes before the leak assertion below. + await anyio.wait_all_tasks_blocked() + (transport,) = session_manager._server_instances.values() + assert transport._request_streams == {}