diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 0130467d..1fa0340b 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; -use http::{Method, Request, Response, header::ALLOW}; +use http::{HeaderMap, Method, Request, Response, Uri, header::ALLOW}; use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; @@ -48,6 +48,12 @@ pub struct StreamableHttpServerConfig { /// When this token is cancelled, all active sessions are terminated and /// the server stops accepting new requests. pub cancellation_token: CancellationToken, + /// Allowed hostnames for inbound `Host` / `Origin` validation. + /// + /// By default, Streamable HTTP servers only accept loopback hosts to + /// prevent DNS rebinding attacks against locally running servers. Public + /// deployments should override this list with their own hostnames. + pub allowed_hosts: Vec, } impl Default for StreamableHttpServerConfig { @@ -58,10 +64,26 @@ impl Default for StreamableHttpServerConfig { stateful_mode: true, json_response: false, cancellation_token: CancellationToken::new(), + allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()], } } } +impl StreamableHttpServerConfig { + pub fn with_allowed_hosts( + mut self, + allowed_hosts: impl IntoIterator>, + ) -> Self { + self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect(); + self + } + /// Disable allowed hosts. This will allow requests with any `Host` or `Origin` header, which is NOT recommended for public deployments. + pub fn disable_allowed_hosts(mut self) -> Self { + self.allowed_hosts.clear(); + self + } +} + #[expect( clippy::result_large_err, reason = "BoxResponse is intentionally large; matches other handlers in this file" @@ -102,6 +124,99 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box Ok(()) } +fn forbidden_response(message: impl Into) -> BoxResponse { + Response::builder() + .status(http::StatusCode::FORBIDDEN) + .body(Full::new(Bytes::from(message.into())).boxed()) + .expect("valid response") +} + +fn normalize_host(host: &str) -> String { + host.trim_matches('[') + .trim_matches(']') + .to_ascii_lowercase() +} + +fn host_is_allowed(host: &str, allowed_hosts: &[String]) -> bool { + if allowed_hosts.is_empty() { + // If the allowed hosts list is empty, allow all hosts (not recommended). + return true; + } + let normalized = normalize_host(host); + allowed_hosts + .iter() + .any(|allowed| normalize_host(allowed) == normalized) +} + +fn parse_host_header(headers: &HeaderMap) -> Result, BoxResponse> { + let Some(host) = headers.get(http::header::HOST) else { + return Ok(None); + }; + + let host = host + .to_str() + .map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?; + let authority = http::uri::Authority::try_from(host) + .map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?; + Ok(Some(normalize_host(authority.host()))) +} + +fn parse_origin_host(headers: &HeaderMap) -> Result, BoxResponse> { + let Some(origin) = headers.get(http::header::ORIGIN) else { + return Ok(None); + }; + + let origin = origin + .to_str() + .map_err(|_| forbidden_response("Forbidden: Invalid Origin header encoding"))?; + if origin.eq_ignore_ascii_case("null") { + return Err(forbidden_response("Forbidden: Invalid Origin header")); + } + + let uri: Uri = origin + .parse() + .map_err(|_| forbidden_response("Forbidden: Invalid Origin header"))?; + let Some(authority) = uri.authority() else { + return Err(forbidden_response("Forbidden: Invalid Origin header")); + }; + let Some(scheme) = uri.scheme_str() else { + return Err(forbidden_response("Forbidden: Invalid Origin header")); + }; + if !matches!(scheme, "http" | "https") { + return Err(forbidden_response("Forbidden: Invalid Origin header")); + } + + Ok(Some(normalize_host(authority.host()))) +} + +fn validate_dns_rebinding_headers( + headers: &HeaderMap, + config: &StreamableHttpServerConfig, +) -> Result<(), BoxResponse> { + let host = parse_host_header(headers)?; + if let Some(host) = host.as_deref() { + if !host_is_allowed(host, &config.allowed_hosts) { + return Err(forbidden_response("Forbidden: Host header is not allowed")); + } + } + + let origin = parse_origin_host(headers)?; + if let Some(origin) = origin.as_deref() { + if !host_is_allowed(origin, &config.allowed_hosts) { + return Err(forbidden_response( + "Forbidden: Origin header is not allowed", + )); + } + if let Some(host) = host.as_deref() { + if origin != host { + return Err(forbidden_response("Forbidden: Origin does not match Host")); + } + } + } + + Ok(()) +} + /// # Streamable HTTP server /// /// An HTTP service that implements the @@ -251,6 +366,9 @@ where B: Body + Send + 'static, B::Error: Display, { + if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) { + return response; + } let method = request.method().clone(); let allowed_methods = match self.config.stateful_mode { true => "GET, POST, DELETE", diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index 7d4316d3..e092834b 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -870,3 +870,86 @@ fn test_protocol_version_utilities() { assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26)); assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18)); } + +/// Integration test: Verify server validates Host and Origin headers for DNS rebinding protection +#[tokio::test] +#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] +async fn test_server_validates_host_and_origin_headers() { + use std::sync::Arc; + + use bytes::Bytes; + use http::{Method, Request, header::CONTENT_TYPE}; + use http_body_util::Full; + use rmcp::{ + handler::server::ServerHandler, + model::{ServerCapabilities, ServerInfo}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }; + use serde_json::json; + + #[derive(Clone)] + struct TestHandler; + + impl ServerHandler for TestHandler { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().build()) + } + } + + let service = StreamableHttpService::new( + || Ok(TestHandler), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default(), + ); + + let init_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let allowed_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") + .header("Origin", "http://localhost:8080") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(allowed_request).await; + assert_eq!(response.status(), http::StatusCode::OK); + + let bad_host_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "attacker.example") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(bad_host_request).await; + assert_eq!(response.status(), http::StatusCode::FORBIDDEN); + + let bad_origin_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") + .header("Origin", "http://attacker.example") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(bad_origin_request).await; + assert_eq!(response.status(), http::StatusCode::FORBIDDEN); +}