diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index d9e9f965b..0229d8745 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse, Response logger = logging.getLogger(__name__) @@ -106,7 +106,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res # Validate Host header host = request.headers.get("host") if not self._validate_host(host): - return Response("Invalid Host header", status_code=421) + return JSONResponse( + { + "error": "host_not_allowed", + "received_host": host, + "configure": "TransportSecuritySettings.allowed_hosts", + }, + status_code=421, + ) # Validate Origin header origin = request.headers.get("origin") diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 6331c2dae..c569e0201 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -367,7 +367,16 @@ async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> No assert [event async for event in ok.aiter_sse()] assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) - assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + assert (bad_host.status_code, bad_host.json()) == snapshot( + ( + 421, + { + "error": "host_not_allowed", + "received_host": "evil.example", + "configure": "TransportSecuritySettings.allowed_hosts", + }, + ) + ) async with mounted_app( Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index ca16d3354..7eb9cb903 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -86,7 +86,11 @@ async def test_sse_security_invalid_host_header() -> None: async with sse_security_client(security_settings) as client: response = await client.get("/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } @pytest.mark.anyio @@ -149,7 +153,11 @@ async def test_sse_security_custom_allowed_hosts() -> None: response = await client.get("/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } @pytest.mark.anyio diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index f13bb4a9b..2a2230a94 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -60,7 +60,11 @@ async def test_streamable_http_security_invalid_host_header() -> None: async with streamable_http_security_client(security_settings) as client: response = await client.post("/", json=_initialize_body(), headers=_base_headers() | {"Host": "evil.com"}) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } @pytest.mark.anyio @@ -121,7 +125,11 @@ async def test_streamable_http_security_get_request() -> None: async with streamable_http_security_client(security_settings) as client: response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "evil.com"}) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } response = await client.get("/", headers={"Accept": "text/event-stream", "Host": "127.0.0.1"}) # An allowed host passes security and fails on session validation instead. diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py index be28980b5..cf9870d49 100644 --- a/tests/server/test_transport_security.py +++ b/tests/server/test_transport_security.py @@ -48,6 +48,20 @@ async def test_validate_request_checks_host_then_origin( assert (None if response is None else response.status_code) == expected +@pytest.mark.anyio +async def test_validate_request_explains_host_rejection() -> None: + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("evil.example", None)) + + assert response is not None + assert response.status_code == 421 + assert response.media_type == "application/json" + assert response.body == ( + b'{"error":"host_not_allowed","received_host":"evil.example",' + b'"configure":"TransportSecuritySettings.allowed_hosts"}' + ) + + @pytest.mark.anyio async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: """With DNS-rebinding protection off, any Host/Origin is accepted."""