From 46e82180d3c82df79714e1b6149a3b710d989190 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:38:53 +0000 Subject: [PATCH 1/7] tests: eliminate port-allocation races in SSE/StreamableHTTP tests Replace the multiprocessing.Process + socket.bind(("",0)) pattern with a run_server_in_thread helper that uses uvicorn with port=0 in a background thread. The kernel atomically assigns the port at bind time and we read it back from the bound socket, eliminating the TOCTOU window that caused test_response and others to intermittently connect to the wrong server under pytest-xdist parallel execution. Where possible (StreamableHTTP security tests, Unicode tests), use httpx.ASGITransport for fully in-process testing with no real sockets. Legacy SSE tests keep real HTTP because ASGITransport buffers responses and cannot support the long-lived GET /sse stream pattern. Github-Issue: #1573 --- tests/client/test_http_unicode.py | 265 ++++++-------- tests/server/test_sse_security.py | 228 ++++-------- tests/server/test_streamable_http_security.py | 333 +++++++----------- tests/shared/test_sse.py | 227 +++--------- tests/shared/test_streamable_http.py | 284 +++++---------- tests/test_helpers.py | 65 ++++ 6 files changed, 495 insertions(+), 907 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e46..ee105505f 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,11 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import httpx import pytest from starlette.applications import Starlette from starlette.routing import Mount @@ -19,7 +18,6 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,197 +39,132 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn - - # Need to recreate the server setup in this process - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, - }, - "required": ["text"], - }, - ), - ] - ) +async def _handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": {"text": {"type": "string", "description": "Text to echo back"}}, + "required": ["text"], + }, + ), + ] + ) + + +async def _handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult(content=[TextContent(type="text", text=f"Echo: {text}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "echo_unicode": - text = params.arguments.get("text", "") if params.arguments else "" - return types.CallToolResult( - content=[ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] +async def _handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], ) - else: - raise ValueError(f"Unknown tool: {params.name}") - - async def handle_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - return types.ListPromptsResult( - prompts=[ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ] + ) + + +async def _handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + if params.name == "unicode_prompt": + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), ) ] ) + raise ValueError(f"Unknown prompt: {params.name}") # pragma: no cover - async def handle_get_prompt( - ctx: ServerRequestContext, params: types.GetPromptRequestParams - ) -> types.GetPromptResult: - if params.name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] - ) - raise ValueError(f"Unknown prompt: {params.name}") +def _make_unicode_app() -> Starlette: server = Server( name="unicode_test_server", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - on_list_prompts=handle_list_prompts, - on_get_prompt=handle_get_prompt, - ) - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + on_list_prompts=_handle_list_prompts, + on_get_prompt=_handle_get_prompt, ) + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with session_manager.run(): yield - # Create an ASGI application - app = Starlette( + return Starlette( debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], + routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=lifespan, ) - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - @pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) +async def unicode_session() -> AsyncGenerator[ClientSession, None]: + """Create an initialized client session connected to the in-process unicode server.""" + app = _make_unicode_app() + async with app.router.lifespan_context(app): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, follow_redirects=True) as http_client: + async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): + async with ClientSession(rs, ws) as session: + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call(unicode_session: ClientSession) -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 - - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Test 1: List tools (server→client Unicode in descriptions) + tools = await unicode_session.list_tools() + assert len(tools.tools) == 1 - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await unicode_session.call_tool("echo_unicode", arguments={"text": test_string}) + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts(unicode_session: ClientSession) -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await unicode_session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await unicode_session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a2..bd9a174cd 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,155 +1,127 @@ """Tests for SSE server DNS rebinding protection.""" +import contextlib import logging -import multiprocessing -import socket +from collections.abc import Generator import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_server_in_thread + +# Several tests open an SSE stream, check the status code, then exit without +# consuming the stream. When uvicorn shuts down, it cancels the still-running +# SSE handler mid-operation, and SseServerTransport's internal memory streams +# may be GC'd without their cleanup finalizers running. These ResourceWarnings +# are artifacts of the abrupt-disconnect test pattern, not production bugs. +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: - return [] + return [] # pragma: no cover -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" +def make_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: + """Build a Starlette app with SSE transport and the given security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): - try: + async def handle_sse(request: Request) -> Response: + # connect_sse sends responses directly via ASGI `send` (both the SSE stream + # and any validation error responses), so by the time we return here the + # response has already been sent. Starlette will still try to send our + # return value, which fails with "Unexpected ASGI message". We suppress + # ValueError from connect_sse and wrap the final Response() send in a + # no-op so Starlette's machinery doesn't conflict. + with contextlib.suppress(ValueError): async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: + if streams: # pragma: no branch await app.run(streams[0], streams[1], app.create_initialization_options()) - except ValueError as e: - # Validation error was already handled inside connect_sse - logger.debug(f"SSE connection failed validation: {e}") - return Response() + return _AlreadySentResponse() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] + return Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +class _AlreadySentResponse(Response): + """No-op Response for handlers that already sent via raw ASGI `send`.""" -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + pass -@pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +@pytest.fixture +def server_url() -> Generator[str, None, None]: + """Default-settings server for tests that don't need custom security config.""" + with run_server_in_thread(make_app(), lifespan="off") as url: + yield url - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() +@pytest.mark.anyio +async def test_sse_security_default_settings(server_url: str): + """Test SSE with default security settings (protection disabled).""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(): """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Origin": "http://evil.com"}) assert response.status_code == 403 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(): """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" + # Test POST with invalid content type response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{url}/messages/?session_id={fake_session_id}", headers={"Content-Type": "text/plain"}, content="test", ) @@ -157,137 +129,85 @@ async def test_sse_security_post_invalid_content_type(server_port: int): assert response.text == "Invalid Content-Type header" # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) + response = await client.post(f"{url}/messages/?session_id={fake_session_id}", content="test") assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - + with run_server_in_thread(make_app(settings), lifespan="off") as url: async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{url}/sse", headers={"Host": "evil.com"}) as response: # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_in_thread(make_app(settings), lifespan="off") as url: # Test with custom allowed host - headers = {"Host": "custom.host"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host + async with client.stream("GET", f"{url}/sse", headers={"Host": "custom.host"}) as response: assert response.status_code == 200 # Test with non-allowed host - headers = {"Host": "evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_in_thread(make_app(settings), lifespan="off") as url: # Test with various port numbers for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port + async with client.stream("GET", f"{url}/sse", headers={"Host": f"localhost:{test_port}"}) as response: assert response.status_code == 200 - headers = {"Origin": f"http://localhost:{test_port}"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port + headers = {"Origin": f"http://localhost:{test_port}"} + async with client.stream("GET", f"{url}/sse", headers=headers) as response: assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(): """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - # Test with various valid content types valid_content_types = [ "application/json", "application/json; charset=utf-8", "application/json;charset=utf-8", "APPLICATION/JSON", # Case insensitive ] - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{url}/messages/?session_id={fake_session_id}", headers={"Content-Type": content_type}, json={"test": "data"}, ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check + # Will get 404 because session doesn't exist — that means we passed content-type validation assert response.status_code == 404 assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..11f75e9a3 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,13 +1,10 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount from starlette.types import Receive, Scope, Send @@ -16,36 +13,21 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server SERVER_NAME = "test_streamable_http_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: - return [] + return [] # pragma: no cover -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" +def make_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: + """Build a Starlette app with the given security settings.""" app = SecurityTestServer() - - # Create session manager with security settings session_manager = StreamableHTTPSessionManager( app=app, json_response=False, @@ -53,239 +35,164 @@ def run_server_with_settings(port: int, security_settings: TransportSecuritySett security_settings=security_settings, ) - # Create the ASGI handler async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await session_manager.handle_request(scope, receive, send) - # Create Starlette app with lifespan @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with session_manager.run(): yield - routes = [ - Mount("/", app=handle_streamable_http), - ] + return Starlette(routes=[Mount("/", app=handle_streamable_http)], lifespan=lifespan) - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +@asynccontextmanager +async def make_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create an httpx client wired to an in-process ASGI app via ASGITransport. -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + StreamableHTTP POST requests return promptly (SSE body then close), so the + ASGITransport buffering behavior is not an issue here. + """ + app = make_app(security_settings) + async with app.router.lifespan_context(app): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver", timeout=5.0) as client: + yield client @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): +async def test_streamable_http_security_default_settings(): """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() + async with make_client() as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): +async def test_streamable_http_security_invalid_host_header(): """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with make_client(security_settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): +async def test_streamable_http_security_invalid_origin_header(): """Test StreamableHTTP with invalid Origin header.""" - security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["testserver"]) + async with make_client(security_settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Origin": "http://evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): +async def test_streamable_http_security_invalid_content_type(): """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() + async with make_client() as client: + # Test POST with invalid content type + response = await client.post( + "/", + headers={ + "Content-Type": "text/plain", + "Accept": "application/json, text/event-stream", + }, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + "/", + headers={"Accept": "application/json, text/event-stream"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): +async def test_streamable_http_security_disabled(): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with make_client(settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Should connect successfully even with invalid host + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): +async def test_streamable_http_security_custom_allowed_hosts(): """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost", "127.0.0.1", "custom.host"], - allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], + allowed_hosts=["localhost", "testserver", "custom.host"], + allowed_origins=["http://localhost", "http://testserver", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with make_client(settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "custom.host", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Should connect successfully with custom host + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): +async def test_streamable_http_security_get_request(): """Test StreamableHTTP GET request with security.""" - security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["testserver"]) + async with make_client(security_settings) as client: # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + response = await client.get("/", headers={"Host": "evil.com", "Accept": "text/event-stream"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + # Test GET request with valid host header but no session ID + # Should pass security but fail on session validation + response = await client.get("/", headers={"Host": "testserver", "Accept": "text/event-stream"}) + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707..63b285ae0 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,6 +1,4 @@ import json -import multiprocessing -import socket from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -9,7 +7,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -41,31 +38,24 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server - -SERVER_NAME = "test_server_for_SSE" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +from tests.test_helpers import run_server_in_thread +# When SSE clients disconnect abruptly (exiting sse_client context while the +# server's long-lived SSE stream is open), uvicorn cancels the server handler +# mid-operation and SseServerTransport's internal memory streams may be GC'd +# without their finalizers running. This is a test-lifecycle artifact of +# abrupt disconnect, not a production bug — real clients consume the stream. +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +SERVER_NAME = "test_server_for_SSE" -async def _handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: +async def _handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": + elif parsed.scheme == "slow": # pragma: no cover await anyio.sleep(2.0) text = f"Slow response from {parsed.netloc}" else: @@ -73,39 +63,15 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover +def _create_server() -> Server: return Server( SERVER_NAME, on_read_resource=_handle_read_resource, - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool, ) -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover +def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) @@ -117,47 +83,25 @@ async def handle_sse(request: Request) -> Response: await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +def server() -> Generator[str, None, None]: + """Run the basic SSE server in a background thread, yielding its URL.""" + with run_server_in_thread(make_server_app(), lifespan="off") as url: + yield url @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server: str) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: + async with httpx.AsyncClient(base_url=server) as client: yield client @@ -188,8 +132,8 @@ async def connection_test() -> None: @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection(server: str) -> None: + async with sse_client(server + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -202,10 +146,10 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: +async def test_sse_client_on_session_created(server: str) -> None: captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client(server + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -231,7 +175,7 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non @pytest.mark.anyio async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch + server: str, monkeypatch: pytest.MonkeyPatch ) -> None: callback_mock = Mock() @@ -240,7 +184,7 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(server + "/sse", on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -250,8 +194,8 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session(server: str) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -297,37 +241,18 @@ async def test_sse_client_timeout( # pragma: no cover pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: # pragma: no cover +@pytest.fixture() +def mounted_server() -> Generator[str, None, None]: + """Run the SSE server mounted under a sub-path, yielding its base URL.""" app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") + with run_server_in_thread(main_app, lifespan="off") as url: + yield url @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: +async def test_sse_client_basic_connection_mounted_app(mounted_server: str) -> None: + async with sse_client(mounted_server + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -339,26 +264,22 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - headers_info: dict[str, Any] = {} - if ctx.request: - headers_info = dict(ctx.request.headers) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert ctx.request is not None + headers_info = dict(ctx.request.headers) if params.name == "echo_headers": return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - elif params.name == "echo_context": - context_data = { - "request_id": (params.arguments or {}).get("request_id"), - "headers": headers_info, - } - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + assert params.name == "echo_context" + context_data = { + "request_id": (params.arguments or {}).get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -381,9 +302,8 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing +def make_context_server_app() -> Starlette: + """Create a Starlette app with an SSE server that echoes request context.""" security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) @@ -399,73 +319,47 @@ async def handle_sse(request: Request) -> Response: await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") +def context_server() -> Generator[str, None, None]: + """Run the context-echoing SSE server in a background thread, yielding its URL.""" + with run_server_in_thread(make_context_server_app(), lifespan="off") as url: + yield url @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(context_server: str) -> None: """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): + async with sse_client(context_server + "/sse", headers=custom_headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: - # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(context_server: str) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] @@ -473,14 +367,10 @@ async def test_request_context_isolation(context_server: None, server_url: str) for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): + async with sse_client(context_server + "/sse", headers=headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 @@ -489,7 +379,6 @@ async def test_request_context_isolation(context_server: None, server_url: str) ) contexts.append(context_data) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -605,7 +494,7 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: @pytest.mark.anyio -async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: +async def test_sse_session_cleanup_on_disconnect(server: str) -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 When a client disconnects, the server should remove the session from @@ -616,7 +505,7 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) captured: list[str] = [] # Connect a client session, then disconnect - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client(server + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: await session.initialize() @@ -624,7 +513,7 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) # (not 202 as it did before the fix) async with httpx.AsyncClient() as client: response = await client.post( - f"{server_url}/messages/?session_id={captured[0]}", + f"{server}/messages/?session_id={captured[0]}", json={"jsonrpc": "2.0", "method": "ping", "id": 99}, headers={"Content-Type": "application/json"}, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..344deab4f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,10 +6,7 @@ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -21,7 +18,6 @@ import httpx import pytest import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -65,7 +61,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_server_in_thread # Test constants SERVER_NAME = "test_streamable_http_server" @@ -431,74 +427,18 @@ def create_app( return app -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. - """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, - ) - - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests +# Test fixtures — uvicorn in a background thread with port=0 (no port races) @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +def basic_server() -> Generator[str, None, None]: + """Start a basic server. Yields the server URL.""" + with run_server_in_thread(create_app()) as url: + yield url @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def basic_server_url(basic_server: str) -> str: + """Alias for basic_server (kept for test signature compatibility).""" + return basic_server @pytest.fixture @@ -508,69 +448,32 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store and retry_interval enabled. - # Clean up - proc.kill() - proc.join(timeout=2) + Yields (event_store, server_url). Unlike the old multiprocessing fixture, the + event_store is now the SAME object used by the server (same process), so tests + can inspect server-side state directly if needed. + """ + with run_server_in_thread(create_app(event_store=event_store, retry_interval=500)) as url: + yield event_store, url @pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +def json_response_server() -> Generator[str, None, None]: + """Start a server with JSON response enabled. Yields the server URL.""" + with run_server_in_thread(create_app(is_json_response_enabled=True)) as url: + yield url @pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +def json_server_url(json_response_server: str) -> str: + """Alias for json_response_server (kept for test signature compatibility).""" + return json_response_server # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): +def test_accept_header_validation(basic_server: str, basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +498,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_wildcard(basic_server: str, basic_server_url: str, accept_header: str): """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +519,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_incompatible(basic_server: str, basic_server_url: str, accept_header: str): """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +533,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str): +def test_content_type_validation(basic_server: str, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +549,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): +def test_json_validation(basic_server: str, basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +564,7 @@ def test_json_validation(basic_server: None, basic_server_url: str): assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str): +def test_json_parsing(basic_server: str, basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +579,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str): assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str): +def test_method_not_allowed(basic_server: str, basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +594,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str): +def test_session_validation(basic_server: str, basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -766,7 +669,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): +def test_session_termination(basic_server: str, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +709,7 @@ def test_session_termination(basic_server: None, basic_server_url: str): assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str): +def test_response(basic_server: str, basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +744,7 @@ def test_response(basic_server: None, basic_server_url: str): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str): +def test_json_response(json_response_server: str, json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +759,7 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): +def test_json_response_accept_json_only(json_response_server: str, json_server_url: str): """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +774,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_missing_accept_header(json_response_server: str, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +791,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_incorrect_accept_header(json_response_server: str, json_server_url: str): """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,7 +815,7 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): +def test_json_response_wildcard_accept_header(json_response_server: str, json_server_url: str, accept_header: str): """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -927,7 +830,7 @@ def test_json_response_wildcard_accept_header(json_response_server: None, json_s assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str): +def test_get_sse_stream(basic_server: str, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -987,7 +890,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str): +def test_get_validation(basic_server: str, basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1044,14 +947,14 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client(basic_server: str, basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_server: str, basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1060,7 +963,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server: str, basic_server_url: str): """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1105,7 +1008,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server: str, basic_server_url: str): """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1126,7 +1029,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_response_server: str, json_server_url: str): """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1147,7 +1050,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server: str, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1198,7 +1101,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server: str, basic_server_url: str): """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1234,7 +1137,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch + basic_server: str, basic_server_url: str, monkeypatch: pytest.MonkeyPatch ): """Test client session termination functionality with a 204 response. @@ -1412,7 +1315,7 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): +async def test_streamablehttp_server_sampling(basic_server: str, basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1516,59 +1419,30 @@ async def _handle_context_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def create_context_aware_app() -> Starlette: + """Create the context-aware test server app.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, on_call_tool=_handle_context_call_tool, ) - - session_manager = StreamableHTTPSessionManager( - app=server, - event_store=None, - json_response=False, - ) - - app = Starlette( + session_manager = StreamableHTTPSessionManager(app=server, event_store=None, json_response=False) + return Starlette( debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], + routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - @pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") +def context_aware_server() -> Generator[str, None, None]: + """Start the context-aware server. Yields the server URL.""" + with run_server_in_thread(create_context_aware_app()) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1577,7 +1451,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1601,7 +1475,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1614,7 +1488,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1639,9 +1513,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_aware_server: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -1659,7 +1533,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): +def test_server_validates_protocol_version_header(basic_server: str, basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1717,7 +1591,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): +def test_server_backwards_compatibility_no_protocol_version(basic_server: str, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1747,7 +1621,7 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): +async def test_client_crash_handled(basic_server: str, basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init @@ -2219,9 +2093,7 @@ async def message_handler( @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_does_not_mutate_provided_client(basic_server: str, basic_server_url: str) -> None: """Test that streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { @@ -2252,9 +2124,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults(context_aware_server: str) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests @@ -2263,7 +2133,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2283,9 +2156,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_aware_server: str) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", @@ -2294,7 +2165,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269f..98a901207 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,16 @@ """Common test utilities for MCP server tests.""" +import contextlib +import gc import socket +import threading import time +import warnings +from collections.abc import Generator +from typing import Literal + +import uvicorn +from starlette.types import ASGIApp def wait_for_server(port: int, timeout: float = 20.0) -> None: @@ -29,3 +38,59 @@ def wait_for_server(port: int, timeout: float = 20.0) -> None: # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover + + +@contextlib.contextmanager +def run_server_in_thread(app: ASGIApp, lifespan: Literal["auto", "on", "off"] = "on") -> Generator[str, None, None]: + """Run a Starlette/ASGI app in a uvicorn server on a background thread. + + Uses `port=0` so the kernel atomically assigns an available port, eliminating + the TOCTOU port-allocation race that affects subprocess-based fixtures. The + actual bound port is read back from the server's socket after binding. + + Unlike multiprocessing, this runs in-process so: + - No port race (port=0 is assigned atomically at bind time) + - No pickling of app/state (the app runs in the same process) + - Faster startup (no fork/exec overhead) + - Works with both asyncio and trio test backends (uvicorn runs its own + asyncio loop in the thread; uvicorn skips signal handlers automatically + when not on the main thread) + + Args: + app: The ASGI application to serve. + lifespan: uvicorn lifespan mode — "on" to run app lifespan events, + "off" to skip them (default "on"). + + Yields: + Base URL of the running server (e.g., "http://127.0.0.1:54321"). + """ + config = uvicorn.Config(app=app, host="127.0.0.1", port=0, log_level="error", lifespan=lifespan) + server = uvicorn.Server(config=config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for uvicorn to bind and start accepting connections + start_time = time.time() + while not server.started: + if time.time() - start_time > 20.0: # pragma: no cover + raise TimeoutError("uvicorn server did not start within 20 seconds") + time.sleep(0.01) + + # Read back the kernel-assigned port from the bound socket + port = server.servers[0].sockets[0].getsockname()[1] + try: + yield f"http://127.0.0.1:{port}" + finally: + server.should_exit = True + thread.join(timeout=5) + # When uvicorn shuts down with in-flight SSE connections, the server + # cancels request handlers mid-operation. SseServerTransport's internal + # memory streams may not get their `finally` cleanup run before GC, + # causing ResourceWarnings. These are artifacts of test abrupt-disconnect + # patterns (open SSE stream → check status → exit without consuming), + # not bugs. Force GC here and suppress the warnings so they don't leak + # into the next test's PytestUnraisableExceptionWarning collector. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect() From 663f0b561c6b3dad67db7565c42013f489776224 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:47:08 +0000 Subject: [PATCH 2/7] chore: downgrade stale no-cover pragmas to lax no cover With server-side test code now running in-thread (same process), coverage tracks code paths that were previously only hit in subprocesses. Several # pragma: no cover annotations now cover lines that ARE covered, failing the strict-no-cover check. Downgrade these to # pragma: lax no cover, which permits partial coverage without failing. --- src/mcp/server/session.py | 2 +- src/mcp/server/sse.py | 4 +-- src/mcp/server/streamable_http.py | 46 ++++++++++++++-------------- src/mcp/server/transport_security.py | 18 +++++------ tests/shared/test_streamable_http.py | 18 +++++------ 5 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..f95b1a74a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -222,7 +222,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: lax no cover """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..cab122443 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,7 +116,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: lax no cover if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -195,7 +195,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: lax no cover logger.debug("Handling POST message") request = Request(scope, receive) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..62140cd36 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -177,7 +177,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: # pragma: lax no cover """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -205,7 +205,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: # pragma: lax no cover """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -240,10 +240,10 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: # pragma: lax no cover self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: # pragma: lax no cover self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -291,7 +291,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: # pragma: lax no cover response_headers.update(headers) if self.mcp_session_id: @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: # pragma: lax no cover event_data["id"] = event_message.event_id return event_data @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # pragma: lax no cover # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: # pragma: lax no cover await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: # pragma: lax no cover response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): # pragma: lax no cover return # For notifications and responses only, return 202 Accepted @@ -659,7 +659,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: # pragma: lax no cover response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -667,11 +667,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): # pragma: lax no cover return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: lax no cover await self._replay_events(last_event_id, request, send) return @@ -685,7 +685,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: # pragma: lax no cover response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -714,7 +714,7 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE writer") finally: logger.debug("Closing standalone SSE writer") @@ -791,7 +791,7 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: lax no cover """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: # pragma: lax no cover response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: # pragma: lax no cover protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: lax no cover supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -865,7 +865,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: lax no cover """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. @@ -991,7 +991,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( # pragma: lax no cover session_message.metadata is not None and isinstance( session_message.metadata, @@ -1015,10 +1015,10 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: # pragma: lax no cover logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..e3009ae62 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -104,13 +104,13 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res return None # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + host = request.headers.get("host") # pragma: lax no cover + if not self._validate_host(host): # pragma: lax no cover + return Response("Invalid Host header", status_code=421) # pragma: lax no cover # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + origin = request.headers.get("origin") # pragma: lax no cover + if not self._validate_origin(origin): # pragma: lax no cover + return Response("Invalid Origin header", status_code=403) # pragma: lax no cover - return None # pragma: no cover + return None # pragma: lax no cover diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 344deab4f..374fc5fbb 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -104,7 +104,7 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( # pragma: lax no cover self, last_event_id: EventId, send_callback: EventCallback, @@ -140,11 +140,11 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: lax no cover yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) @@ -159,7 +159,7 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -224,7 +224,7 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover +async def _handle_call_tool( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: CallToolRequestParams ) -> CallToolResult: name = params.name @@ -378,7 +378,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: # pragma: lax no cover return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -392,7 +392,7 @@ def create_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover +) -> Starlette: # pragma: lax no cover """Create a Starlette application for testing using the session manager. Args: @@ -1365,7 +1365,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( # pragma: lax no cover ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1390,7 +1390,7 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover +async def _handle_context_call_tool( # pragma: lax no cover ctx: ServerRequestContext, params: CallToolRequestParams ) -> CallToolResult: name = params.name From 09ad2b601f45ce0ff9d9ce84b818277dc4771814 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 10:50:42 +0000 Subject: [PATCH 3/7] perf(tests): skip thread.join in run_server_in_thread teardown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit uvicorn polls should_exit every 0.1s in its main_loop and adds another 0.1s sleep in shutdown(), so thread.join() blocks ~200ms per fixture teardown. With 60+ fixture-using tests in test_streamable_http.py, that was ~12s of pure shutdown latency. The thread is a daemon, so signaling exit and moving on is safe — the interpreter reaps it. The socket is still closed by uvicorn's shutdown, releasing the port. Sequential runtime for affected files: 24s → 13s (old multiprocessing baseline was 16s). Parallel (-n 14) runtime unchanged at ~4.5s. --- tests/test_helpers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 98a901207..e800e4278 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -83,7 +83,12 @@ def run_server_in_thread(app: ASGIApp, lifespan: Literal["auto", "on", "off"] = yield f"http://127.0.0.1:{port}" finally: server.should_exit = True - thread.join(timeout=5) + server.force_exit = True + # Don't block on thread.join() — uvicorn polls should_exit every 0.1s and + # its shutdown() adds another 0.1s, totaling ~200ms teardown latency per + # fixture. The thread is a daemon so it will be reaped by the interpreter; + # we just signal exit and move on. The socket is closed by uvicorn's + # shutdown regardless, so the port is released for the next test. # When uvicorn shuts down with in-flight SSE connections, the server # cancels request handlers mid-operation. SseServerTransport's internal # memory streams may not get their `finally` cleanup run before GC, From e44156d7d185ca117aa215316cc1d13e17ae5510 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 11:21:16 +0000 Subject: [PATCH 4/7] perf(tests): module-scope server fixtures, ASGI for context-aware tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bulk of test_streamable_http.py's runtime was fixture overhead: 46 function-scoped uvicorn starts at ~39ms each = 1.8s of pure server lifecycle for tests that are mostly stateless HTTP validation. - Make basic_server, json_response_server module-scoped. The session manager keys sessions by ID; each test uses fresh IDs so there is no cross-contamination. One server now handles ~30 validation tests. - Convert the 5 context_aware_server tests (header propagation checks) and initialized_client_session to ASGITransport via a new asgi_client helper. These test that headers flow from ASGI scope → ctx.request, which ASGITransport exercises directly. - Collapse repeated header-echo boilerplate into _echo_headers(). 53 non-reconnection tests: 3.59s → 1.53s. Setup+teardown 1.77s → 0.20s. Remaining call time is legitimate: ~15 real-TCP MCP sessions (init + request + terminate ≈ 60ms each) and one event_server test with a sleep-polling loop (out of scope, tracked separately). --- src/mcp/server/streamable_http.py | 2 +- tests/shared/test_streamable_http.py | 237 +++++++++++---------------- 2 files changed, 101 insertions(+), 138 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 62140cd36..eec0c56f5 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -633,7 +633,7 @@ async def sse_writer(): # pragma: lax no cover finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 374fc5fbb..23657f2ee 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,7 +7,7 @@ import json import time -from collections.abc import AsyncIterator, Generator +from collections.abc import AsyncGenerator, AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any @@ -405,7 +405,8 @@ def create_app( # Create the session manager security_settings = TransportSecuritySettings( - allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + allowed_hosts=["127.0.0.1:*", "localhost:*", "localhost"], + allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://localhost"], ) session_manager = StreamableHTTPSessionManager( app=server, @@ -427,8 +428,14 @@ def create_app( return app -# Test fixtures — uvicorn in a background thread with port=0 (no port races) -@pytest.fixture +# Test fixtures — uvicorn in a background thread with port=0 (no port races). +# +# basic_server and json_response_server are module-scoped: one server instance +# handles ~30 tests. StreamableHTTPSessionManager accumulates sessions in a dict +# keyed by session-id; each test uses distinct session IDs so there is no +# cross-contamination. Tests that terminate sessions or crash clients leave +# entries in the dict, but subsequent tests simply create new sessions. +@pytest.fixture(scope="module") def basic_server() -> Generator[str, None, None]: """Start a basic server. Yields the server URL.""" with run_server_in_thread(create_app()) as url: @@ -451,15 +458,15 @@ def event_store() -> SimpleEventStore: def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: """Start a server with event store and retry_interval enabled. - Yields (event_store, server_url). Unlike the old multiprocessing fixture, the - event_store is now the SAME object used by the server (same process), so tests - can inspect server-side state directly if needed. + Yields (event_store, server_url). The event_store is the same object used by + the server (same process), so tests can inspect server-side state directly. + Function-scoped because the reconnection tests depend on event_store state. """ with run_server_in_thread(create_app(event_store=event_store, retry_interval=500)) as url: yield event_store, url -@pytest.fixture +@pytest.fixture(scope="module") def json_response_server() -> Generator[str, None, None]: """Start a server with JSON response enabled. Yields the server URL.""" with run_server_in_thread(create_app(is_json_response_enabled=True)) as url: @@ -472,6 +479,23 @@ def json_server_url(json_response_server: str) -> str: return json_response_server +@asynccontextmanager +async def asgi_client(app: Starlette, **client_kwargs: Any) -> AsyncGenerator[httpx.AsyncClient, None]: + """Run a Starlette app in-process via ASGITransport. + + Manages the app lifespan and yields an httpx.AsyncClient wired to the app. + No threads, no sockets — requests call the ASGI app directly in the same + event loop. Use when tests only need POST→response cycles (not the + long-lived GET stream, which would deadlock since ASGITransport buffers + the full response body before returning). + """ + async with app.router.lifespan_context(app): + transport = httpx.ASGITransport(app=app) + client_kwargs.setdefault("follow_redirects", True) + async with httpx.AsyncClient(transport=transport, **client_kwargs) as client: + yield client + + # Basic request validation tests def test_accept_header_validation(basic_server: str, basic_server_url: str): """Test that Accept header is properly validated.""" @@ -947,19 +971,14 @@ def test_get_validation(basic_server: str, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: str, basic_server_url: str): # pragma: no cover - """Create test client matching the SSE test pattern.""" - async with httpx.AsyncClient(base_url=basic_server_url) as client: - yield client - - -@pytest.fixture -async def initialized_client_session(basic_server: str, basic_server_url: str): - """Create initialized StreamableHTTP client session.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session +async def initialized_client_session() -> AsyncGenerator[ClientSession, None]: + """Create an initialized StreamableHTTP client session over ASGI (in-process).""" + async with asgi_client(create_app()) as http_client: + # Use localhost so create_app's TransportSecuritySettings allowlist accepts it + async with streamable_http_client("http://localhost/mcp", http_client=http_client) as (rs, ws): + async with ClientSession(rs, ws) as session: + await session.initialize() + yield session @pytest.mark.anyio @@ -1434,76 +1453,60 @@ def create_context_aware_app() -> Starlette: ) -@pytest.fixture -def context_aware_server() -> Generator[str, None, None]: - """Start the context-aware server. Yields the server URL.""" - with run_server_in_thread(create_context_aware_app()) as url: - yield url +@asynccontextmanager +async def context_aware_session( + **client_kwargs: Any, +) -> AsyncGenerator[ClientSession, None]: + """Initialized ClientSession against an in-process context-aware server (ASGI).""" + async with asgi_client(create_context_aware_app(), **client_kwargs) as http_client: + async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): + async with ClientSession(rs, ws) as session: + await session.initialize() + yield session + + +async def _echo_headers(session: ClientSession) -> dict[str, Any]: + tool_result = await session.call_tool("echo_headers", {}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + return json.loads(tool_result.content[0].text) @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: str) -> None: +async def test_streamablehttp_request_context_propagation() -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - - async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: # pragma: no branch - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "ContextAwareServer" - - # Call the tool that echoes headers back - tool_result = await session.call_tool("echo_headers", {}) - - # Parse the JSON response - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify headers were propagated - assert headers_data.get("authorization") == "Bearer test-token" - assert headers_data.get("x-custom-header") == "test-value" - assert headers_data.get("x-trace-id") == "trace-123" + async with context_aware_session(headers=custom_headers) as session: + headers_data = await _echo_headers(session) + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: str) -> None: +async def test_streamablehttp_request_context_isolation() -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" + # Each client hits a fresh ASGI app instance, so isolation is guaranteed at + # the app level. What we're really testing is that the server plumbs headers + # from the ASGI scope into ctx.request correctly per-request, not that one + # connection can't see another's state (the session manager already keys by + # session-id). contexts: list[dict[str, Any]] = [] - - # Create multiple clients with different headers for i in range(3): headers = { "X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}", "Authorization": f"Bearer token-{i}", } + async with context_aware_session(headers=headers) as session: + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) + assert isinstance(tool_result.content[0], TextContent) + contexts.append(json.loads(tool_result.content[0].text)) - async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: # pragma: no branch - await session.initialize() - - # Call the tool that echoes context - tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - context_data = json.loads(tool_result.content[0].text) - contexts.append(context_data) - - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -1513,24 +1516,17 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: st @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: str): +async def test_client_includes_protocol_version_header_after_init(): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{context_aware_server}/mcp") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Initialize and get the negotiated version - init_result = await session.initialize() - negotiated_version = init_result.protocol_version - - # Call a tool that echoes headers to verify the header is present - tool_result = await session.call_tool("echo_headers", {}) - - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) + async with asgi_client(create_context_aware_app()) as http_client: + async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): + async with ClientSession(rs, ws) as session: + init_result = await session.initialize() + negotiated_version = init_result.protocol_version - # Verify protocol version header is present - assert "mcp-protocol-version" in headers_data - assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + headers_data = await _echo_headers(session) + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version def test_server_validates_protocol_version_header(basic_server: str, basic_server_url: str): @@ -2124,69 +2120,36 @@ async def test_streamable_http_client_does_not_mutate_provided_client(basic_serv @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults(context_aware_server: str) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults() -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" - # httpx.AsyncClient has default "accept: */*" header - # We need to verify that our MCP accept header overrides it in actual requests - - async with httpx.AsyncClient(follow_redirects=True) as client: - # Verify client has default accept header + # httpx.AsyncClient sets "accept: */*" by default — verify the MCP client + # overrides it per-request rather than relying on the client default. + async with asgi_client(create_context_aware_app()) as client: assert client.headers.get("accept") == "*/*" - - async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + async with streamable_http_client("http://testserver/mcp", http_client=client) as (rs, ws): + async with ClientSession(rs, ws) as session: await session.initialize() - - # Use echo_headers tool to see what headers the server actually received - tool_result = await session.call_tool("echo_headers", {}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify MCP protocol headers were sent (not httpx defaults) - assert "accept" in headers_data + headers_data = await _echo_headers(session) assert "application/json" in headers_data["accept"] assert "text/event-stream" in headers_data["accept"] - - assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_aware_server: str) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers() -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", "X-Request-Id": "req-123", "Authorization": "Bearer test-token", } - - async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( - read_stream, - write_stream, - ): - async with ClientSession(read_stream, write_stream) as session: # pragma: no branch - await session.initialize() - - # Use echo_headers tool to verify both custom and MCP headers are present - tool_result = await session.call_tool("echo_headers", {}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - headers_data = json.loads(tool_result.content[0].text) - - # Verify custom headers are present - assert headers_data.get("x-custom-header") == "custom-value" - assert headers_data.get("x-request-id") == "req-123" - assert headers_data.get("authorization") == "Bearer test-token" - - # Verify MCP protocol headers are also present - assert "accept" in headers_data - assert "application/json" in headers_data["accept"] - assert "text/event-stream" in headers_data["accept"] - - assert "content-type" in headers_data - assert headers_data["content-type"] == "application/json" + async with context_aware_session(headers=custom_headers) as session: + headers_data = await _echo_headers(session) + # Custom headers preserved + assert headers_data.get("x-custom-header") == "custom-value" + assert headers_data.get("x-request-id") == "req-123" + assert headers_data.get("authorization") == "Bearer test-token" + # MCP protocol headers also present + assert "application/json" in headers_data["accept"] + assert "text/event-stream" in headers_data["accept"] + assert headers_data["content-type"] == "application/json" From 6b77efe19a0e7a7d4081102fa8d5dca2882ec848 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:13:42 +0000 Subject: [PATCH 5/7] fix(tests): add no-branch pragma for nested async with ClientSession On Python 3.11+, coverage.py fails to track the ->exit arc from the innermost async with body when it yields inside 3+ nested async with blocks (the async generator frame unwinds through __aexit__ chains differently than 3.10). The codebase already uses # pragma: no branch for this pattern on existing ClientSession fixtures; apply it to the new ASGI-backed fixtures added in the previous commit. Fixes the 5 ->exit branch misses reported on CI 3.11-3.14: test_http_unicode.py:124->exit test_streamable_http.py:979->exit, 1463->exit, 1523->exit, 2130->exit --- tests/client/test_http_unicode.py | 4 +++- tests/shared/test_streamable_http.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index ee105505f..7e02798c8 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -121,7 +121,9 @@ async def unicode_session() -> AsyncGenerator[ClientSession, None]: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, follow_redirects=True) as http_client: async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): - async with ClientSession(rs, ws) as session: + async with ClientSession(rs, ws) as session: # pragma: no branch + # ^ coverage.py misses the ->exit arc on 3.11+ when yield is + # nested inside multiple async with blocks await session.initialize() yield session diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 23657f2ee..06803d036 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -976,7 +976,7 @@ async def initialized_client_session() -> AsyncGenerator[ClientSession, None]: async with asgi_client(create_app()) as http_client: # Use localhost so create_app's TransportSecuritySettings allowlist accepts it async with streamable_http_client("http://localhost/mcp", http_client=http_client) as (rs, ws): - async with ClientSession(rs, ws) as session: + async with ClientSession(rs, ws) as session: # pragma: no branch await session.initialize() yield session @@ -1460,7 +1460,7 @@ async def context_aware_session( """Initialized ClientSession against an in-process context-aware server (ASGI).""" async with asgi_client(create_context_aware_app(), **client_kwargs) as http_client: async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): - async with ClientSession(rs, ws) as session: + async with ClientSession(rs, ws) as session: # pragma: no branch await session.initialize() yield session @@ -1520,7 +1520,7 @@ async def test_client_includes_protocol_version_header_after_init(): """Test that client includes mcp-protocol-version header after initialization.""" async with asgi_client(create_context_aware_app()) as http_client: async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): - async with ClientSession(rs, ws) as session: + async with ClientSession(rs, ws) as session: # pragma: no branch init_result = await session.initialize() negotiated_version = init_result.protocol_version @@ -2127,7 +2127,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults() -> None: async with asgi_client(create_context_aware_app()) as client: assert client.headers.get("accept") == "*/*" async with streamable_http_client("http://testserver/mcp", http_client=client) as (rs, ws): - async with ClientSession(rs, ws) as session: + async with ClientSession(rs, ws) as session: # pragma: no branch await session.initialize() headers_data = await _echo_headers(session) assert "application/json" in headers_data["accept"] From 319dd7c5a0cedfc8eb3a2fa081a9a6711c66d6ad Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:30:20 +0000 Subject: [PATCH 6/7] fix(tests): suppress uvicorn's 3.14 DeprecationWarning in thread, rejoin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two CI failures, same root cause of running uvicorn in-process instead of a subprocess: 1. Python 3.14: asyncio.iscoroutinefunction is deprecated. uvicorn's config.load() calls it, emitting a DeprecationWarning. Pytest's filterwarnings=error applies to the uvicorn thread (same process), so the warning becomes an exception that kills the thread before server.started is set. The old subprocess fixtures masked this — subprocesses don't inherit pytest's warning filters. Fix: wrap server.run in a catch_warnings block that ignores DeprecationWarning. Also detect a dead thread in the startup poll so we fail fast rather than spinning for 20s on future thread crashes. 2. Windows 3.12/3.13: Proactor pipe transports leak as unclosed sockets when the daemon thread is reaped before its event loop finishes shutdown. The prior perf commit removed thread.join() to save ~200ms per fixture, but that doesn't hold on Windows. Bring back the join. Module-scoped fixtures absorb most of the cost anyway (one server per module instead of per-test). --- tests/test_helpers.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e800e4278..51c893dac 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -67,12 +67,28 @@ def run_server_in_thread(app: ASGIApp, lifespan: Literal["auto", "on", "off"] = config = uvicorn.Config(app=app, host="127.0.0.1", port=0, log_level="error", lifespan=lifespan) server = uvicorn.Server(config=config) - thread = threading.Thread(target=server.run, daemon=True) + def _run() -> None: + # On Python 3.14, asyncio.iscoroutinefunction raises a DeprecationWarning + # when uvicorn's config.load() calls it. Pytest's `filterwarnings = error` + # applies to this thread (same process), so the warning becomes an + # exception that kills the thread before server.started is set. The old + # subprocess fixtures masked this because subprocess doesn't inherit + # pytest's warning filters. catch_warnings is thread-local since 3.12; + # on 3.10/3.11 it's process-global but the warning doesn't fire there. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + server.run() + + thread = threading.Thread(target=_run, daemon=True) thread.start() - # Wait for uvicorn to bind and start accepting connections + # Wait for uvicorn to bind and start accepting connections. If the thread + # dies (e.g., config.load() raised), bail early with the thread's exception + # rather than spinning for 20s. start_time = time.time() while not server.started: + if not thread.is_alive(): # pragma: no cover + raise RuntimeError("uvicorn thread exited before server started") if time.time() - start_time > 20.0: # pragma: no cover raise TimeoutError("uvicorn server did not start within 20 seconds") time.sleep(0.01) @@ -82,20 +98,18 @@ def run_server_in_thread(app: ASGIApp, lifespan: Literal["auto", "on", "off"] = try: yield f"http://127.0.0.1:{port}" finally: + # force_exit skips uvicorn's graceful connection drain. We still join + # the thread: on Windows, Proactor pipe transports leak as unclosed + # sockets if the daemon thread is reaped before its event loop finishes + # shutdown. Module-scoped fixtures make the ~200ms join cost acceptable + # (one server per module instead of one per test). server.should_exit = True server.force_exit = True - # Don't block on thread.join() — uvicorn polls should_exit every 0.1s and - # its shutdown() adds another 0.1s, totaling ~200ms teardown latency per - # fixture. The thread is a daemon so it will be reaped by the interpreter; - # we just signal exit and move on. The socket is closed by uvicorn's - # shutdown regardless, so the port is released for the next test. - # When uvicorn shuts down with in-flight SSE connections, the server - # cancels request handlers mid-operation. SseServerTransport's internal - # memory streams may not get their `finally` cleanup run before GC, - # causing ResourceWarnings. These are artifacts of test abrupt-disconnect - # patterns (open SSE stream → check status → exit without consuming), - # not bugs. Force GC here and suppress the warnings so they don't leak - # into the next test's PytestUnraisableExceptionWarning collector. + thread.join(timeout=5) + # SseServerTransport's internal memory streams may be GC'd without + # finalizers when uvicorn cancels in-flight SSE handlers. Force GC here + # and suppress the ResourceWarnings so they don't leak into the next + # test's PytestUnraisableExceptionWarning collector. with warnings.catch_warnings(): warnings.simplefilter("ignore", ResourceWarning) gc.collect() From 4c12359aa438c101651a12112bf1d5af750a908b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:37:20 +0000 Subject: [PATCH 7/7] fix(tests): 3.14 phantom branch arcs, Windows Proactor transport leak MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two more cases where running in-process exposes issues subprocess hid: 1. Python 3.14 coverage: phantom ->exit arcs on nested async with CMs. The WITH_EXCEPT_START suppression-check branch (did __aexit__ suppress?) gets misattributed through the exception table to outer CM lines when the inner body yields. Previously the SSE security tests never ran on 3.14 (uvicorn thread crashed before server started) so this didn't surface. Add # pragma: no branch to the two specific CM lines coverage flags. 2. Windows 3.13 lowest-direct: Proactor socket transports from MCP client connections don't always close before GC when clients disconnect abruptly from the module-scoped server. The transport __del__ ResourceWarning is collected by pytest's unraisable hook during a later test. Older httpx (lowest-direct) has worse transport cleanup than the locked version. Subprocess-based tests hid this — resource warnings in a subprocess die with the subprocess. Add the same PytestUnraisableExceptionWarning filter that test_sse.py uses. --- tests/client/test_http_unicode.py | 11 ++++++++--- tests/server/test_sse_security.py | 4 +++- tests/shared/test_streamable_http.py | 9 +++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index 7e02798c8..6e2274f05 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -120,10 +120,15 @@ async def unicode_session() -> AsyncGenerator[ClientSession, None]: async with app.router.lifespan_context(app): transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, follow_redirects=True) as http_client: - async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): + async with streamable_http_client( # pragma: no branch + "http://testserver/mcp", http_client=http_client + ) as (rs, ws): + # ^ coverage.py on 3.11+ misses phantom ->exit arcs from nested + # async with CMs when the innermost body yields (async generator + # frame unwinds through __aexit__ chains). On 3.14 the + # WITH_EXCEPT_START suppression-check branch is misattributed + # through the exception table to the *outer* CM line too. async with ClientSession(rs, ws) as session: # pragma: no branch - # ^ coverage.py misses the ->exit arc on 3.11+ when yield is - # nested inside multiple async with blocks await session.initialize() yield session diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index bd9a174cd..d3bfe0fe5 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -140,7 +140,9 @@ async def test_sse_security_disabled(): settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) with run_server_in_thread(make_app(settings), lifespan="off") as url: async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"{url}/sse", headers={"Host": "evil.com"}) as response: + async with client.stream( # pragma: no branch + "GET", f"{url}/sse", headers={"Host": "evil.com"} + ) as response: # Should connect successfully even with invalid host assert response.status_code == 200 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 06803d036..c4961319f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -63,6 +63,15 @@ ) from tests.test_helpers import run_server_in_thread +# On Windows, the Proactor event loop's socket transports don't always close +# cleanly before GC when MCP clients disconnect abruptly from the module-scoped +# server (e.g. test_streamable_http_client_session_termination opens and drops +# two clients). The transport __del__ emits ResourceWarning which pytest's +# unraisable collector picks up during a LATER test. These are Windows asyncio +# internals, not bugs in the client code under test. Previously hidden by +# subprocess isolation — subprocess resource warnings die with the subprocess. +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") + # Test constants SERVER_NAME = "test_streamable_http_server" TEST_SESSION_ID = "test-session-id-12345"