From 64e3f2a13f295ffe334bfa6f65d5fa2fee6bf7c2 Mon Sep 17 00:00:00 2001 From: David Stern Date: Tue, 19 May 2026 23:11:50 -0400 Subject: [PATCH 1/2] Move MCP "runtime" logic to `mcp` crate. --- Cargo.lock | 15 + app/src/ai/agent/task.rs | 2 +- app/src/ai/mcp/gallery.rs | 2 +- app/src/ai/mcp/http_client.rs | 29 - app/src/ai/mcp/mod.rs | 6 +- app/src/ai/mcp/parsing.rs | 10 +- app/src/ai/mcp/templatable_manager.rs | 88 +- app/src/ai/mcp/templatable_manager/native.rs | 687 ++++--------- app/src/ai/mcp/templatable_manager/utils.rs | 79 -- .../ai/mcp/templatable_manager/utils_tests.rs | 309 ------ .../ai/skills/file_watchers/skill_watcher.rs | 2 +- app/src/auth/user.rs | 2 +- app/src/autoupdate/mod.rs | 6 +- app/src/server/mod.rs | 1 - app/src/server/network_logging.rs | 2 +- app/src/server/server_api/auth.rs | 4 +- .../in_band_command_executor.rs | 2 +- .../session/command_executor/tmux_executor.rs | 2 +- app/src/workspace/view/vertical_tabs.rs | 2 +- crates/mcp/Cargo.toml | 21 +- crates/mcp/src/lib.rs | 2 + .../mcp/src}/oauth.rs | 165 +--- .../mcp/src}/oauth_tests.rs | 0 crates/mcp/src/runtime.rs | 909 ++++++++++++++++++ .../warp_core/src}/datetime_ext.rs | 0 crates/warp_core/src/lib.rs | 1 + 26 files changed, 1197 insertions(+), 1151 deletions(-) delete mode 100644 app/src/ai/mcp/http_client.rs delete mode 100644 app/src/ai/mcp/templatable_manager/utils.rs delete mode 100644 app/src/ai/mcp/templatable_manager/utils_tests.rs rename {app/src/ai/mcp/templatable_manager => crates/mcp/src}/oauth.rs (71%) rename {app/src/ai/mcp/templatable_manager => crates/mcp/src}/oauth_tests.rs (100%) create mode 100644 crates/mcp/src/runtime.rs rename {app/src/server => crates/warp_core/src}/datetime_ext.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 8f2d2421d8..877e13b90c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7885,16 +7885,31 @@ dependencies = [ name = "mcp" version = "0.1.0" dependencies = [ + "anyhow", + "async-channel", + "async-trait", + "cfg-if", + "cloud_object_models", "futures", "http 1.4.0", + "log", + "oauth2", "pin-project-lite", "reqwest", "rmcp", + "serde", "serde_json", + "simple_logger", "sse-stream", "thiserror 2.0.17", "tokio", "tracing", + "url", + "uuid", + "warp_core", + "warpui", + "warpui_extras", + "windows 0.62.2", ] [[package]] diff --git a/app/src/ai/agent/task.rs b/app/src/ai/agent/task.rs index 5f3b36d135..8b6a89ee07 100644 --- a/app/src/ai/agent/task.rs +++ b/app/src/ai/agent/task.rs @@ -14,6 +14,7 @@ use itertools::Itertools; use prost_types::FieldMask; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use warp_core::datetime_ext::DateTimeExt; use warp_multi_agent_api::{ self as api, message::{tool_call::subagent::Metadata, Message}, @@ -24,7 +25,6 @@ use crate::{ agent::comment::CodeReview, document::ai_document_model::{AIDocumentId, AIDocumentVersion}, }, - server::datetime_ext::DateTimeExt, terminal::model::block::BlockId, AIAgentTodoList, }; diff --git a/app/src/ai/mcp/gallery.rs b/app/src/ai/mcp/gallery.rs index a7319578c1..773f5d7cb8 100644 --- a/app/src/ai/mcp/gallery.rs +++ b/app/src/ai/mcp/gallery.rs @@ -4,9 +4,9 @@ use crate::ai::mcp::templatable::{ GalleryData, JsonTemplate, TemplatableMCPServer, TemplateVariable, }; use crate::server::cloud_objects::update_manager::{UpdateManager, UpdateManagerEvent}; -use crate::server::datetime_ext::DateTimeExt; use chrono::DateTime; use uuid::Uuid; +use warp_core::datetime_ext::DateTimeExt; use warpui::{Entity, ModelContext, SingletonEntity}; #[derive(Clone, Debug)] diff --git a/app/src/ai/mcp/http_client.rs b/app/src/ai/mcp/http_client.rs deleted file mode 100644 index 47d0be706b..0000000000 --- a/app/src/ai/mcp/http_client.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::collections::HashMap; - -use reqwest::header::HeaderMap; - -type ReqwestHttpTransport = rmcp::transport::StreamableHttpClientTransport; - -/// Builds a `HeaderMap` from a `HashMap` of user-provided headers. -/// -/// Invalid header names or values are skipped. -fn build_header_map(headers: &HashMap) -> HeaderMap { - headers.try_into().unwrap_or_default() -} - -/// Builds a reqwest client with custom headers for MCP HTTP/SSE connections. -#[allow(clippy::result_large_err)] -pub fn build_client_with_headers( - headers: &HashMap, -) -> Result { - let header_map = build_header_map(headers); - - reqwest::Client::builder() - .default_headers(header_map) - .build() - .map_err(|e| { - rmcp::RmcpError::transport_creation::(format!( - "Failed to build client with headers: {e}", - )) - }) -} diff --git a/app/src/ai/mcp/mod.rs b/app/src/ai/mcp/mod.rs index 4ddd97c210..86be565444 100644 --- a/app/src/ai/mcp/mod.rs +++ b/app/src/ai/mcp/mod.rs @@ -1,9 +1,9 @@ #[cfg(not(target_family = "wasm"))] -use crate::server::datetime_ext::DateTimeExt; -#[cfg(not(target_family = "wasm"))] use chrono::DateTime; use std::collections::HashMap; use std::path::{Path, PathBuf}; +#[cfg(not(target_family = "wasm"))] +use warp_core::datetime_ext::DateTimeExt; #[cfg(not(target_family = "wasm"))] use crate::persistence::model::MCPEnvironmentVariables; @@ -67,8 +67,6 @@ pub use templatable_installation::{VariableType, VariableValue}; pub mod parsing; pub use parsing::ParsedTemplatableMCPServerResult; #[cfg(not(target_family = "wasm"))] -pub mod http_client; -#[cfg(not(target_family = "wasm"))] pub mod reconnecting_peer; impl CloudObjectUuid for MCPServer { diff --git a/app/src/ai/mcp/parsing.rs b/app/src/ai/mcp/parsing.rs index 06f0504aee..f97fd8058e 100644 --- a/app/src/ai/mcp/parsing.rs +++ b/app/src/ai/mcp/parsing.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use chrono::DateTime; use handlebars::{get_arguments, render_template}; +use warp_core::datetime_ext::DateTimeExt; #[cfg(feature = "local_fs")] use serde::Deserialize; @@ -9,12 +10,9 @@ use serde::Deserialize; #[cfg(feature = "local_fs")] use crate::ai::mcp::{JSONMCPServer, JSONTransportType}; -use crate::{ - ai::mcp::{ - templatable::{JsonTemplate, TemplatableMCPServer, TemplateVariable}, - templatable_installation::{TemplatableMCPServerInstallation, VariableType, VariableValue}, - }, - server::datetime_ext::DateTimeExt, +use crate::ai::mcp::{ + templatable::{JsonTemplate, TemplatableMCPServer, TemplateVariable}, + templatable_installation::{TemplatableMCPServerInstallation, VariableType, VariableValue}, }; /// Normalize MCP JSON input to ensure it has a server name wrapper. diff --git a/app/src/ai/mcp/templatable_manager.rs b/app/src/ai/mcp/templatable_manager.rs index 9131850ef5..4cf9c17bb4 100644 --- a/app/src/ai/mcp/templatable_manager.rs +++ b/app/src/ai/mcp/templatable_manager.rs @@ -1,20 +1,15 @@ #[cfg(not(target_family = "wasm"))] mod native; -#[cfg(not(target_family = "wasm"))] -pub use native::McpIntegration; -#[cfg(not(target_family = "wasm"))] -mod oauth; -#[cfg(not(target_family = "wasm"))] -mod utils; #[cfg(target_family = "wasm")] mod wasm; -#[cfg(all(test, not(target_family = "wasm")))] -mod utils_tests; - #[cfg(not(target_family = "wasm"))] use diesel::SqliteConnection; #[cfg(not(target_family = "wasm"))] +use mcp::oauth; +#[cfg(not(target_family = "wasm"))] +pub use native::McpIntegration; +#[cfg(not(target_family = "wasm"))] use parking_lot::Mutex; use std::collections::{HashMap, HashSet}; #[cfg(not(target_family = "wasm"))] @@ -25,6 +20,7 @@ use crate::ai::mcp::templatable::CloudTemplatableMCPServer; use crate::ai::mcp::FileBasedMCPManager; use crate::ai::mcp::{templatable_installation::TemplatableMCPServerInstallation, MCPServerState}; use futures_util::stream::AbortHandle; +pub use mcp::runtime::TemplatableMCPServerInfo; use uuid::Uuid; #[cfg(not(target_family = "wasm"))] use warpui::ModelSpawner; @@ -97,48 +93,6 @@ struct SpawnedServerInfo { oauth_result_tx: async_channel::Sender, } -/// Information about a single connected MCP server. -#[cfg_attr(target_family = "wasm", allow(dead_code))] -pub struct TemplatableMCPServerInfo { - name: String, - service: rmcp::service::RunningService< - rmcp::RoleClient, - Box>, - >, - resources: Vec, - tools: Vec, - installation_id: Uuid, - description: Option, - /// Whether the underlying transport uses authentication. - /// - /// TODO(vorporeal): Use this to display a toast when server authentication and connection is complete, and - /// to provide a "log out" button. - #[allow(dead_code)] - is_authenticated_transport: bool, -} - -impl TemplatableMCPServerInfo { - pub fn name(&self) -> &str { - &self.name - } - - pub fn resources(&self) -> &Vec { - &self.resources - } - - pub fn tools(&self) -> &Vec { - &self.tools - } - - pub fn installation_id(&self) -> Uuid { - self.installation_id - } - - pub fn description(&self) -> Option<&str> { - self.description.as_deref() - } -} - /// The current status of the Figma MCP server. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FigmaMcpStatus { @@ -221,13 +175,13 @@ impl TemplatableMCPServerManager { pub fn resources(&self) -> impl Iterator { self.active_servers .values() - .flat_map(|server| server.resources.iter()) + .flat_map(|server| server.resources().iter()) } pub fn tools(&self) -> impl Iterator { self.active_servers .values() - .flat_map(|server| server.tools.iter()) + .flat_map(|server| server.tools().iter()) } /// Returns a reconnecting peer for a server that has the given resource. @@ -241,12 +195,7 @@ impl TemplatableMCPServerManager { let spawner = self.spawner.as_ref()?; self.active_servers .iter() - .find(|(_, server)| { - server - .resources - .iter() - .any(|other_resource| resource.uri == other_resource.uri) - }) + .find(|(_, server)| server.has_resource(resource)) .map(|(installation_uuid, _)| { super::reconnecting_peer::ReconnectingPeer::new(*installation_uuid, spawner.clone()) }) @@ -255,7 +204,7 @@ impl TemplatableMCPServerManager { pub fn tools_for_server(&self, uuid: Uuid) -> Vec { self.active_servers .get(&uuid) - .map(|server| server.tools.clone()) + .map(|server| server.tools().clone()) .unwrap_or_default() } @@ -273,24 +222,21 @@ impl TemplatableMCPServerManager { installation_id: Option, tool_name: &str, ) -> Option> { - let candidates: Box> = + let mut candidates: Box> = if let Some(uuid) = installation_id { Box::new(self.active_servers.get(&uuid).into_iter()) } else { Box::new(self.active_servers.values()) }; - candidates - .flat_map(|server| server.tools.iter()) - .find(|t| t.name == tool_name) - .map(|t| t.input_schema.clone()) + candidates.find_map(|server| server.tool_input_schema(tool_name)) } #[cfg(not(target_family = "wasm"))] pub fn server_from_tool(&self, tool: String) -> Option<&Uuid> { self.active_servers .iter() - .find(|(_, server)| server.tools.iter().any(|t| t.name == tool)) + .find(|(_, server)| server.has_tool(&tool)) .map(|(uuid, _)| uuid) } @@ -300,15 +246,7 @@ impl TemplatableMCPServerManager { pub fn server_from_resource(&self, name: &str, uri: Option<&str>) -> Option<&Uuid> { self.active_servers .iter() - .find(|(_, server)| { - server.resources.iter().any(|r| { - if let Some(uri) = uri { - r.uri == uri - } else { - r.name == name - } - }) - }) + .find(|(_, server)| server.has_resource_name_or_uri(name, uri)) .map(|(uuid, _)| uuid) } diff --git a/app/src/ai/mcp/templatable_manager/native.rs b/app/src/ai/mcp/templatable_manager/native.rs index f664e737e4..6c0440f3eb 100644 --- a/app/src/ai/mcp/templatable_manager/native.rs +++ b/app/src/ai/mcp/templatable_manager/native.rs @@ -1,16 +1,16 @@ use crate::ai::mcp::file_based_manager::FileBasedMCPManagerEvent; -use crate::ai::mcp::templatable_manager::oauth::{ - load_credentials_from_secure_storage, write_to_secure_storage, FILE_BASED_MCP_CREDENTIALS_KEY, - TEMPLATABLE_MCP_CREDENTIALS_KEY, -}; use crate::ai::mcp::CloudMCPServer; use crate::ai::mcp::FileBasedMCPManager; use core::fmt; -use std::collections::HashSet; +use mcp::oauth::{ + self, load_credentials_from_secure_storage, write_to_secure_storage, AuthContext, + CallbackResult, FileBasedPersistedCredentialsMap, PersistedCredentials, + PersistedCredentialsMap, FILE_BASED_MCP_CREDENTIALS_KEY, TEMPLATABLE_MCP_CREDENTIALS_KEY, +}; +use mcp::runtime::{error_to_user_message, spawn_server}; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use std::{collections::HashMap, future::Future}; -use crate::ai::mcp::http_client::build_client_with_headers; use crate::ai::mcp::templatable::GalleryData; use crate::ai::mcp::templatable_manager::FigmaMcpStatus; use crate::ai::mcp::{ @@ -51,13 +51,9 @@ use crate::{ GlobalResourceHandlesProvider, }; use async_compat::CompatExt as _; -use cfg_if::cfg_if; -use futures::FutureExt as _; use parking_lot::Mutex; -use rmcp::{transport::ConfigureCommandExt as _, ServiceExt as _}; use simple_logger::manager::LogManager; -use simple_logger::SimpleLogger; -use tokio::io::AsyncBufReadExt as _; +use url::Url; use uuid::Uuid; use warp_core::safe_error; use warp_core::{execution_mode::AppExecutionMode, features::FeatureFlag, settings::Setting as _}; @@ -65,8 +61,6 @@ use warpui::AppContext; use warpui::{windowing::WindowManager, ModelContext, SingletonEntity}; use super::{ - oauth::{self, AuthContext, FileBasedPersistedCredentialsMap, PersistedCredentialsMap}, - utils::{query_resources_for, query_tools_for}, MCPServerState, SpawnedServerInfo, TemplatableMCPServerInfo, TemplatableMCPServerManager, TemplatableMCPServerManagerEvent, }; @@ -122,53 +116,6 @@ impl fmt::Display for LegacyToTemplatableMCPConversionError { } } -/// Convert an rmcp error to a user-friendly error message. -fn error_to_user_message(error: &rmcp::RmcpError) -> String { - match error { - rmcp::RmcpError::ClientInitialize(err) => { - format!("Failed to initialize client: {}", err) - } - rmcp::RmcpError::ServerInitialize(err) => { - format!("Failed to initialize server: {}", err) - } - rmcp::RmcpError::TransportCreation { error, .. } => { - format!("Failed to establish connection: {}", error) - } - rmcp::RmcpError::Runtime(err) => { - format!("Runtime error: {}", err) - } - rmcp::RmcpError::Service(err) => match err { - rmcp::ServiceError::McpError(_) => { - "Server returned an error. Please check server logs for details.".to_string() - } - rmcp::ServiceError::TransportSend(_) => { - "Failed to send data to server. Connection may have been lost.".to_string() - } - rmcp::ServiceError::TransportClosed => { - "Connection closed unexpectedly. The server may have crashed.".to_string() - } - rmcp::ServiceError::UnexpectedResponse => { - "Server sent an unexpected response. The server may be incompatible.".to_string() - } - rmcp::ServiceError::Cancelled { reason } => format!( - "Operation was cancelled with reason: {}", - reason.clone().unwrap_or("Unknown reason".to_string()) - ), - rmcp::ServiceError::Timeout { timeout } => { - format!( - "Connection timed out after {} seconds. The server may be unresponsive.", - timeout.as_secs() - ) - } - _ => format!("Service error: {}", err), - }, - // The enum is marked as non-exhaustive, so we need a catch-all. - _ => { - format!("Error: {error}") - } - } -} - /// An MCP server integration that Warp ships with bundled skills for. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum McpIntegration { @@ -183,6 +130,104 @@ impl TemplatableMCPServerManager { } } + /// Handles an incoming OAuth callback URL. + /// + /// Routes the callback to the correct in-flight OAuth flow using the `state` query + /// parameter that rmcp embedded in the authorization URL. + pub fn handle_oauth_callback(&mut self, url: &Url) -> anyhow::Result<()> { + if url.path() != "/oauth2callback" { + anyhow::bail!( + "Invalid OAuth callback path: expected '/oauth2callback', got '{}'", + url.path() + ); + } + + let query_params: HashMap<_, _> = url.query_pairs().collect(); + + let Some(state) = query_params.get("state") else { + anyhow::bail!("Missing 'state' parameter in OAuth callback"); + }; + + let code = query_params.get("code"); + let error = query_params.get("error"); + + let result = match code { + Some(code) => CallbackResult::Success { + code: code.to_string(), + // Pass the state value through as the CSRF token; rmcp will validate it + // against the token it stored when generating the authorization URL. + csrf_token: state.to_string(), + }, + None => CallbackResult::Error { + error: error.map(|e| e.to_string()), + }, + }; + + let Some(&server_uuid) = self.pending_oauth_csrf.get(state.as_ref() as &str) else { + anyhow::bail!("No active OAuth flow found for state={state}"); + }; + + let Some(server_info) = self.spawned_servers.get(&server_uuid) else { + anyhow::bail!("No spawned server found for uuid={server_uuid}"); + }; + + warpui::r#async::block_on(server_info.oauth_result_tx.send(result)).map_err(|_| { + anyhow::anyhow!( + "Failed to send OAuth result to server {server_uuid} - receiver dropped" + ) + })?; + + self.pending_oauth_csrf.remove(state.as_ref() as &str); + Ok(()) + } + + fn save_credentials_to_secure_storage( + &mut self, + app: &mut warpui::AppContext, + installation_uuid: Uuid, + credentials: PersistedCredentials, + ) { + if let Some(hash) = FileBasedMCPManager::as_ref(app).get_hash_by_uuid(installation_uuid) { + self.file_based_server_credentials.insert(hash, credentials); + write_to_secure_storage( + app, + FILE_BASED_MCP_CREDENTIALS_KEY, + &self.file_based_server_credentials, + ); + return; + } + + if let Some(template_uuid) = self.get_template_uuid(installation_uuid) { + self.server_credentials.insert(template_uuid, credentials); + write_to_secure_storage( + app, + TEMPLATABLE_MCP_CREDENTIALS_KEY, + &self.server_credentials, + ); + } else { + log::error!( + "Corresponding file or cloud-based server not found for installation UUID {installation_uuid}" + ); + } + } + + pub fn delete_credentials_from_secure_storage( + &mut self, + installation_uuid: Uuid, + app: &mut warpui::AppContext, + ) { + if let Some(template_uuid) = self.get_template_uuid(installation_uuid) { + self.server_credentials.remove(&template_uuid); + write_to_secure_storage( + app, + TEMPLATABLE_MCP_CREDENTIALS_KEY, + &self.server_credentials, + ); + } else { + log::error!("No template UUID found for installation UUID {installation_uuid}"); + } + } + /// Creates a new [`TemplatableMCPServerManager`] instance. pub fn new( locally_installed_servers: HashMap, @@ -844,17 +889,87 @@ impl TemplatableMCPServerManager { .get_hash_by_uuid(installation_uuid) .is_some(); - let auth_context = AuthContext { - oauth_result_rx, - spawner: ctx.spawner(), - uuid: installation_uuid, - persisted_credentials, - is_headless, - is_file_based, - }; - let server_name = server.name.clone(); let description = installation.templatable_mcp_server().description.clone(); + let auth_context = FeatureFlag::McpOauth.is_enabled().then(|| { + let persist_spawner = ctx.spawner(); + let requires_authentication_spawner = ctx.spawner(); + let authenticated_spawner = ctx.spawner(); + + AuthContext { + oauth_result_rx, + uuid: installation_uuid, + persisted_credentials, + is_headless, + is_file_based, + persist_credentials: Box::new(move |installation_uuid, credentials| { + let spawner = persist_spawner.clone(); + Box::pin(async move { + spawner + .spawn(move |manager, ctx| { + manager.save_credentials_to_secure_storage( + ctx, + installation_uuid, + credentials, + ); + }) + .await + .map_err(|err| { + anyhow::anyhow!( + "Failed to persist auto-refreshed MCP credentials: {err:?}" + ) + })?; + Ok(()) + }) + }), + requires_authentication: Box::new(move |uuid, csrf_state, auth_url| { + let spawner = requires_authentication_spawner.clone(); + Box::pin(async move { + spawner + .spawn(move |manager, ctx| { + if !csrf_state.is_empty() { + manager.pending_oauth_csrf.insert(csrf_state, uuid); + } + ctx.open_url(&auth_url); + manager.change_server_state(uuid, MCPServerState::Authenticating, ctx); + }) + .await + .map_err(|err| { + anyhow::anyhow!( + "Failed to emit MCP authentication-required state: {err:?}" + ) + })?; + Ok(()) + }) + }), + authenticated: Box::new(move |server_name| { + let spawner = authenticated_spawner.clone(); + Box::pin(async move { + spawner + .spawn(move |_, ctx| { + if let Some(active_window_id) = ctx.windows().active_window() { + ToastStack::handle(ctx).update(ctx, |stack, ctx| { + stack.add_ephemeral_toast( + DismissibleToast::default(format!( + "Successfully authenticated {server_name} MCP server" + )), + active_window_id, + ctx, + ); + }); + } + }) + .await + .map_err(|err| { + anyhow::anyhow!( + "Failed to emit MCP authenticated notification: {err:?}" + ) + })?; + Ok(()) + }) + }), + } + }); // Extract values from mode before moving it into the closure. let should_persist = mode.should_persist_running_state_to_sqlite(); @@ -878,7 +993,7 @@ impl TemplatableMCPServerManager { let error = match server_info { Ok(info) => { - let peer = info.service.clone(); + let peer = info.peer(); me.active_servers.insert(installation_uuid, info); // Clear any previous error message on successful connection. @@ -972,7 +1087,7 @@ impl TemplatableMCPServerManager { if let Some(server_info) = self.active_servers.remove(&installation_uuid) { self.change_server_state(installation_uuid, MCPServerState::ShuttingDown, ctx); // Cancel the server, and emit NotRunning state once it has stopped. - ctx.spawn(server_info.service.cancel(), move |me, _, ctx| { + ctx.spawn(server_info.shutdown(), move |me, _, ctx| { me.change_server_state(installation_uuid, MCPServerState::NotRunning, ctx); ctx.dispatch_global_action("workspace:save_app", ()); }); @@ -1554,13 +1669,7 @@ impl TemplatableMCPServerManager { ) -> Option> { self.active_servers .get(&installation_uuid) - .and_then(|server| { - if server.service.is_transport_closed() { - None - } else { - Some(server.service.clone()) - } - }) + .and_then(TemplatableMCPServerInfo::peer_if_connected) } /// Triggers reconnection of a server by its installation UUID. @@ -1644,7 +1753,7 @@ impl TemplatableMCPServerManager { let spawner = self.spawner.as_ref()?; self.active_servers .iter() - .find(|(_, server)| server.tools.iter().any(|t| t.name == tool_name)) + .find(|(_, server)| server.has_tool(&tool_name)) .map(|(installation_uuid, _)| { crate::ai::mcp::reconnecting_peer::ReconnectingPeer::new( *installation_uuid, @@ -1663,7 +1772,7 @@ impl TemplatableMCPServerManager { ) -> Option { let spawner = self.spawner.as_ref()?; let server = self.active_servers.get(&installation_id)?; - if server.tools.iter().any(|t| t.name == tool_name) { + if server.has_tool(&tool_name) { Some(crate::ai::mcp::reconnecting_peer::ReconnectingPeer::new( installation_id, spawner.clone(), @@ -1735,419 +1844,3 @@ impl TemplatableMCPServerManager { .contains_key(&installation_hash) } } - -type ReqwestHttpTransport = rmcp::transport::StreamableHttpClientTransport; -type ReqwestSseTransport = mcp::sse_transport::SseClientTransport; - -/// Spawns a new MCP server from a given [`TransportType`]. -async fn spawn_server( - server_name: String, - description: Option, - uuid: Uuid, - transport_type: TransportType, - logger: SimpleLogger, - auth_context: AuthContext, -) -> Result { - logger.log("[note] Attention! There may be sensitive information (such as API keys) in these logs. Make sure to redact any secrets before sharing with others.".to_string()); - - let mut is_authenticated_transport = false; - let service = match transport_type { - TransportType::CLIServer(cli_server) => { - logger.log("[info] MCP: Using stdio transport".to_string()); - - cfg_if! { - if #[cfg(windows)] { - // We wrap the command in cmd.exe /c to allow Windows to be responsible for resolving the - // PATH variable rather than depending on the `Command` implementation, which only looks for - // `.exe` files in directories found in PATH. - // https://github.com/rust-lang/rust/issues/37519 - let command = "cmd.exe".to_owned(); - let args = std::iter::once("/c".to_owned()) - .chain(std::iter::once(cli_server.command)) - .chain(cli_server.args) - .collect::>(); - } else { - let command = cli_server.command; - let args = cli_server.args; - } - } - - // Capture the command and configured cwd for diagnostics before they're - // moved into the Command builder closure. - let command_for_log = command.clone(); - let cwd_for_log = cli_server.cwd_parameter.clone(); - - // Try to spawn the child process. - let (transport, stderr) = rmcp::transport::TokioChildProcess::builder( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args); - if let Some(cwd) = cli_server.cwd_parameter { - cmd.current_dir(cwd); - } - for StaticEnvVar { name, value } in cli_server.static_env_vars.iter() { - if value.is_empty() { - // Skip empty/unset environment variables so that, in the CLI, they can be inherited. - logger.log(format!( - "[warn] MCP: Skipping empty environment variable: {name}" - )); - continue; - } - cmd.env(name, value); - } - - // On Windows, ensure that no console window is shown. - #[cfg(windows)] - cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0); - }), - ) - .stderr(std::process::Stdio::piped()) - .spawn() - .map_err(|err| { - if err.kind() == std::io::ErrorKind::NotFound { - let cwd_display = cwd_for_log - .as_deref() - .unwrap_or(""); - logger.log(format!( - "[error] MCP: Failed to spawn '{server_name}': command '{command_for_log}' \ - not found (cwd: {cwd_display}). If your MCP server depends on a specific \ - working directory, set the `working_directory` field in your config to \ - override the default." - )); - } - rmcp::RmcpError::transport_creation::(err) - })?; - - let pid = transport - .id() - .map(|pid| pid.to_string()) - .unwrap_or("??".to_string()); - - // We always expect to have an stderr, but this is marginally safer than unwrapping. - if let Some(stderr) = stderr { - let logger = logger.clone(); - // Spawn a background task to forward from the child process's stderr to our logger. - tokio::spawn(async move { - let mut buf = String::new(); - let mut reader = tokio::io::BufReader::new(stderr); - loop { - match reader.read_line(&mut buf).await { - // EOF. - Ok(0) => return, - // Read some data. - Ok(_) => logger.log(format!("[info] MCP [pid: {pid}] stderr: {buf}")), - // Failed to read from the child process's stderr. - Err(e) => { - log::error!("Failed to read stderr: {e}"); - return; - } - } - } - }); - } - - // Wrap the transport in a logging wrapper. - let transport = TransportLoggingWrapper { - transport, - logger: logger.clone(), - }; - - // Create the MCP client and connect to the server. - Ok::<_, rmcp::RmcpError>(make_client_info().into_dyn().serve(transport).await?) - } - TransportType::ServerSentEvents(sse_server) => { - let headers: std::collections::HashMap = sse_server - .headers - .iter() - .map(|h| (h.name.clone(), h.value.clone())) - .collect(); - match determine_transport(server_name.clone(), &sse_server.url, &headers, auth_context) - .await - { - // TODO: these need headers also? - Ok(Transport::Http(Some(client))) => { - is_authenticated_transport = true; - - logger.log("[info] MCP: Using Streaming HTTP transport".to_string()); - let transport = rmcp::transport::StreamableHttpClientTransport::with_client( - client, - rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri( - sse_server.url.clone(), - ), - ); - let transport = TransportLoggingWrapper { - transport, - logger: logger.clone(), - }; - Ok(make_client_info().into_dyn().serve(transport).await?) - } - Ok(Transport::Http(None)) => { - logger.log("[info] MCP: Using Streaming HTTP transport".to_string()); - let transport = if headers.is_empty() { - rmcp::transport::StreamableHttpClientTransport::from_uri( - sse_server.url.clone(), - ) - } else { - let client = build_client_with_headers(&headers)?; - rmcp::transport::StreamableHttpClientTransport::with_client( - client, - rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri( - sse_server.url.clone(), - ), - ) - }; - let transport = TransportLoggingWrapper { - transport, - logger: logger.clone(), - }; - Ok(make_client_info().into_dyn().serve(transport).await?) - } - Ok(Transport::Sse(Some(client))) => { - is_authenticated_transport = true; - - logger.log("[info] MCP: Using (legacy) SSE transport (due to preflight failing with a 404)".to_string()); - let transport = mcp::sse_transport::SseClientTransport::start_with_client( - client, - mcp::sse_transport::SseClientConfig { - sse_endpoint: sse_server.url.into(), - ..Default::default() - }, - ) - .await - .map_err(rmcp::RmcpError::transport_creation::)?; - let transport = TransportLoggingWrapper { - transport, - logger: logger.clone(), - }; - Ok(make_client_info().into_dyn().serve(transport).await?) - } - Ok(Transport::Sse(None)) => { - logger.log("[info] MCP: Using (legacy) SSE transport (due to preflight failing with a 404)".to_string()); - let transport = if headers.is_empty() { - mcp::sse_transport::SseClientTransport::start(sse_server.url.clone()) - .await - .map_err(|e| { - rmcp::RmcpError::transport_creation::(e) - })? - } else { - let client = build_client_with_headers(&headers)?; - mcp::sse_transport::SseClientTransport::start_with_client( - client, - mcp::sse_transport::SseClientConfig { - sse_endpoint: sse_server.url.clone().into(), - ..Default::default() - }, - ) - .await - .map_err(rmcp::RmcpError::transport_creation::)? - }; - let transport = TransportLoggingWrapper { - transport, - logger: logger.clone(), - }; - Ok(make_client_info().into_dyn().serve(transport).await?) - } - Err(err) => { - logger.log(format!( - "[error] MCP: preflight connection to MCP server failed: {err:#}" - )); - Err(err)? - } - } - } - }?; - - let server_info = service.peer_info(); - logger.log(format!("[info] MCP: Connected to server: {server_info:#?}")); - - let capabilities = server_info.map(|info| &info.capabilities); - - let resources = - query_resources_for(capabilities, &server_name, || service.list_all_resources()).await; - let tools = query_tools_for(capabilities, &server_name, || service.list_all_tools()).await; - - Ok(TemplatableMCPServerInfo { - name: server_name, - service, - resources, - tools, - installation_id: uuid, - description, - is_authenticated_transport, - }) -} - -/// The transport to use for MCP. -enum Transport { - /// The HTTP transport, with an optional authenticated client. - Http(Option>), - /// The SSE transport, with an optional authenticated client. - Sse(Option>), -} - -/// Determines which transport to use. -/// -/// This sends a "preflight" InitializeRequest to the server to determine whether the -/// server supports the HTTP transport (or needs to use the SSE transport), and if -/// authentication is required. -async fn determine_transport( - server_name: String, - url: &str, - headers: &std::collections::HashMap, - auth_context: AuthContext, -) -> Result { - use reqwest::StatusCode; - - fn unexpected_error(status: reqwest::StatusCode) -> rmcp::RmcpError { - rmcp::RmcpError::transport_creation::(format!( - "Unexpected status code: {status}" - )) - } - match send_initialize_request(url, headers, None).await? { - StatusCode::OK => Ok(Transport::Http(None)), - StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => Ok(Transport::Sse(None)), - StatusCode::UNAUTHORIZED => { - if !FeatureFlag::McpOauth.is_enabled() { - return Err(rmcp::RmcpError::transport_creation::( - "Server requires authentication, which is not yet supported.".to_string(), - )); - } - - let spawner = auth_context.spawner.clone(); - // Go through the OAuth flow to get an authenticated client. - // This will first attempt to use cached credentials before starting interactive OAuth. - let (client, did_require_login) = oauth::make_authenticated_client(url, auth_context) - .boxed() - .await - .map_err(rmcp::RmcpError::transport_creation::)?; - let transport = match send_initialize_request(url, headers, Some(&client)).await? { - StatusCode::OK => Ok(Transport::Http(Some(client))), - StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { - Ok(Transport::Sse(Some(client))) - } - other => Err(unexpected_error(other)), - }; - if transport.is_ok() && did_require_login { - let _ = spawner - .spawn(move |_, ctx| { - if let Some(active_window_id) = ctx.windows().active_window() { - ToastStack::handle(ctx).update(ctx, |stack, ctx| { - stack.add_ephemeral_toast( - DismissibleToast::default(format!( - "Successfully authenticated {server_name} MCP server" - )), - active_window_id, - ctx, - ); - }); - } - }) - .await; - } - - transport - } - status => Err(unexpected_error(status)), - } -} - -/// Sends an InitializeRequest to the server, and returns the HTTP status code from the response. -async fn send_initialize_request( - url: &str, - headers: &std::collections::HashMap, - auth_client: Option<&rmcp::transport::auth::AuthClient>, -) -> Result { - use rmcp::transport::common::http_header::{EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE}; - - let request = rmcp::model::InitializeRequest::new(make_client_info()); - let request = rmcp::model::ClientJsonRpcMessage::request( - rmcp::model::ClientRequest::InitializeRequest(request), - rmcp::model::RequestId::Number(0), - ); - - let mut request = build_client_with_headers(headers)? - .post(url) - .header( - http::header::ACCEPT, - [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "), - ) - .json(&request); - - if let Some(auth_client) = auth_client.as_ref() { - let access_token = auth_client - .get_access_token() - .await - .map_err(rmcp::RmcpError::transport_creation::)?; - request = request.bearer_auth(access_token); - } - - let response = request - .send() - .await - .map_err(rmcp::RmcpError::transport_creation::)?; - - Ok(response.status()) -} - -/// Creates a [`ClientInfo`] for the MCP client. -/// -/// This tells the MCP server who we are and what capabilities we have. -fn make_client_info() -> rmcp::model::ClientInfo { - rmcp::model::ClientInfo::new( - Default::default(), - rmcp::model::Implementation::new( - warp_core::channel::ChannelState::app_id().to_string(), - warp_core::channel::ChannelState::app_version() - .map(|v| v.to_string()) - .unwrap_or_default(), - ), - ) -} - -/// A wrapper around a [`rmcp::transport::Transport`] that logs all requests and responses. -struct TransportLoggingWrapper { - transport: T, - logger: SimpleLogger, -} - -impl, R: rmcp::service::ServiceRole> rmcp::transport::Transport - for TransportLoggingWrapper -{ - type Error = T::Error; - - fn send( - &mut self, - item: rmcp::service::TxJsonRpcMessage, - ) -> impl Future> + Send + 'static { - if let Ok(json) = serde_json::to_string(&item) { - self.logger - .log(format!("[info] MCP: Sending request: {json}")); - } - - let logger = self.logger.clone(); - self.transport.send(item).map(move |result| { - if let Err(e) = &result { - logger.log(format!("[warn] MCP: Failed to send request: {e:#}")); - } - result - }) - } - - fn receive( - &mut self, - ) -> impl Future>> + Send { - let logger = self.logger.clone(); - async move { - let result = self.transport.receive().await; - if let Some(item) = &result { - if let Ok(json) = serde_json::to_string(item) { - logger.log(format!("[info] MCP: Received response: {json}")); - } - } - result - } - } - - fn close(&mut self) -> impl Future> + Send { - self.transport.close() - } -} diff --git a/app/src/ai/mcp/templatable_manager/utils.rs b/app/src/ai/mcp/templatable_manager/utils.rs deleted file mode 100644 index 155e5b4852..0000000000 --- a/app/src/ai/mcp/templatable_manager/utils.rs +++ /dev/null @@ -1,79 +0,0 @@ -//! Capability-gating helpers used during MCP server startup. -//! -//! Each `query_*_for` function pairs a capability check with the actual list -//! call from rmcp, gating the call on advertisement and failing soft on errors. -//! They take the list call as a closure so unit tests can drive the gate-and- -//! fail-soft control flow with a fake `RunningService` substitute. - -/// Whether to query `resources/list` for a server with the given capabilities. -/// -/// Per the MCP spec, the client should only invoke a list method when the server -/// has advertised the corresponding capability during initialization. -pub(super) fn should_query_resources( - capabilities: Option<&rmcp::model::ServerCapabilities>, -) -> bool { - capabilities.is_some_and(|c| c.resources.is_some()) -} - -/// Whether to query `tools/list` for a server with the given capabilities. -/// -/// Per the MCP spec, the client should only invoke a list method when the server -/// has advertised the corresponding capability during initialization. -pub(super) fn should_query_tools(capabilities: Option<&rmcp::model::ServerCapabilities>) -> bool { - capabilities.is_some_and(|c| c.tools.is_some()) -} - -/// Query `resources/list` for a connected MCP server. -/// -/// Skips the call entirely when `resources` was not advertised. Treats any -/// listing error as "no resources" (fail-soft) so a flaky `resources/list` -/// does not abort the entire server startup. Mirrors the behavior of -/// [`query_tools_for`] so the two capabilities are handled symmetrically. -pub(super) async fn query_resources_for( - capabilities: Option<&rmcp::model::ServerCapabilities>, - server_name: &str, - list_resources: F, -) -> Vec -where - F: FnOnce() -> Fut, - Fut: std::future::Future, rmcp::ServiceError>>, -{ - if !should_query_resources(capabilities) { - return Vec::new(); - } - match list_resources().await { - Ok(result) => result, - Err(err) => { - log::warn!("Failed to list resources for MCP server '{server_name}': {err}"); - Vec::new() - } - } -} - -/// Query `tools/list` for a connected MCP server. -/// -/// Skips the call entirely when `tools` was not advertised. Treats any listing -/// error as "no tools" (fail-soft) so a transient `tools/list` failure does -/// not abort the entire server startup — the user-visible regression #6798 -/// was rooted in the prior asymmetric handling, where a tools-list error on -/// a server with healthy resources would propagate and fail startup. -pub(super) async fn query_tools_for( - capabilities: Option<&rmcp::model::ServerCapabilities>, - server_name: &str, - list_tools: F, -) -> Vec -where - F: FnOnce() -> Fut, - Fut: std::future::Future, rmcp::ServiceError>>, -{ - if !should_query_tools(capabilities) { - return Vec::new(); - } - match list_tools().await { - Ok(result) => result, - Err(err) => { - log::warn!("Failed to list tools for MCP server '{server_name}': {err}"); - Vec::new() - } - } -} diff --git a/app/src/ai/mcp/templatable_manager/utils_tests.rs b/app/src/ai/mcp/templatable_manager/utils_tests.rs deleted file mode 100644 index 047056728e..0000000000 --- a/app/src/ai/mcp/templatable_manager/utils_tests.rs +++ /dev/null @@ -1,309 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::ai::mcp::templatable_manager::utils::{ - query_resources_for, query_tools_for, should_query_resources, should_query_tools, - }; - use rmcp::model::{ErrorCode, ErrorData, Resource, ServerCapabilities, Tool}; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Arc; - - /// Build a `ServerCapabilities` with selected capability flags toggled on. - /// Each `Some(default)` mirrors how rmcp deserializes a capability the - /// server advertised with no inner flags set. - fn caps(tools: bool, resources: bool) -> ServerCapabilities { - match (tools, resources) { - (true, true) => ServerCapabilities::builder() - .enable_tools() - .enable_resources() - .build(), - (true, false) => ServerCapabilities::builder().enable_tools().build(), - (false, true) => ServerCapabilities::builder().enable_resources().build(), - (false, false) => ServerCapabilities::builder().build(), - } - } - - fn test_tool(name: &str) -> Tool { - serde_json::from_value(serde_json::json!({ - "name": name, - "description": "test tool", - "inputSchema": { "type": "object" }, - })) - .expect("Tool deserialization") - } - - fn test_resource(uri: &str) -> Resource { - serde_json::from_value(serde_json::json!({ - "uri": uri, - "name": "test resource", - })) - .expect("Resource deserialization") - } - - // ---------- predicate-level tests ---------- - - /// Regression test for warpdotdev/warp#6798: each capability is queried - /// independently. Previously, asymmetric handling could cause `tools/list` - /// to be skipped when a server advertised both `tools` and `resources`, - /// resulting in "No tools available" even though the server had tools. - #[test] - fn each_capability_is_queried_independently() { - for has_tools in [false, true] { - for has_resources in [false, true] { - let c = caps(has_tools, has_resources); - assert_eq!( - should_query_tools(Some(&c)), - has_tools, - "tools={has_tools}, resources={has_resources}", - ); - assert_eq!( - should_query_resources(Some(&c)), - has_resources, - "tools={has_tools}, resources={has_resources}", - ); - } - } - assert!(!should_query_tools(None)); - assert!(!should_query_resources(None)); - } - - // ---------- query_tools_for control-flow tests ---------- - - /// When `tools` is not advertised, the helper must skip the list call so - /// we don't waste a round trip and pollute the wire log with a request - /// that's destined to return `METHOD_NOT_FOUND`. - #[tokio::test] - async fn query_tools_for_skips_listing_when_capability_not_advertised() { - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - let no_caps = caps(false, false); - - let result = query_tools_for(Some(&no_caps), "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_tool("never")]) - }) - .await; - - assert!(result.is_empty()); - assert_eq!( - calls.load(Ordering::SeqCst), - 0, - "list function must not be called when tools capability is absent", - ); - } - - /// `None` server info follows the same skip-listing path as "no capability". - #[tokio::test] - async fn query_tools_for_skips_listing_when_server_info_is_none() { - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - - let result = query_tools_for(None, "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_tool("never")]) - }) - .await; - - assert!(result.is_empty()); - assert_eq!(calls.load(Ordering::SeqCst), 0); - } - - /// Happy path: tools advertised, list call succeeds, tools returned. - #[tokio::test] - async fn query_tools_for_returns_listed_tools_when_capability_advertised() { - let c = caps(true, false); - let expected = vec![test_tool("greet"), test_tool("review")]; - let to_return = expected.clone(); - - let result = query_tools_for(Some(&c), "srv", || async move { Ok(to_return) }).await; - - assert_eq!(result, expected); - } - - /// `tools` advertised but server returns an empty list — distinct from the - /// "skipped" case in that we still made the call. - #[tokio::test] - async fn query_tools_for_returns_empty_vec_when_server_lists_no_tools() { - let c = caps(true, false); - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - - let result = query_tools_for(Some(&c), "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(Vec::new()) - }) - .await; - - assert!(result.is_empty()); - assert_eq!( - calls.load(Ordering::SeqCst), - 1, - "list function still called when capability is advertised", - ); - } - - /// **The fail-soft test the bug ticket implicitly demands.** Transport- - /// closed errors must not abort server startup; the helper must log and - /// return an empty vec. This is the regression-protector for #6798's - /// underlying asymmetry — if anyone re-introduces a `return Err(...)` here, - /// this test fails. - #[tokio::test] - async fn query_tools_for_returns_empty_on_transport_error() { - let c = caps(true, false); - let result = query_tools_for(Some(&c), "srv", || async { - Err(rmcp::ServiceError::TransportClosed) - }) - .await; - assert!(result.is_empty()); - } - - /// MCP-protocol errors (e.g. METHOD_NOT_FOUND from a misbehaving server - /// that advertised the capability but rejects the call) also fail soft, - /// so the rest of the server surface still comes up. - #[tokio::test] - async fn query_tools_for_returns_empty_on_mcp_error() { - let c = caps(true, false); - let result = query_tools_for(Some(&c), "srv", || async { - Err(rmcp::ServiceError::McpError(ErrorData { - code: ErrorCode::METHOD_NOT_FOUND, - message: "tools/list not implemented".into(), - data: None, - })) - }) - .await; - assert!(result.is_empty()); - } - - /// The list function must be called exactly once per query — not zero - /// (that would be the skip path) and not multiple times (no implicit - /// retry inside the helper). - #[tokio::test] - async fn query_tools_for_calls_list_function_exactly_once() { - let c = caps(true, false); - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - - let _ = query_tools_for(Some(&c), "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_tool("p")]) - }) - .await; - - assert_eq!(calls.load(Ordering::SeqCst), 1); - } - - /// Independence: tools-listing decision must not depend on whether - /// resources are also advertised. Run the full happy-path flow under - /// every (tools, resources) combination and assert tools come back iff - /// the tools capability is advertised. - #[tokio::test] - async fn query_tools_for_decision_independent_of_other_capabilities() { - let tools = vec![test_tool("x")]; - for has_tools in [false, true] { - for has_resources in [false, true] { - let c = caps(has_tools, has_resources); - let to_return = tools.clone(); - let result = - query_tools_for(Some(&c), "srv", || async move { Ok(to_return) }).await; - - if has_tools { - assert_eq!( - result, tools, - "expected tools when advertised \ - (tools={has_tools}, resources={has_resources})", - ); - } else { - assert!( - result.is_empty(), - "expected empty when tools not advertised \ - (tools={has_tools}, resources={has_resources})", - ); - } - } - } - } - - // ---------- query_resources_for control-flow tests ---------- - - #[tokio::test] - async fn query_resources_for_skips_listing_when_capability_not_advertised() { - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - let no_caps = caps(false, false); - - let result = query_resources_for(Some(&no_caps), "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_resource("file:///nope")]) - }) - .await; - - assert!(result.is_empty()); - assert_eq!(calls.load(Ordering::SeqCst), 0); - } - - #[tokio::test] - async fn query_resources_for_skips_listing_when_server_info_is_none() { - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - - let result = query_resources_for(None, "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_resource("file:///nope")]) - }) - .await; - - assert!(result.is_empty()); - assert_eq!(calls.load(Ordering::SeqCst), 0); - } - - #[tokio::test] - async fn query_resources_for_returns_listed_resources_when_capability_advertised() { - let c = caps(false, true); - let expected = vec![test_resource("file:///a"), test_resource("file:///b")]; - let to_return = expected.clone(); - - let result = query_resources_for(Some(&c), "srv", || async move { Ok(to_return) }).await; - - assert_eq!(result, expected); - } - - /// Fail-soft on transport errors — same contract as `query_tools_for`, - /// matching the existing behavior the predicate refactor preserved. - #[tokio::test] - async fn query_resources_for_returns_empty_on_transport_error() { - let c = caps(false, true); - let result = query_resources_for(Some(&c), "srv", || async { - Err(rmcp::ServiceError::TransportClosed) - }) - .await; - assert!(result.is_empty()); - } - - #[tokio::test] - async fn query_resources_for_returns_empty_on_mcp_error() { - let c = caps(false, true); - let result = query_resources_for(Some(&c), "srv", || async { - Err(rmcp::ServiceError::McpError(ErrorData { - code: ErrorCode::METHOD_NOT_FOUND, - message: "resources/list not implemented".into(), - data: None, - })) - }) - .await; - assert!(result.is_empty()); - } - - #[tokio::test] - async fn query_resources_for_calls_list_function_exactly_once() { - let c = caps(false, true); - let calls = Arc::new(AtomicUsize::new(0)); - let calls_clone = calls.clone(); - - let _ = query_resources_for(Some(&c), "srv", || async move { - calls_clone.fetch_add(1, Ordering::SeqCst); - Ok(vec![test_resource("file:///a")]) - }) - .await; - - assert_eq!(calls.load(Ordering::SeqCst), 1); - } -} diff --git a/app/src/ai/skills/file_watchers/skill_watcher.rs b/app/src/ai/skills/file_watchers/skill_watcher.rs index ae815ba1fb..d115ecaa49 100644 --- a/app/src/ai/skills/file_watchers/skill_watcher.rs +++ b/app/src/ai/skills/file_watchers/skill_watcher.rs @@ -14,7 +14,6 @@ use super::{ }; use watcher::{BulkFilesystemWatcherEvent, HomeDirectoryWatcher, HomeDirectoryWatcherEvent}; -use crate::server::datetime_ext::DateTimeExt; use crate::warp_managed_paths_watcher::{ filter_repository_update_by_prefix, warp_managed_skill_dirs, WarpManagedPathsWatcher, WarpManagedPathsWatcherEvent, @@ -29,6 +28,7 @@ use repo_metadata::{ repository::{Repository, SubscriberId}, DirectoryWatcher, RepoMetadataModel, RepositoryUpdate, }; +use warp_core::datetime_ext::DateTimeExt; use warpui::{AppContext, Entity, ModelContext, ModelHandle, SingletonEntity}; #[derive(Debug, PartialEq)] diff --git a/app/src/auth/user.rs b/app/src/auth/user.rs index 77f568c9c2..20d8eacc99 100644 --- a/app/src/auth/user.rs +++ b/app/src/auth/user.rs @@ -1,7 +1,7 @@ -use crate::server::datetime_ext::DateTimeExt; use anyhow::{anyhow, Result}; use chrono::{DateTime, FixedOffset}; use serde::{Deserialize, Serialize}; +use warp_core::datetime_ext::DateTimeExt; use warp_graphql::{queries::get_user::FirebaseProfile, scalars::time::ServerTimestamp}; use super::UserUid; diff --git a/app/src/autoupdate/mod.rs b/app/src/autoupdate/mod.rs index 6b23dcfb31..af0f5fa781 100644 --- a/app/src/autoupdate/mod.rs +++ b/app/src/autoupdate/mod.rs @@ -12,10 +12,7 @@ use crate::send_telemetry_sync_from_app_ctx; use crate::server::server_api::ServerApi; use crate::server::telemetry::TelemetryEvent; use crate::workspace::Workspace; -use crate::{ - channel::Channel, report_if_error, send_telemetry_from_ctx, server::datetime_ext::DateTimeExt, - ChannelState, -}; +use crate::{channel::Channel, report_if_error, send_telemetry_from_ctx, ChannelState}; use ::channel_versions::{ParsedVersion, VersionInfo}; use anyhow::{anyhow, Context as _, Result}; use chrono::{DateTime, FixedOffset, NaiveDate}; @@ -23,6 +20,7 @@ use rand::Rng as _; use std::collections::VecDeque; use std::sync::Arc; use std::time::Duration; +use warp_core::datetime_ext::DateTimeExt; use warp_core::execution_mode::AppExecutionMode; use warpui::platform::TerminationMode; use warpui::r#async::Timer; diff --git a/app/src/server/mod.rs b/app/src/server/mod.rs index 4369bec89b..29533c0207 100644 --- a/app/src/server/mod.rs +++ b/app/src/server/mod.rs @@ -1,6 +1,5 @@ pub mod block; pub mod cloud_objects; -pub mod datetime_ext; pub mod experiments; pub mod graphql; pub mod ids; diff --git a/app/src/server/network_logging.rs b/app/src/server/network_logging.rs index 6f9a68c2b6..3a500b783b 100644 --- a/app/src/server/network_logging.rs +++ b/app/src/server/network_logging.rs @@ -5,8 +5,8 @@ use chrono::{DateTime, FixedOffset}; use enclose::enclose; use warpui::{Entity, ModelContext, SingletonEntity}; -use crate::server::datetime_ext::DateTimeExt; use crate::server::server_api::ServerApiProvider; +use warp_core::datetime_ext::DateTimeExt; /// Maximum number of network log items retained in memory. Matches the /// previous file-rotation threshold so the pane surface behaves consistently diff --git a/app/src/server/server_api/auth.rs b/app/src/server/server_api/auth.rs index 0e85d8c73f..373b70995e 100644 --- a/app/src/server/server_api/auth.rs +++ b/app/src/server/server_api/auth.rs @@ -10,6 +10,7 @@ use instant::Duration; use mockall::{automock, predicate::*}; use oauth2::TokenResponse; use thiserror::Error; +use warp_core::datetime_ext::DateTimeExt as _; use warp_core::errors::{AnyhowErrorExt, ErrorExt}; use warp_graphql::client::Operation; use warp_graphql::mutations::expire_api_key::{ @@ -59,8 +60,7 @@ use crate::{ channel::ChannelState, convert_to_server_experiment, server::{ - datetime_ext::DateTimeExt as _, experiments::ServerExperiment, - graphql::get_request_context, server_api::ServerApiEvent, + experiments::ServerExperiment, graphql::get_request_context, server_api::ServerApiEvent, }, }; diff --git a/app/src/terminal/model/session/command_executor/in_band_command_executor.rs b/app/src/terminal/model/session/command_executor/in_band_command_executor.rs index e34a1a0ab1..cffb881037 100644 --- a/app/src/terminal/model/session/command_executor/in_band_command_executor.rs +++ b/app/src/terminal/model/session/command_executor/in_band_command_executor.rs @@ -14,9 +14,9 @@ use warp_terminal::model::Point; use warpui::r#async::block_on; use crate::safe_info; -use crate::server::datetime_ext::DateTimeExt; use crate::terminal::event::ExecutedExecutorCommandEvent; use crate::terminal::shell::{Shell, ShellType}; +use warp_core::datetime_ext::DateTimeExt; use warp_util::on_cancel::OnCancelFutureExt; use crate::terminal::model::session::command_executor::{ diff --git a/app/src/terminal/model/session/command_executor/tmux_executor.rs b/app/src/terminal/model/session/command_executor/tmux_executor.rs index 3b0a7ec7d6..3225e3af5d 100644 --- a/app/src/terminal/model/session/command_executor/tmux_executor.rs +++ b/app/src/terminal/model/session/command_executor/tmux_executor.rs @@ -10,10 +10,10 @@ use chrono::DateTime; use parking_lot::Mutex; use super::{ExecuteCommandOptions, ExecutorCommandEvent}; -use crate::server::datetime_ext::DateTimeExt; use crate::terminal::event::ExecutedExecutorCommandEvent; use crate::terminal::model::tmux::commands::TmuxCommand; use crate::terminal::shell::Shell; +use warp_core::datetime_ext::DateTimeExt; use super::CommandExecutor; use warp_completer::completer::{CommandExitStatus, CommandOutput}; diff --git a/app/src/workspace/view/vertical_tabs.rs b/app/src/workspace/view/vertical_tabs.rs index 31bd035c27..d0271022f1 100644 --- a/app/src/workspace/view/vertical_tabs.rs +++ b/app/src/workspace/view/vertical_tabs.rs @@ -4,8 +4,8 @@ use crate::ai::agent::conversation::{ConversationStatus, StatusColorStyle}; use crate::ai::agent_management::AgentNotificationsModel; use crate::ai::cloud_environments::CloudAmbientAgentEnvironment; use crate::ai::conversation_status_ui::render_status_element; -use crate::cloud_object::CloudObjectLookup as _; use crate::cloud_object::model::generic_string_model::StringModel; +use crate::cloud_object::CloudObjectLookup as _; use crate::code::editor::{add_color, remove_color}; use crate::code::icon_from_file_path; use crate::safe_triangle::SafeTriangle; diff --git a/crates/mcp/Cargo.toml b/crates/mcp/Cargo.toml index 2c186e9fa4..801b9ba9b4 100644 --- a/crates/mcp/Cargo.toml +++ b/crates/mcp/Cargo.toml @@ -6,13 +6,30 @@ publish.workspace = true license.workspace = true [dependencies] +anyhow.workspace = true +async-channel.workspace = true +async-trait.workspace = true +cfg-if.workspace = true +cloud_object_models.workspace = true futures.workspace = true http = "1" +log.workspace = true +oauth2.workspace = true pin-project-lite = "0.2" reqwest = { workspace = true, features = ["json", "stream"] } -rmcp = { workspace = true, features = ["client", "client-side-sse", "auth"] } +rmcp = { workspace = true, features = ["client", "client-side-sse", "auth", "transport-streamable-http-client-reqwest", "transport-child-process"] } +serde.workspace = true serde_json.workspace = true +simple_logger.workspace = true sse-stream = "0.2" thiserror.workspace = true -tokio = { workspace = true, features = ["sync", "macros", "rt", "time"] } +tokio = { workspace = true, features = ["sync", "macros", "rt", "time", "process", "io-util"] } tracing.workspace = true +url.workspace = true +uuid.workspace = true +warp_core.workspace = true +warpui.workspace = true +warpui_extras = { workspace = true, features = ["default", "user_preferences-toml"] } + +[target.'cfg(windows)'.dependencies] +windows = { workspace = true, features = ["Win32_System_Threading"] } diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs index e9d2272fb5..65e43342f8 100644 --- a/crates/mcp/src/lib.rs +++ b/crates/mcp/src/lib.rs @@ -1 +1,3 @@ +pub mod oauth; +pub mod runtime; pub mod sse_transport; diff --git a/app/src/ai/mcp/templatable_manager/oauth.rs b/crates/mcp/src/oauth.rs similarity index 71% rename from app/src/ai/mcp/templatable_manager/oauth.rs rename to crates/mcp/src/oauth.rs index 5b3c0f3b29..c236646e2f 100644 --- a/app/src/ai/mcp/templatable_manager/oauth.rs +++ b/crates/mcp/src/oauth.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use anyhow::{anyhow, bail}; +use futures::future::BoxFuture; use oauth2::{RefreshToken, TokenResponse as _}; use rmcp::transport::{ auth::{ @@ -14,14 +14,10 @@ use serde::{Deserialize, Serialize}; use url::Url; use uuid::Uuid; use warp_core::channel::ChannelState; -use warpui::ModelSpawner; use warpui_extras::secure_storage::AppContextExt as _; -use super::{MCPServerState, TemplatableMCPServerManager}; -use {crate::ai::mcp::FileBasedMCPManager, warpui::SingletonEntity}; - -pub(crate) const TEMPLATABLE_MCP_CREDENTIALS_KEY: &str = "TemplatableMcpCredentials"; -pub(crate) const FILE_BASED_MCP_CREDENTIALS_KEY: &str = "FileBasedMcpCredentials"; +pub const TEMPLATABLE_MCP_CREDENTIALS_KEY: &str = "TemplatableMcpCredentials"; +pub const FILE_BASED_MCP_CREDENTIALS_KEY: &str = "FileBasedMcpCredentials"; /// The issuer URL for GitHub's OAuth provider. const GITHUB_ISSUER: &str = "https://github.com/login/oauth"; @@ -53,6 +49,12 @@ pub type PersistedCredentialsMap = HashMap; // Maps a consistent hash of the installation to its persisted credentials pub type FileBasedPersistedCredentialsMap = HashMap; +pub type PersistCredentialsCallback = + Box BoxFuture<'static, anyhow::Result<()>> + Send>; +pub type RequiresAuthenticationCallback = + Box BoxFuture<'static, anyhow::Result<()>> + Send>; +pub type AuthenticatedCallback = + Box BoxFuture<'static, anyhow::Result<()>> + Send>; /// A credential store that wraps [`InMemoryCredentialStore`] and persists token /// updates to Warp's secure storage via a channel. @@ -134,7 +136,7 @@ impl CredentialStore for PersistingCredentialStore { /// runtime token auto-refreshes are written back to Warp's secure storage. /// /// A background tokio task is spawned to receive credential updates and persist -/// them via the [`ModelSpawner`]. The task terminates when the auth manager (and +/// them via the provided callback. The task terminates when the auth manager (and /// thus the credential store's sender) is dropped. /// /// Note: this store is not responsible for the initial population of credentials. @@ -144,7 +146,7 @@ impl CredentialStore for PersistingCredentialStore { async fn install_persisting_credential_store( auth_manager: &mut AuthorizationManager, persisted_credentials: Option, - spawner: ModelSpawner, + persist_credentials: PersistCredentialsCallback, installation_uuid: Uuid, ) { let client_secret = persisted_credentials @@ -168,13 +170,8 @@ async fn install_persisting_credential_store( tokio::spawn(async move { while let Ok(credentials) = persist_rx.recv().await { - if let Err(e) = spawner - .spawn(move |manager, ctx| { - manager.save_credentials_to_secure_storage(ctx, installation_uuid, credentials); - }) - .await - { - log::warn!("Failed to persist auto-refreshed MCP credentials: {e:?}"); + if let Err(err) = persist_credentials(installation_uuid, credentials).await { + log::warn!("Failed to persist auto-refreshed MCP credentials: {err:?}"); } } }); @@ -183,13 +180,15 @@ async fn install_persisting_credential_store( /// Context for OAuth authentication flows. pub struct AuthContext { pub oauth_result_rx: async_channel::Receiver, - pub spawner: ModelSpawner, pub uuid: Uuid, pub persisted_credentials: Option, /// Whether the client is running in headless/CLI mode. pub is_headless: bool, /// Whether this server was auto-discovered from a repo MCP configuration file. pub is_file_based: bool, + pub persist_credentials: PersistCredentialsCallback, + pub requires_authentication: RequiresAuthenticationCallback, + pub authenticated: AuthenticatedCallback, } /// Result of OAuth callback. @@ -204,19 +203,20 @@ pub enum CallbackResult { /// This takes in the URL of the resource to authenticate for, and uses that /// to determine the authorization server. /// -/// Upon success, returns the client and a boolean indicating whether the user was required to -/// re-authenticate (e.g. re-log in). pub async fn make_authenticated_client( + server_name: &str, resource_url: &str, auth_context: AuthContext, -) -> Result<(AuthClient, bool), AuthError> { +) -> Result, AuthError> { let AuthContext { oauth_result_rx, - spawner, uuid, persisted_credentials, is_headless, is_file_based, + persist_credentials, + requires_authentication, + authenticated, } = auth_context; // Build the redirect URI using the channel's URL scheme. @@ -237,7 +237,7 @@ pub async fn make_authenticated_client( install_persisting_credential_store( &mut auth_manager, persisted_credentials, - spawner.clone(), + persist_credentials, uuid, ) .await; @@ -253,7 +253,7 @@ pub async fn make_authenticated_client( .with_client_secret(client_secret), )?; } - return Ok((AuthClient::new(reqwest::Client::new(), auth_manager), false)); + return Ok(AuthClient::new(reqwest::Client::new(), auth_manager)); } // If we're in headless mode and we reach here, it means we either have no credentials @@ -331,17 +331,8 @@ pub async fn make_authenticated_client( }) .unwrap_or_default(); - if let Err(e) = spawner - .spawn(move |manager, ctx| { - if !csrf_state.is_empty() { - manager.pending_oauth_csrf.insert(csrf_state, uuid); - } - ctx.open_url(&auth_url); - manager.change_server_state(uuid, MCPServerState::Authenticating, ctx); - }) - .await - { - log::warn!("Failed to emit RequiresAuthentication state: {e:?}"); + if let Err(err) = requires_authentication(uuid, csrf_state, auth_url).await { + log::warn!("Failed to emit RequiresAuthentication state: {err:?}"); } // Wait for the authorization code from the OAuth callback channel. @@ -366,111 +357,15 @@ pub async fn make_authenticated_client( AuthError::InternalError("Failed to create authorization manager".to_string()) })?; - Ok((AuthClient::new(reqwest::Client::new(), auth_manager), true)) -} - -impl TemplatableMCPServerManager { - /// Handles an incoming OAuth callback URL. - /// - /// Routes the callback to the correct in-flight OAuth flow using the `state` query - /// parameter (the CSRF token that rmcp embedded in the authorization URL). This avoids - /// encoding routing data in the redirect URI, keeping it RFC 6749 §3.1.2.2 compliant. - pub fn handle_oauth_callback(&mut self, url: &Url) -> anyhow::Result<()> { - // Ensure the URL has the expected path - if url.path() != "/oauth2callback" { - bail!( - "Invalid OAuth callback path: expected '/oauth2callback', got '{}'", - url.path() - ); - } - - let query_params: HashMap<_, _> = url.query_pairs().collect(); - - let Some(state) = query_params.get("state") else { - bail!("Missing 'state' parameter in OAuth callback"); - }; - - let code = query_params.get("code"); - let error = query_params.get("error"); - - let result = match code { - Some(code) => CallbackResult::Success { - code: code.to_string(), - // Pass the state value through as the CSRF token; rmcp will validate it - // against the token it stored when generating the authorization URL. - csrf_token: state.to_string(), - }, - None => CallbackResult::Error { - error: error.map(|e| e.to_string()), - }, - }; - - let Some(&server_uuid) = self.pending_oauth_csrf.get(state.as_ref() as &str) else { - bail!("No active OAuth flow found for state={state}"); - }; - - let Some(server_info) = self.spawned_servers.get(&server_uuid) else { - bail!("No spawned server found for uuid={server_uuid}"); - }; - - warpui::r#async::block_on(server_info.oauth_result_tx.send(result)).map_err(|_| { - anyhow!("Failed to send OAuth result to server {server_uuid} - receiver dropped") - })?; - - self.pending_oauth_csrf.remove(state.as_ref() as &str); - Ok(()) - } - - pub fn save_credentials_to_secure_storage( - &mut self, - app: &mut warpui::AppContext, - installation_uuid: Uuid, - credentials: PersistedCredentials, - ) { - if let Some(hash) = FileBasedMCPManager::as_ref(app).get_hash_by_uuid(installation_uuid) { - self.file_based_server_credentials.insert(hash, credentials); - write_to_secure_storage( - app, - FILE_BASED_MCP_CREDENTIALS_KEY, - &self.file_based_server_credentials, - ); - return; - } - - if let Some(template_uuid) = self.get_template_uuid(installation_uuid) { - self.server_credentials.insert(template_uuid, credentials); - write_to_secure_storage( - app, - TEMPLATABLE_MCP_CREDENTIALS_KEY, - &self.server_credentials, - ); - } else { - log::error!( - "Corresponding file or cloud-based server not found for installation UUID {installation_uuid}" - ); - } + if let Err(err) = authenticated(server_name.to_string()).await { + log::warn!("Failed to emit MCP authenticated notification: {err:?}"); } - pub fn delete_credentials_from_secure_storage( - &mut self, - installation_uuid: Uuid, - app: &mut warpui::AppContext, - ) { - if let Some(template_uuid) = self.get_template_uuid(installation_uuid) { - self.server_credentials.remove(&template_uuid); - write_to_secure_storage( - app, - TEMPLATABLE_MCP_CREDENTIALS_KEY, - &self.server_credentials, - ); - } else { - log::error!("No template UUID found for installation UUID {installation_uuid}"); - } - } + Ok(AuthClient::new(reqwest::Client::new(), auth_manager)) } /// Loads credentials from secure storage at the provided key. -pub(crate) fn load_credentials_from_secure_storage( +pub fn load_credentials_from_secure_storage( app: &mut warpui::AppContext, key: &str, ) -> T { @@ -487,7 +382,7 @@ pub(crate) fn load_credentials_from_secure_storage( +pub fn write_to_secure_storage( app: &mut warpui::AppContext, key: &str, credentials: &T, diff --git a/app/src/ai/mcp/templatable_manager/oauth_tests.rs b/crates/mcp/src/oauth_tests.rs similarity index 100% rename from app/src/ai/mcp/templatable_manager/oauth_tests.rs rename to crates/mcp/src/oauth_tests.rs diff --git a/crates/mcp/src/runtime.rs b/crates/mcp/src/runtime.rs new file mode 100644 index 0000000000..8bc66dd99e --- /dev/null +++ b/crates/mcp/src/runtime.rs @@ -0,0 +1,909 @@ +use std::collections::HashMap; +use std::future::Future; + +use cfg_if::cfg_if; +use cloud_object_models::{StaticEnvVar, TransportType}; +use futures::FutureExt as _; +use rmcp::{transport::ConfigureCommandExt as _, ServiceExt as _}; +use simple_logger::SimpleLogger; +use tokio::io::AsyncBufReadExt as _; +use uuid::Uuid; + +type ReqwestHttpTransport = rmcp::transport::StreamableHttpClientTransport; +type ReqwestSseTransport = crate::sse_transport::SseClientTransport; + +/// Information about a single connected MCP server. +#[cfg_attr(target_family = "wasm", allow(dead_code))] +pub struct TemplatableMCPServerInfo { + name: String, + service: rmcp::service::RunningService< + rmcp::RoleClient, + Box>, + >, + resources: Vec, + tools: Vec, + installation_id: Uuid, + description: Option, + /// Whether the underlying transport uses authentication. + /// + /// TODO(vorporeal): Use this to display a toast when server authentication and connection is complete, and + /// to provide a "log out" button. + #[allow(dead_code)] + is_authenticated_transport: bool, +} + +impl TemplatableMCPServerInfo { + pub fn name(&self) -> &str { + &self.name + } + + pub fn resources(&self) -> &Vec { + &self.resources + } + + pub fn tools(&self) -> &Vec { + &self.tools + } + + pub fn installation_id(&self) -> Uuid { + self.installation_id + } + + pub fn description(&self) -> Option<&str> { + self.description.as_deref() + } + + pub fn peer(&self) -> rmcp::Peer { + self.service.clone() + } + + pub fn peer_if_connected(&self) -> Option> { + if self.service.is_transport_closed() { + None + } else { + Some(self.service.clone()) + } + } + + pub fn has_tool(&self, tool_name: &str) -> bool { + self.tools.iter().any(|tool| tool.name == tool_name) + } + + pub fn has_resource(&self, resource: &rmcp::model::Resource) -> bool { + self.resources + .iter() + .any(|other_resource| resource.uri == other_resource.uri) + } + + pub fn has_resource_name_or_uri(&self, name: &str, uri: Option<&str>) -> bool { + self.resources.iter().any(|resource| { + if let Some(uri) = uri { + resource.uri == uri + } else { + resource.name == name + } + }) + } + + pub fn tool_input_schema( + &self, + tool_name: &str, + ) -> Option> { + self.tools + .iter() + .find(|tool| tool.name == tool_name) + .map(|tool| tool.input_schema.clone()) + } + + pub async fn shutdown(self) -> Result { + self.service.cancel().await + } +} + +/// Convert an rmcp error to a user-friendly error message. +pub fn error_to_user_message(error: &rmcp::RmcpError) -> String { + match error { + rmcp::RmcpError::ClientInitialize(err) => { + format!("Failed to initialize client: {}", err) + } + rmcp::RmcpError::ServerInitialize(err) => { + format!("Failed to initialize server: {}", err) + } + rmcp::RmcpError::TransportCreation { error, .. } => { + format!("Failed to establish connection: {}", error) + } + rmcp::RmcpError::Runtime(err) => { + format!("Runtime error: {}", err) + } + rmcp::RmcpError::Service(err) => match err { + rmcp::ServiceError::McpError(_) => { + "Server returned an error. Please check server logs for details.".to_string() + } + rmcp::ServiceError::TransportSend(_) => { + "Failed to send data to server. Connection may have been lost.".to_string() + } + rmcp::ServiceError::TransportClosed => { + "Connection closed unexpectedly. The server may have crashed.".to_string() + } + rmcp::ServiceError::UnexpectedResponse => { + "Server sent an unexpected response. The server may be incompatible.".to_string() + } + rmcp::ServiceError::Cancelled { reason } => format!( + "Operation was cancelled with reason: {}", + reason.clone().unwrap_or("Unknown reason".to_string()) + ), + rmcp::ServiceError::Timeout { timeout } => { + format!( + "Connection timed out after {} seconds. The server may be unresponsive.", + timeout.as_secs() + ) + } + _ => format!("Service error: {}", err), + }, + // The enum is marked as non-exhaustive, so we need a catch-all. + _ => { + format!("Error: {error}") + } + } +} + +/// Builds a `HeaderMap` from a `HashMap` of user-provided headers. +/// +/// Invalid header names or values are skipped. +fn build_header_map(headers: &HashMap) -> reqwest::header::HeaderMap { + headers.try_into().unwrap_or_default() +} + +/// Builds a reqwest client with custom headers for MCP HTTP/SSE connections. +#[allow(clippy::result_large_err)] +pub fn build_client_with_headers( + headers: &HashMap, +) -> Result { + let header_map = build_header_map(headers); + + reqwest::Client::builder() + .default_headers(header_map) + .build() + .map_err(|e| { + rmcp::RmcpError::transport_creation::(format!( + "Failed to build client with headers: {e}", + )) + }) +} + +/// Spawns a new MCP server from a given [`TransportType`]. +#[allow(clippy::result_large_err)] +pub async fn spawn_server( + server_name: String, + description: Option, + uuid: Uuid, + transport_type: TransportType, + logger: SimpleLogger, + auth_context: Option, +) -> Result { + logger.log("[note] Attention! There may be sensitive information (such as API keys) in these logs. Make sure to redact any secrets before sharing with others.".to_string()); + + let mut is_authenticated_transport = false; + let service = match transport_type { + TransportType::CLIServer(cli_server) => { + logger.log("[info] MCP: Using stdio transport".to_string()); + + cfg_if! { + if #[cfg(windows)] { + // We wrap the command in cmd.exe /c to allow Windows to be responsible for resolving the + // PATH variable rather than depending on the `Command` implementation, which only looks for + // `.exe` files in directories found in PATH. + // https://github.com/rust-lang/rust/issues/37519 + let command = "cmd.exe".to_owned(); + let args = std::iter::once("/c".to_owned()) + .chain(std::iter::once(cli_server.command)) + .chain(cli_server.args) + .collect::>(); + } else { + let command = cli_server.command; + let args = cli_server.args; + } + } + + // Capture the command and configured cwd for diagnostics before they're + // moved into the Command builder closure. + let command_for_log = command.clone(); + let cwd_for_log = cli_server.cwd_parameter.clone(); + + // Try to spawn the child process. + let (transport, stderr) = rmcp::transport::TokioChildProcess::builder( + tokio::process::Command::new(command).configure(|cmd| { + cmd.args(args); + if let Some(cwd) = cli_server.cwd_parameter { + cmd.current_dir(cwd); + } + for StaticEnvVar { name, value } in cli_server.static_env_vars.iter() { + if value.is_empty() { + // Skip empty/unset environment variables so that, in the CLI, they can be inherited. + logger.log(format!( + "[warn] MCP: Skipping empty environment variable: {name}" + )); + continue; + } + cmd.env(name, value); + } + + // On Windows, ensure that no console window is shown. + #[cfg(windows)] + cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0); + }), + ) + .stderr(std::process::Stdio::piped()) + .spawn() + .map_err(|err| { + if err.kind() == std::io::ErrorKind::NotFound { + let cwd_display = cwd_for_log + .as_deref() + .unwrap_or(""); + logger.log(format!( + "[error] MCP: Failed to spawn '{server_name}': command '{command_for_log}' \ + not found (cwd: {cwd_display}). If your MCP server depends on a specific \ + working directory, set the `working_directory` field in your config to \ + override the default." + )); + } + rmcp::RmcpError::transport_creation::(err) + })?; + + let pid = transport + .id() + .map(|pid| pid.to_string()) + .unwrap_or("??".to_string()); + + // We always expect to have an stderr, but this is marginally safer than unwrapping. + if let Some(stderr) = stderr { + let logger = logger.clone(); + // Spawn a background task to forward from the child process's stderr to our logger. + tokio::spawn(async move { + let mut buf = String::new(); + let mut reader = tokio::io::BufReader::new(stderr); + loop { + match reader.read_line(&mut buf).await { + // EOF. + Ok(0) => return, + // Read some data. + Ok(_) => logger.log(format!("[info] MCP [pid: {pid}] stderr: {buf}")), + // Failed to read from the child process's stderr. + Err(e) => { + log::error!("Failed to read stderr: {e}"); + return; + } + } + } + }); + } + + // Wrap the transport in a logging wrapper. + let transport = TransportLoggingWrapper { + transport, + logger: logger.clone(), + }; + + // Create the MCP client and connect to the server. + Ok::<_, rmcp::RmcpError>(make_client_info().into_dyn().serve(transport).await?) + } + TransportType::ServerSentEvents(sse_server) => { + let headers: HashMap = sse_server + .headers + .iter() + .map(|h| (h.name.clone(), h.value.clone())) + .collect(); + match determine_transport(server_name.clone(), &sse_server.url, &headers, auth_context) + .await + { + // TODO: these need headers also? + Ok(Transport::Http(Some(client))) => { + is_authenticated_transport = true; + + logger.log("[info] MCP: Using Streaming HTTP transport".to_string()); + let transport = rmcp::transport::StreamableHttpClientTransport::with_client( + client, + rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri( + sse_server.url.clone(), + ), + ); + let transport = TransportLoggingWrapper { + transport, + logger: logger.clone(), + }; + Ok(make_client_info().into_dyn().serve(transport).await?) + } + Ok(Transport::Http(None)) => { + logger.log("[info] MCP: Using Streaming HTTP transport".to_string()); + let transport = if headers.is_empty() { + rmcp::transport::StreamableHttpClientTransport::from_uri( + sse_server.url.clone(), + ) + } else { + let client = build_client_with_headers(&headers)?; + rmcp::transport::StreamableHttpClientTransport::with_client( + client, + rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri( + sse_server.url.clone(), + ), + ) + }; + let transport = TransportLoggingWrapper { + transport, + logger: logger.clone(), + }; + Ok(make_client_info().into_dyn().serve(transport).await?) + } + Ok(Transport::Sse(Some(client))) => { + is_authenticated_transport = true; + + logger.log("[info] MCP: Using (legacy) SSE transport (due to preflight failing with a 404)".to_string()); + let transport = crate::sse_transport::SseClientTransport::start_with_client( + client, + crate::sse_transport::SseClientConfig { + sse_endpoint: sse_server.url.into(), + ..Default::default() + }, + ) + .await + .map_err(rmcp::RmcpError::transport_creation::)?; + let transport = TransportLoggingWrapper { + transport, + logger: logger.clone(), + }; + Ok(make_client_info().into_dyn().serve(transport).await?) + } + Ok(Transport::Sse(None)) => { + logger.log("[info] MCP: Using (legacy) SSE transport (due to preflight failing with a 404)".to_string()); + let transport = if headers.is_empty() { + crate::sse_transport::SseClientTransport::start(sse_server.url.clone()) + .await + .map_err(|e| { + rmcp::RmcpError::transport_creation::(e) + })? + } else { + let client = build_client_with_headers(&headers)?; + crate::sse_transport::SseClientTransport::start_with_client( + client, + crate::sse_transport::SseClientConfig { + sse_endpoint: sse_server.url.clone().into(), + ..Default::default() + }, + ) + .await + .map_err(rmcp::RmcpError::transport_creation::)? + }; + let transport = TransportLoggingWrapper { + transport, + logger: logger.clone(), + }; + Ok(make_client_info().into_dyn().serve(transport).await?) + } + Err(err) => { + logger.log(format!( + "[error] MCP: preflight connection to MCP server failed: {err:#}" + )); + Err(err)? + } + } + } + }?; + + let server_info = service.peer_info(); + logger.log(format!("[info] MCP: Connected to server: {server_info:#?}")); + + let capabilities = server_info.map(|info| &info.capabilities); + + let resources = + query_resources_for(capabilities, &server_name, || service.list_all_resources()).await; + let tools = query_tools_for(capabilities, &server_name, || service.list_all_tools()).await; + + Ok(TemplatableMCPServerInfo { + name: server_name, + service, + resources, + tools, + installation_id: uuid, + description, + is_authenticated_transport, + }) +} + +/// The transport to use for MCP. +enum Transport { + /// The HTTP transport, with an optional authenticated client. + Http(Option>), + /// The SSE transport, with an optional authenticated client. + Sse(Option>), +} + +/// Determines which transport to use. +/// +/// This sends a "preflight" InitializeRequest to the server to determine whether the +/// server supports the HTTP transport (or needs to use the SSE transport), and if +/// authentication is required. +#[allow(clippy::result_large_err)] +async fn determine_transport( + server_name: String, + url: &str, + headers: &HashMap, + auth_context: Option, +) -> Result { + use reqwest::StatusCode; + + fn unexpected_error(status: reqwest::StatusCode) -> rmcp::RmcpError { + rmcp::RmcpError::transport_creation::(format!( + "Unexpected status code: {status}" + )) + } + match send_initialize_request(url, headers, None).await? { + StatusCode::OK => Ok(Transport::Http(None)), + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => Ok(Transport::Sse(None)), + StatusCode::UNAUTHORIZED => { + let Some(auth_context) = auth_context else { + return Err(rmcp::RmcpError::transport_creation::( + "Server requires authentication, which is not yet supported.".to_string(), + )); + }; + + let client = crate::oauth::make_authenticated_client(&server_name, url, auth_context) + .await + .map_err(rmcp::RmcpError::transport_creation::)?; + match send_initialize_request(url, headers, Some(&client)).await? { + StatusCode::OK => Ok(Transport::Http(Some(client))), + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { + Ok(Transport::Sse(Some(client))) + } + other => Err(unexpected_error(other)), + } + } + status => Err(unexpected_error(status)), + } +} + +/// Sends an InitializeRequest to the server, and returns the HTTP status code from the response. +#[allow(clippy::result_large_err)] +async fn send_initialize_request( + url: &str, + headers: &HashMap, + auth_client: Option<&rmcp::transport::auth::AuthClient>, +) -> Result { + use rmcp::transport::common::http_header::{EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE}; + + let request = rmcp::model::InitializeRequest::new(make_client_info()); + let request = rmcp::model::ClientJsonRpcMessage::request( + rmcp::model::ClientRequest::InitializeRequest(request), + rmcp::model::RequestId::Number(0), + ); + + let mut request = build_client_with_headers(headers)? + .post(url) + .header( + http::header::ACCEPT, + [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "), + ) + .json(&request); + + if let Some(auth_client) = auth_client.as_ref() { + let access_token = auth_client + .get_access_token() + .await + .map_err(rmcp::RmcpError::transport_creation::)?; + request = request.bearer_auth(access_token); + } + + let response = request + .send() + .await + .map_err(rmcp::RmcpError::transport_creation::)?; + + Ok(response.status()) +} + +/// Creates a [`ClientInfo`] for the MCP client. +/// +/// This tells the MCP server who we are and what capabilities we have. +fn make_client_info() -> rmcp::model::ClientInfo { + rmcp::model::ClientInfo::new( + Default::default(), + rmcp::model::Implementation::new( + warp_core::channel::ChannelState::app_id().to_string(), + warp_core::channel::ChannelState::app_version() + .map(|v| v.to_string()) + .unwrap_or_default(), + ), + ) +} + +/// Whether to query `resources/list` for a server with the given capabilities. +/// +/// Per the MCP spec, the client should only invoke a list method when the server +/// has advertised the corresponding capability during initialization. +fn should_query_resources(capabilities: Option<&rmcp::model::ServerCapabilities>) -> bool { + capabilities.is_some_and(|c| c.resources.is_some()) +} + +/// Whether to query `tools/list` for a server with the given capabilities. +/// +/// Per the MCP spec, the client should only invoke a list method when the server +/// has advertised the corresponding capability during initialization. +fn should_query_tools(capabilities: Option<&rmcp::model::ServerCapabilities>) -> bool { + capabilities.is_some_and(|c| c.tools.is_some()) +} + +/// Query `resources/list` for a connected MCP server. +/// +/// Skips the call entirely when `resources` was not advertised. Treats any +/// listing error as "no resources" (fail-soft) so a flaky `resources/list` +/// does not abort the entire server startup. Mirrors the behavior of +/// [`query_tools_for`] so the two capabilities are handled symmetrically. +async fn query_resources_for( + capabilities: Option<&rmcp::model::ServerCapabilities>, + server_name: &str, + list_resources: F, +) -> Vec +where + F: FnOnce() -> Fut, + Fut: Future, rmcp::ServiceError>>, +{ + if !should_query_resources(capabilities) { + return Vec::new(); + } + match list_resources().await { + Ok(result) => result, + Err(err) => { + log::warn!("Failed to list resources for MCP server '{server_name}': {err}"); + Vec::new() + } + } +} + +/// Query `tools/list` for a connected MCP server. +/// +/// Skips the call entirely when `tools` was not advertised. Treats any listing +/// error as "no tools" (fail-soft) so a transient `tools/list` failure does +/// not abort the entire server startup — the user-visible regression #6798 +/// was rooted in the prior asymmetric handling, where a tools-list error on +/// a server with healthy resources would propagate and fail startup. +async fn query_tools_for( + capabilities: Option<&rmcp::model::ServerCapabilities>, + server_name: &str, + list_tools: F, +) -> Vec +where + F: FnOnce() -> Fut, + Fut: Future, rmcp::ServiceError>>, +{ + if !should_query_tools(capabilities) { + return Vec::new(); + } + match list_tools().await { + Ok(result) => result, + Err(err) => { + log::warn!("Failed to list tools for MCP server '{server_name}': {err}"); + Vec::new() + } + } +} + +/// A wrapper around a [`rmcp::transport::Transport`] that logs all requests and responses. +struct TransportLoggingWrapper { + transport: T, + logger: SimpleLogger, +} + +impl, R: rmcp::service::ServiceRole> rmcp::transport::Transport + for TransportLoggingWrapper +{ + type Error = T::Error; + + fn send( + &mut self, + item: rmcp::service::TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + if let Ok(json) = serde_json::to_string(&item) { + self.logger + .log(format!("[info] MCP: Sending request: {json}")); + } + + let logger = self.logger.clone(); + self.transport.send(item).map(move |result| { + if let Err(e) = &result { + logger.log(format!("[warn] MCP: Failed to send request: {e:#}")); + } + result + }) + } + + fn receive( + &mut self, + ) -> impl Future>> + Send { + let logger = self.logger.clone(); + async move { + let result = self.transport.receive().await; + if let Some(item) = &result { + if let Ok(json) = serde_json::to_string(item) { + logger.log(format!("[info] MCP: Received response: {json}")); + } + } + result + } + } + + fn close(&mut self) -> impl Future> + Send { + self.transport.close() + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use rmcp::model::{ErrorCode, ErrorData, Resource, ServerCapabilities, Tool}; + + use super::{query_resources_for, query_tools_for, should_query_resources, should_query_tools}; + + /// Builds `ServerCapabilities` with selected capability flags toggled on. + fn caps(tools: bool, resources: bool) -> ServerCapabilities { + match (tools, resources) { + (true, true) => ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + (true, false) => ServerCapabilities::builder().enable_tools().build(), + (false, true) => ServerCapabilities::builder().enable_resources().build(), + (false, false) => ServerCapabilities::builder().build(), + } + } + + fn test_tool(name: &str) -> Tool { + serde_json::from_value(serde_json::json!({ + "name": name, + "description": "test tool", + "inputSchema": { "type": "object" }, + })) + .expect("Tool deserialization") + } + + fn test_resource(uri: &str) -> Resource { + serde_json::from_value(serde_json::json!({ + "uri": uri, + "name": "test resource", + })) + .expect("Resource deserialization") + } + + /// Regression test for warpdotdev/warp#6798. + #[test] + fn each_capability_is_queried_independently() { + for has_tools in [false, true] { + for has_resources in [false, true] { + let c = caps(has_tools, has_resources); + assert_eq!( + should_query_tools(Some(&c)), + has_tools, + "tools={has_tools}, resources={has_resources}", + ); + assert_eq!( + should_query_resources(Some(&c)), + has_resources, + "tools={has_tools}, resources={has_resources}", + ); + } + } + assert!(!should_query_tools(None)); + assert!(!should_query_resources(None)); + } + + /// Skips `tools/list` when `tools` is not advertised. + #[tokio::test] + async fn query_tools_for_skips_listing_when_capability_not_advertised() { + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + let no_caps = caps(false, false); + + let result = query_tools_for(Some(&no_caps), "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_tool("never")]) + }) + .await; + + assert!(result.is_empty()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + /// Skips `tools/list` when server info is absent. + #[tokio::test] + async fn query_tools_for_skips_listing_when_server_info_is_none() { + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + + let result = query_tools_for(None, "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_tool("never")]) + }) + .await; + + assert!(result.is_empty()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + /// Returns listed tools when `tools` is advertised. + #[tokio::test] + async fn query_tools_for_returns_listed_tools_when_capability_advertised() { + let c = caps(true, false); + let expected = vec![test_tool("greet"), test_tool("review")]; + let to_return = expected.clone(); + + let result = query_tools_for(Some(&c), "srv", || async move { Ok(to_return) }).await; + + assert_eq!(result, expected); + } + + /// Returns an empty vector when the server lists no tools. + #[tokio::test] + async fn query_tools_for_returns_empty_vec_when_server_lists_no_tools() { + let c = caps(true, false); + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + + let result = query_tools_for(Some(&c), "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(Vec::new()) + }) + .await; + + assert!(result.is_empty()); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + /// Fails soft when `tools/list` sees a transport error. + #[tokio::test] + async fn query_tools_for_returns_empty_on_transport_error() { + let c = caps(true, false); + let result = query_tools_for(Some(&c), "srv", || async { + Err(rmcp::ServiceError::TransportClosed) + }) + .await; + assert!(result.is_empty()); + } + + /// Fails soft when `tools/list` returns an MCP protocol error. + #[tokio::test] + async fn query_tools_for_returns_empty_on_mcp_error() { + let c = caps(true, false); + let result = query_tools_for(Some(&c), "srv", || async { + Err(rmcp::ServiceError::McpError(ErrorData { + code: ErrorCode::METHOD_NOT_FOUND, + message: "tools/list not implemented".into(), + data: None, + })) + }) + .await; + assert!(result.is_empty()); + } + + /// Calls the `tools/list` function exactly once per query. + #[tokio::test] + async fn query_tools_for_calls_list_function_exactly_once() { + let c = caps(true, false); + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + + let _ = query_tools_for(Some(&c), "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_tool("p")]) + }) + .await; + + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + /// Keeps the tools-listing decision independent of resource capability state. + #[tokio::test] + async fn query_tools_for_decision_independent_of_other_capabilities() { + let tools = vec![test_tool("x")]; + for has_tools in [false, true] { + for has_resources in [false, true] { + let c = caps(has_tools, has_resources); + let to_return = tools.clone(); + let result = + query_tools_for(Some(&c), "srv", || async move { Ok(to_return) }).await; + + if has_tools { + assert_eq!(result, tools); + } else { + assert!(result.is_empty()); + } + } + } + } + + /// Skips `resources/list` when `resources` is not advertised. + #[tokio::test] + async fn query_resources_for_skips_listing_when_capability_not_advertised() { + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + let no_caps = caps(false, false); + + let result = query_resources_for(Some(&no_caps), "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_resource("file:///nope")]) + }) + .await; + + assert!(result.is_empty()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + /// Skips `resources/list` when server info is absent. + #[tokio::test] + async fn query_resources_for_skips_listing_when_server_info_is_none() { + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + + let result = query_resources_for(None, "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_resource("file:///nope")]) + }) + .await; + + assert!(result.is_empty()); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + /// Returns listed resources when `resources` is advertised. + #[tokio::test] + async fn query_resources_for_returns_listed_resources_when_capability_advertised() { + let c = caps(false, true); + let expected = vec![test_resource("file:///a"), test_resource("file:///b")]; + let to_return = expected.clone(); + + let result = query_resources_for(Some(&c), "srv", || async move { Ok(to_return) }).await; + + assert_eq!(result, expected); + } + + /// Fails soft when `resources/list` sees a transport error. + #[tokio::test] + async fn query_resources_for_returns_empty_on_transport_error() { + let c = caps(false, true); + let result = query_resources_for(Some(&c), "srv", || async { + Err(rmcp::ServiceError::TransportClosed) + }) + .await; + assert!(result.is_empty()); + } + + /// Fails soft when `resources/list` returns an MCP protocol error. + #[tokio::test] + async fn query_resources_for_returns_empty_on_mcp_error() { + let c = caps(false, true); + let result = query_resources_for(Some(&c), "srv", || async { + Err(rmcp::ServiceError::McpError(ErrorData { + code: ErrorCode::METHOD_NOT_FOUND, + message: "resources/list not implemented".into(), + data: None, + })) + }) + .await; + assert!(result.is_empty()); + } + + /// Calls the `resources/list` function exactly once per query. + #[tokio::test] + async fn query_resources_for_calls_list_function_exactly_once() { + let c = caps(false, true); + let calls = Arc::new(AtomicUsize::new(0)); + let calls_clone = calls.clone(); + + let _ = query_resources_for(Some(&c), "srv", || async move { + calls_clone.fetch_add(1, Ordering::SeqCst); + Ok(vec![test_resource("file:///a")]) + }) + .await; + + assert_eq!(calls.load(Ordering::SeqCst), 1); + } +} diff --git a/app/src/server/datetime_ext.rs b/crates/warp_core/src/datetime_ext.rs similarity index 100% rename from app/src/server/datetime_ext.rs rename to crates/warp_core/src/datetime_ext.rs diff --git a/crates/warp_core/src/lib.rs b/crates/warp_core/src/lib.rs index ad447eefe0..764ed8c1a0 100644 --- a/crates/warp_core/src/lib.rs +++ b/crates/warp_core/src/lib.rs @@ -3,6 +3,7 @@ pub mod assertions; pub mod channel; pub mod command; pub mod context_flag; +pub mod datetime_ext; pub mod errors; pub mod execution_mode; pub mod features; From c38de1233aed8ab16cff700f417592d4b3a0802a Mon Sep 17 00:00:00 2001 From: David Stern Date: Thu, 21 May 2026 12:35:29 -0400 Subject: [PATCH 2/2] Restore reviewed MCP runtime comments Co-Authored-By: Oz --- app/src/ai/mcp/templatable_manager/native.rs | 4 ++- crates/mcp/src/runtime.rs | 32 +++++++++++++++++--- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/app/src/ai/mcp/templatable_manager/native.rs b/app/src/ai/mcp/templatable_manager/native.rs index 6c0440f3eb..41e2e515b0 100644 --- a/app/src/ai/mcp/templatable_manager/native.rs +++ b/app/src/ai/mcp/templatable_manager/native.rs @@ -133,8 +133,10 @@ impl TemplatableMCPServerManager { /// Handles an incoming OAuth callback URL. /// /// Routes the callback to the correct in-flight OAuth flow using the `state` query - /// parameter that rmcp embedded in the authorization URL. + /// parameter (the CSRF token that rmcp embedded in the authorization URL). This avoids + /// encoding routing data in the redirect URI, keeping it RFC 6749 §3.1.2.2 compliant. pub fn handle_oauth_callback(&mut self, url: &Url) -> anyhow::Result<()> { + // Ensure the URL has the expected path if url.path() != "/oauth2callback" { anyhow::bail!( "Invalid OAuth callback path: expected '/oauth2callback', got '{}'", diff --git a/crates/mcp/src/runtime.rs b/crates/mcp/src/runtime.rs index 8bc66dd99e..b8eda55ac8 100644 --- a/crates/mcp/src/runtime.rs +++ b/crates/mcp/src/runtime.rs @@ -1,3 +1,10 @@ +//! Capability-gating helpers used during MCP server startup. +//! +//! Each `query_*_for` function pairs a capability check with the actual list +//! call from rmcp, gating the call on advertisement and failing soft on errors. +//! They take the list call as a closure so unit tests can drive the gate-and- +//! fail-soft control flow with a fake `RunningService` substitute. + use std::collections::HashMap; use std::future::Future; @@ -446,6 +453,8 @@ async fn determine_transport( )); }; + // Go through the OAuth flow to get an authenticated client. + // This will first attempt to use cached credentials before starting interactive OAuth. let client = crate::oauth::make_authenticated_client(&server_name, url, auth_context) .await .map_err(rmcp::RmcpError::transport_creation::)?; @@ -644,7 +653,9 @@ mod tests { use super::{query_resources_for, query_tools_for, should_query_resources, should_query_tools}; - /// Builds `ServerCapabilities` with selected capability flags toggled on. + /// Build a `ServerCapabilities` with selected capability flags toggled on. + /// Each `Some(default)` mirrors how rmcp deserializes a capability the + /// server advertised with no inner flags set. fn caps(tools: bool, resources: bool) -> ServerCapabilities { match (tools, resources) { (true, true) => ServerCapabilities::builder() @@ -674,7 +685,10 @@ mod tests { .expect("Resource deserialization") } - /// Regression test for warpdotdev/warp#6798. + /// Regression test for warpdotdev/warp#6798: each capability is queried + /// independently. Previously, asymmetric handling could cause `tools/list` + /// to be skipped when a server advertised both `tools` and `resources`, + /// resulting in "No tools available" even though the server had tools. #[test] fn each_capability_is_queried_independently() { for has_tools in [false, true] { @@ -696,7 +710,9 @@ mod tests { assert!(!should_query_resources(None)); } - /// Skips `tools/list` when `tools` is not advertised. + /// When `tools` is not advertised, the helper must skip the list call so + /// we don't waste a round trip and pollute the wire log with a request + /// that's destined to return `METHOD_NOT_FOUND`. #[tokio::test] async fn query_tools_for_skips_listing_when_capability_not_advertised() { let calls = Arc::new(AtomicUsize::new(0)); @@ -758,7 +774,11 @@ mod tests { assert_eq!(calls.load(Ordering::SeqCst), 1); } - /// Fails soft when `tools/list` sees a transport error. + /// **The fail-soft test the bug ticket implicitly demands.** Transport- + /// closed errors must not abort server startup; the helper must log and + /// return an empty vec. This is the regression-protector for #6798's + /// underlying asymmetry — if anyone re-introduces a `return Err(...)` here, + /// this test fails. #[tokio::test] async fn query_tools_for_returns_empty_on_transport_error() { let c = caps(true, false); @@ -769,7 +789,9 @@ mod tests { assert!(result.is_empty()); } - /// Fails soft when `tools/list` returns an MCP protocol error. + /// MCP-protocol errors (e.g. METHOD_NOT_FOUND from a misbehaving server + /// that advertised the capability but rejects the call) also fail soft, + /// so the rest of the server surface still comes up. #[tokio::test] async fn query_tools_for_returns_empty_on_mcp_error() { let c = caps(true, false);