diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index e1464814656..d977e41247a 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -127,7 +127,9 @@ fn map_procedure_error(e: ProcedureCallError, procedure: &str) -> (StatusCode, S StatusCode::NOT_FOUND } ProcedureCallError::OutOfEnergy => StatusCode::PAYMENT_REQUIRED, - ProcedureCallError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, + ProcedureCallError::GuestPanic(_) | ProcedureCallError::InvalidReturnValue(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } }; log::error!("Error while invoking procedure {e:#}"); (status_code, format!("{:#}", anyhow::anyhow!(e))) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index e2131abfa31..ce7eb107638 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::Display; use std::future::Future; use std::num::NonZeroUsize; @@ -23,12 +24,12 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, - ToProtocol, + serialize, serialize_v2, CloseCode as StCloseCode, CloseFrame as StCloseFrame, IdentityTokenMessage, + InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ - ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, - MessageHandleError, MeteredReceiver, MeteredSender, OutboundMessage, Protocol, WsVersion, + ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, ClientDisconnectSender, DataMessage, + MessageExecutionError, MessageHandleError, MeteredReceiver, MeteredSender, OutboundMessage, Protocol, WsVersion, }; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::NoSuchModule; @@ -52,7 +53,8 @@ use tokio_tungstenite::tungstenite::Utf8Bytes; use crate::auth::SpacetimeAuth; use crate::util::serde::humantime_duration; use crate::util::websocket::{ - CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError, + CloseCode as WsCloseCode, CloseFrame as WsCloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, + WebSocketUpgrade, WsError, }; use crate::util::{NameOrIdentity, XForwardedFor}; use crate::{log_and_500, Authorization, ControlStateDelegate, NodeDelegate}; @@ -267,7 +269,8 @@ where "websocket: Database accepted connection from {client_log_string}; spawning ws_client_actor and ClientConnection" ); - let actor = |client, receiver| ws_client_actor(ws_opts, client, ws, receiver); + let (disconnect_tx, disconnect_rx) = mpsc::unbounded_channel(); + let actor = |client, receiver| ws_client_actor(ws_opts, client, ws, receiver, disconnect_rx); let client = ClientConnection::spawn( client_id, auth.into(), @@ -275,6 +278,7 @@ where client_config, leader.replica_id, module_rx, + Some(ClientDisconnectSender::new(disconnect_tx)), actor, connected, ) @@ -432,13 +436,14 @@ async fn ws_client_actor( client: ClientConnection, ws: WebSocketStream, sendrx: ClientConnectionReceiver, + disconnect_rx: mpsc::UnboundedReceiver, ) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { tokio::spawn(client.disconnect()); }); - ws_client_actor_inner(&mut client, options, ws, sendrx).await; + ws_client_actor_inner(&mut client, options, ws, sendrx, disconnect_rx).await; ScopeGuard::into_inner(client).disconnect().await; } @@ -448,14 +453,21 @@ async fn ws_client_actor_inner( config: WebSocketOptions, ws: WebSocketStream, sendrx: ClientConnectionReceiver, + mut disconnect_rx: mpsc::UnboundedReceiver, ) { let database = client.module().info().database_identity; let client_id = client.id; let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database); let state = Arc::new(ActorState::new(database, client_id, config)); - // Channel for [`UnorderedWsMessage`]s. - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + // Channel for websocket control traffic (`Close`, `Ping`, unordered execution errors). + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); + let ws_control_tx_for_disconnect = ws_control_tx.clone(); + tokio::spawn(async move { + while let Some(close_frame) = disconnect_rx.recv().await { + let _ = ws_control_tx_for_disconnect.send(WsControlMessage::Close(close_frame)); + } + }); // Split websocket into send and receive halves. let (ws_send, ws_recv) = ws.split(); @@ -467,13 +479,13 @@ async fn ws_client_actor_inner( let bsatn_rlb_pool = client.module().subscriptions().bsatn_rlb_pool.clone(); // Spawn a task to send outgoing messages - // obtained from `sendrx` and `unordered_rx`. + // obtained from `sendrx` and `ws_control_rx`. let send_task = tokio::spawn(ws_send_loop( state.clone(), client.config, ws_send, sendrx, - unordered_rx, + ws_control_rx, bsatn_rlb_pool, )); // Spawn a task to handle incoming messages. @@ -488,7 +500,7 @@ async fn ws_client_actor_inner( async move { client.handle_message(data, timer.into()).await } } }, - unordered_tx.clone(), + ws_control_tx.clone(), ws_recv, client.config.version, )); @@ -501,7 +513,7 @@ async fn ws_client_actor_inner( }; ws_main_loop(state, hotswap, idle_timer, send_task, recv_task, move |msg| { - let _ = unordered_tx.send(msg); + let _ = ws_control_tx.send(msg); }) .await; log::info!("Client connection ended: {client_id}"); @@ -578,13 +590,13 @@ async fn ws_client_actor_inner( /// The idle timer should be reset whenever data is received from the websocket. /// /// * **send_task**: -/// Task handling outgoing messages. Holds the receive end of `unordered_tx`. +/// Task handling outgoing messages. Holds the receive end of `ws_control_tx`. /// /// If the task returns, the connection is considered bad, and the main loop /// exits. If the task panicked, the panic is resumed on the current thread. /// /// Note that the send task must not terminate after it has sent a `Close` -/// frame (via `unordered_tx`) -- the websocket protocol mandates that the +/// frame (via `ws_control_tx`) -- the websocket protocol mandates that the /// initiator of the close handshake wait for the other end to respond with /// a `Close` frame. Thus, the loop must continue to poll `recv_task` and not /// exit due to `send_task` being complete. @@ -600,7 +612,7 @@ async fn ws_client_actor_inner( /// /// See [`ws_recv_task`]. /// -/// * **unordered_tx**: +/// * **ws_control_tx**: /// Channel connected to `send_task` that allows the loop to send `Ping` and /// `Close` frames. /// @@ -616,7 +628,7 @@ async fn ws_main_loop( idle_timer: impl Future, mut send_task: JoinHandle<()>, mut recv_task: JoinHandle<()>, - unordered_tx: impl Fn(UnorderedWsMessage), + ws_control_tx: impl Fn(WsControlMessage), ) where HotswapWatcher: Future>, { @@ -678,11 +690,11 @@ async fn ws_main_loop( // Branch is disabled if we already sent a close frame. res = &mut watch_hotswap, if !closed => { if let Err(NoSuchModule) = res { - let close = CloseFrame { - code: CloseCode::Away, + let close = StCloseFrame { + code: StCloseCode::Away, reason: "module exited".into() }; - unordered_tx(close.into()); + ws_control_tx(close.into()); } watch_hotswap.set(hotswap()); }, @@ -701,7 +713,7 @@ async fn ws_main_loop( _ = ping_interval.tick(), if !closed => { let was_ponged = state.reset_ponged(); if was_ponged { - unordered_tx(UnorderedWsMessage::Ping(Bytes::new())); + ws_control_tx(WsControlMessage::Ping(Bytes::new())); } } } @@ -745,7 +757,7 @@ async fn ws_idle_timer(mut activity: watch::Receiver) { /// `idle_tx` is the sending end of a [`ws_idle_timer`]. The [`ws_recv_loop`] /// sends a new, extended deadline whenever it receives a message. /// -/// `unordered_tx` is used to send message execution errors +/// `ws_control_tx` is used to send message execution errors /// or to initiate a close handshake. /// /// Initiates a close handshake if the `message_handler` returns any variant @@ -754,7 +766,7 @@ async fn ws_idle_timer(mut activity: watch::Receiver) { /// Terminates if: /// /// - the `ws` stream is exhausted -/// - or, `unordered_tx` is already closed +/// - or, `ws_control_tx` is already closed /// /// In the latter case, we assume that the connection is in an errored state, /// such that we wouldn't be able to receive any more messages anyway. @@ -763,7 +775,7 @@ async fn ws_recv_task( idle_tx: watch::Sender, client_closed_metric: IntGauge, message_handler: impl Fn(DataMessage, Instant) -> MessageHandler, - unordered_tx: mpsc::UnboundedSender, + ws_control_tx: mpsc::UnboundedSender, ws: impl Stream> + Unpin + Send + 'static, ws_version: WsVersion, ) where @@ -772,7 +784,7 @@ async fn ws_recv_task( let recv_queue_gauge = WORKER_METRICS .total_incoming_queue_length .with_label_values(&state.database); - let recv_queue = ws_recv_queue(state.clone(), unordered_tx.clone(), recv_queue_gauge, ws); + let recv_queue = ws_recv_queue(state.clone(), ws_control_tx.clone(), recv_queue_gauge, ws); let recv_loop = pin!(ws_recv_loop(state.clone(), idle_tx, recv_queue)); let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop); pin_mut!(recv_handler); @@ -785,25 +797,34 @@ async fn ws_recv_task( { log::error!("{err:#}"); // If the send task has exited, also exit this recv task. - if unordered_tx.send(err.into()).is_err() { + if ws_control_tx.send(err.into()).is_err() { break; } continue; } log::debug!("Client caused error: {e}"); - let close = CloseFrame { - code: CloseCode::Error, - reason: format!("{e:#}").into(), - }; + let close = close_frame_for_error(e); + // If the send task has exited, also exit this recv task. // No need to send the close handshake in that case; the client is already gone. - if unordered_tx.send(close.into()).is_err() { + if ws_control_tx.send(close.into()).is_err() { break; }; } } } +fn close_frame_for_error(e: MessageHandleError) -> StCloseFrame { + if let MessageHandleError::DisconnectClient(frame) = e { + frame + } else { + StCloseFrame { + code: StCloseCode::Protocol, + reason: format!("{e:#}").into(), + } + } +} + /// Stream that consumes a stream of [`WsMessage`]s and yields [`ClientMessage`]s. /// /// Terminates if: @@ -903,7 +924,7 @@ fn ws_recv_loop( /// /// The channel is initialized with [`ActorConfig::incoming_queue_length`]. /// If it is at capacity, a connection shutdown is initiated by sending -/// [`UnorderedWsMessage::Close`] via `unordered_tx`. +/// [`WsControlMessage::Close`] via `ws_control_tx`. /// /// Returns the channel receiver. /// @@ -914,13 +935,13 @@ fn ws_recv_loop( /// [#1851]: https://github.com/clockworklabs/SpacetimeDBPrivate/issues/1851 fn ws_recv_queue( state: Arc, - unordered_tx: mpsc::UnboundedSender, + ws_control_tx: mpsc::UnboundedSender, recv_queue_gauge: IntGauge, mut ws: impl Stream> + Unpin + Send + 'static, ) -> impl Stream> { - const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Again, - reason: Utf8Bytes::from_static("too many requests"), + const CLOSE: WsControlMessage = WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Again, + reason: Cow::Borrowed("too many requests"), }); let on_message_after_close = move |client_id| { log::warn!("client {client_id} sent message after close or error"); @@ -950,7 +971,7 @@ fn ws_recv_queue( // - Then exit the loop, as we won't be processing any // more messages, and we don't expect a close response // to arrive from the client. - if unordered_tx.send(CLOSE).is_err() { + if ws_control_tx.send(CLOSE).is_err() { state.close(); break; } @@ -1049,9 +1070,9 @@ fn ws_client_message_handler( /// Outgoing messages that don't need to be ordered wrt subscription updates. #[derive(Debug, From)] -enum UnorderedWsMessage { +enum WsControlMessage { /// Server-initiated close. - Close(CloseFrame), + Close(spacetimedb::client::messages::CloseFrame), /// Server-initiated ping. Ping(Bytes), /// Error calling a reducer. @@ -1093,32 +1114,32 @@ impl Receiver for mpsc::Receiver { /// Consumes `messages`, which yields subscription updates and reducer call /// results. Note that [`SerializableMessage`]s require serialization and /// potentially compression, which can be costly. -/// Also consumes `unordered`, which yields [`UnorderedWsMessage`]s. +/// Also consumes `ws_control`, which yields [`WsControlMessage`]s. /// /// Terminates if: /// -/// - `unordered` is closed +/// - `ws_control` is closed /// - an error occurs sending to the `ws` sink /// -/// If an [`UnorderedWsMessage::Close`] is encountered, a close frame is sent +/// If an [`WsControlMessage::Close`] is encountered, a close frame is sent /// to the `ws` sink, and `state.close()` is called. When this happens, /// `messages` will no longer be polled (no data can be sent after a close /// frame anyways), so `messages.close()` will be called. /// -/// Keeps polling `unordered` if `state.closed()`, but discards all data. +/// Keeps polling `ws_control` if `state.closed()`, but discards all data. /// This is so `ws_client_actor_inner` keeps polling the receive end of the /// socket until the close handshake completes -- it would otherwise exit early -/// when sending to `unordered` fails. +/// when sending to `ws_control` fails. async fn ws_send_loop( state: Arc, config: ClientConfig, ws: impl Sink + Unpin, messages: impl Receiver, - unordered: mpsc::UnboundedReceiver, + ws_control: mpsc::UnboundedReceiver, bsatn_rlb_pool: BsatnRowListBuilderPool, ) { let metrics = SendMetrics::new(state.database); - ws_send_loop_inner(state, ws, messages, unordered, move |encode_rx, frames_tx| { + ws_send_loop_inner(state, ws, messages, ws_control, move |encode_rx, frames_tx| { ws_encode_task(metrics, config, encode_rx, frames_tx, bsatn_rlb_pool) }) .await @@ -1128,7 +1149,7 @@ async fn ws_send_loop_inner( state: Arc, mut ws: impl Sink + Unpin, mut messages: impl Receiver, - mut unordered: mpsc::UnboundedReceiver, + mut ws_control: mpsc::UnboundedReceiver, encoder: impl FnOnce(mpsc::UnboundedReceiver, mpsc::UnboundedSender) -> Encoder, ) where T: Into, @@ -1138,7 +1159,7 @@ async fn ws_send_loop_inner( // The number of frames we'll `feed` to the `ws` sink in one iteration // of the `select!` loop. // - // This batching is done to allow control messages appearing on `unordered` + // This batching is done to allow control messages appearing on `ws_control` // to be interleaved with the sending of large messages split across some // number of frames. // @@ -1171,19 +1192,19 @@ async fn ws_send_loop_inner( biased; // Check for control messages or execution errors. - maybe_msg = unordered.recv() => { + maybe_msg = ws_control.recv() => { let Some(msg) = maybe_msg else { break; }; // We shall not send more data after a close frame, - // but keep polling `unordered` so that `ws_client_actor` keeps + // but keep polling `ws_control` so that `ws_client_actor` keeps // waiting for an acknowledgement from the client, // even if it spuriously initiates another close itself. if closed { continue; } match msg { - UnorderedWsMessage::Close(close_frame) => { + WsControlMessage::Close(close_frame) => { log::trace!("intiating close"); // Send outstanding frames until one that has the FIN // bit set. Ensures the client won't receive partial @@ -1202,7 +1223,7 @@ async fn ws_send_loop_inner( } // Then send the close frame. log::trace!("sending close frame"); - if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await { + if let Err(e) = ws.send(WsMessage::Close(Some(convert_close_frame(close_frame)))).await { log::warn!("error sending close frame: {e:#}"); break; } @@ -1221,14 +1242,14 @@ async fn ws_send_loop_inner( // so let senders know. messages.close(); }, - UnorderedWsMessage::Ping(bytes) => { + WsControlMessage::Ping(bytes) => { log::trace!("sending ping"); if let Err(e) = ws.feed(WsMessage::Ping(bytes)).await { log::warn!("error sending ping: {e:#}"); break; } }, - UnorderedWsMessage::Error(err) => { + WsControlMessage::Error(err) => { log::trace!("encoding execution error"); encode_tx .send(err.into()) @@ -1279,6 +1300,21 @@ async fn ws_send_loop_inner( } } +fn convert_close_frame(frame: StCloseFrame) -> WsCloseFrame { + WsCloseFrame { + code: match frame.code { + StCloseCode::Again => WsCloseCode::Again, + StCloseCode::Invalid => WsCloseCode::Invalid, + StCloseCode::Away => WsCloseCode::Away, + StCloseCode::Protocol => WsCloseCode::Protocol, + }, + reason: match frame.reason { + Cow::Borrowed(reason) => Utf8Bytes::from_static(reason), + Cow::Owned(reason) => reason.into(), + }, + } +} + #[derive(From)] enum OutboundWsMessage { Error(MessageExecutionError), @@ -1523,7 +1559,7 @@ enum ClientMessage { Message(DataMessage), Ping(Bytes), Pong(Bytes), - Close(Option), + Close(Option), } impl ClientMessage { @@ -1786,17 +1822,17 @@ mod tests { } #[tokio::test] - async fn send_loop_terminates_when_unordered_closed() { + async fn send_loop_terminates_when_ws_control_closed() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state, ClientConfig::for_test(), sink::drain(), messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); @@ -1805,7 +1841,7 @@ mod tests { drop(messages_tx); assert!(is_pending(&mut send_loop).await); - drop(unordered_tx); + drop(ws_control_tx); send_loop.await; } @@ -1813,21 +1849,21 @@ mod tests { async fn send_loop_close_message_closes_state_and_messages() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), sink::drain(), messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); - unordered_tx - .send(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + ws_control_tx + .send(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "done".into(), })) .unwrap(); @@ -1840,12 +1876,12 @@ mod tests { #[tokio::test] async fn send_loop_terminates_if_sink_cant_be_fed() { let input = [ - Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + Either::Left(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "bah!".into(), })), - Either::Left(UnorderedWsMessage::Ping(Bytes::new())), - Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + Either::Left(WsControlMessage::Ping(Bytes::new())), + Either::Left(WsControlMessage::Error(MessageExecutionError { reducer: None, reducer_id: None, caller_identity: Identity::ZERO, @@ -1866,20 +1902,20 @@ mod tests { for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnfeedableSink, messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); match message { - Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Left(ws_control) => ws_control_tx.send(ws_control).unwrap(), Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; @@ -1889,12 +1925,12 @@ mod tests { #[tokio::test] async fn send_loop_terminates_if_sink_cant_be_flushed() { let input = [ - Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + Either::Left(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "bah!".into(), })), - Either::Left(UnorderedWsMessage::Ping(Bytes::new())), - Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + Either::Left(WsControlMessage::Ping(Bytes::new())), + Either::Left(WsControlMessage::Error(MessageExecutionError { reducer: None, reducer_id: None, caller_identity: Identity::ZERO, @@ -1915,20 +1951,20 @@ mod tests { for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnflushableSink, messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); match message { - Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Left(ws_control) => ws_control_tx.send(ws_control).unwrap(), Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; @@ -2006,11 +2042,11 @@ mod tests { let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline()); // Pretend we received a pong immediately after sending a ping, // but only five times. - let unordered_tx = { + let ws_control_tx = { let state = state.clone(); let pings = AtomicUsize::new(0); move |m| { - if let UnorderedWsMessage::Ping(_) = m { + if let WsControlMessage::Ping(_) = m { let n = pings.fetch_add(1, Ordering::Relaxed); if n < 5 { state.set_ponged(); @@ -2030,7 +2066,7 @@ mod tests { ws_idle_timer(idle_rx), tokio::spawn(future::pending()), tokio::spawn(future::pending()), - unordered_tx, + ws_control_tx, ) .await } @@ -2057,10 +2093,10 @@ mod tests { let state = Arc::new(dummy_actor_state()); let (_idle_tx, idle_rx) = watch::channel(state.next_idle_deadline()); - let unordered_tx = { + let ws_control_tx = { let state = state.clone(); move |m| { - if let UnorderedWsMessage::Close(_) = m { + if let WsControlMessage::Close(_) = m { state.close(); } } @@ -2087,7 +2123,7 @@ mod tests { } }), tokio::spawn(future::pending()), - unordered_tx, + ws_control_tx, ) .await }) @@ -2112,15 +2148,15 @@ mod tests { ..<_>::default() })); - let (unordered_tx, mut unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, mut ws_control_rx) = mpsc::unbounded_channel(); let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}"))))); let metric = IntGauge::new("bleep", "unhelpful").unwrap(); - let received = ws_recv_queue(state, unordered_tx, metric.clone(), input) + let received = ws_recv_queue(state, ws_control_tx, metric.clone(), input) .collect::>() .await; - assert_matches!(unordered_rx.recv().await, Some(UnorderedWsMessage::Close(_))); + assert_matches!(ws_control_rx.recv().await, Some(WsControlMessage::Close(_))); // Queue length metric should be zero assert_eq!(metric.get(), 0); // Should have received all of the input. @@ -2134,11 +2170,11 @@ mod tests { ..<_>::default() })); - let (unordered_tx, _) = mpsc::unbounded_channel(); + let (ws_control_tx, _) = mpsc::unbounded_channel(); let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}"))))); let metric = IntGauge::new("bleep", "unhelpful").unwrap(); - let received = ws_recv_queue(state.clone(), unordered_tx, metric.clone(), input) + let received = ws_recv_queue(state.clone(), ws_control_tx, metric.clone(), input) .collect::>() .await; @@ -2154,7 +2190,7 @@ mod tests { let state = Arc::new(dummy_actor_state()); let mut received = Vec::new(); let (messages_tx, messages_rx) = mpsc::channel(1); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); #[derive(From)] enum OutgoingBytes { @@ -2179,24 +2215,24 @@ mod tests { const NUM_CONTROL_FRAMES: usize = 2; let send_loop = tokio::spawn(async move { - ws_send_loop_inner(state, &mut received, messages_rx, unordered_rx, encoder).await; + ws_send_loop_inner(state, &mut received, messages_rx, ws_control_rx, encoder).await; received }); messages_tx.send(Bytes::from_static(&[1; MESSAGE_SIZE])).await.unwrap(); // Yield task to give the send loop a chance to receive the message. tokio::task::yield_now().await; // Send ping, then close. - unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())).unwrap(); - unordered_tx - .send(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + ws_control_tx.send(WsControlMessage::Ping(Bytes::new())).unwrap(); + ws_control_tx + .send(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "we're done".into(), })) .unwrap(); // Shut down the loop. drop(messages_tx); - drop(unordered_tx); + drop(ws_control_tx); let received = send_loop.await.unwrap(); let ping_pos = received diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index 812d03c0701..b23b21c7c76 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -11,9 +11,9 @@ mod message_handlers_v3; pub mod messages; pub use client_connection::{ - ClientConfig, ClientConnection, ClientConnectionReceiver, ClientConnectionSender, ClientSendError, DataMessage, - MeteredDeque, MeteredReceiver, MeteredSender, MeteredUnboundedReceiver, MeteredUnboundedSender, Protocol, - WsVersion, + ClientConfig, ClientConnection, ClientConnectionReceiver, ClientConnectionSender, ClientDisconnectError, + ClientDisconnectSender, ClientSendError, DataMessage, MeteredDeque, MeteredReceiver, MeteredSender, + MeteredUnboundedReceiver, MeteredUnboundedSender, Protocol, WsVersion, }; pub use client_connection_index::ClientActorIndex; pub use message_handlers::MessageHandleError; diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 0a3fa198b01..de1166ad7f5 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -27,8 +27,7 @@ use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use tokio::sync::mpsc::error::{SendError, TrySendError}; -use tokio::sync::{mpsc, oneshot, watch}; -use tokio::task::AbortHandle; +use tokio::sync::{mpsc, watch}; use tracing::{trace, warn}; #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] @@ -262,7 +261,15 @@ pub struct ClientConnectionSender { pub auth: ConnectionAuthCtx, pub config: ClientConfig, sendtx: mpsc::Sender, - abort_handle: AbortHandle, + /// Optional because dummy/test senders are not necessarily backed by a + /// live websocket actor control queue. + /// + /// Production websocket-backed senders receive this at spawn-time so they + /// can request a transport-level close through the websocket control path. + /// Test constructors such as [`ClientConnectionSender::dummy_with_channel`] + /// and [`ClientConnectionSender::dummy`] intentionally leave this as + /// `None` unless a test explicitly wires a control queue. + disconnect_tx: Option, cancelled: AtomicBool, /// Handles on Prometheus metrics related to connections to this database. @@ -309,6 +316,21 @@ impl ClientConnectionMetrics { } } +#[derive(Clone, Debug)] +pub struct ClientDisconnectSender(mpsc::UnboundedSender); + +impl ClientDisconnectSender { + pub fn new(inner: mpsc::UnboundedSender) -> Self { + Self(inner) + } + + pub fn send(&self, close_frame: crate::client::messages::CloseFrame) -> Result<(), ClientDisconnectError> { + self.0 + .send(close_frame) + .map_err(|_| ClientDisconnectError::Disconnected) + } +} + #[derive(Debug, thiserror::Error)] pub enum ClientSendError { #[error("client disconnected")] @@ -317,6 +339,16 @@ pub enum ClientSendError { Cancelled, } +#[derive(Debug, thiserror::Error)] +pub enum ClientDisconnectError { + #[error("client was already cancelled")] + Cancelled, + #[error("disconnect channel disconnected")] + Disconnected, + #[error("disconnect handle is not configured")] + NoDisconnectHandle, +} + impl ClientConnectionSender { pub fn dummy_with_channel( id: ClientActorId, @@ -324,11 +356,6 @@ impl ClientConnectionSender { offset_supply: impl DurableOffsetSupply + 'static, ) -> (Self, ClientConnectionReceiver) { let (sendtx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY_TEST); - // just make something up, it doesn't need to be attached to a real task - let abort_handle = match tokio::runtime::Handle::try_current() { - Ok(h) => h.spawn(async {}).abort_handle(), - Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(), - }; let receiver = ClientConnectionReceiver::new(config.confirmed_reads, MeteredReceiver::new(rx), offset_supply); let cancelled = AtomicBool::new(false); @@ -346,7 +373,7 @@ impl ClientConnectionSender { auth: ConnectionAuthCtx::try_from(dummy_claims).expect("dummy claims should always be valid"), config, sendtx, - abort_handle, + disconnect_tx: None, cancelled, metrics: None, }; @@ -357,10 +384,36 @@ impl ClientConnectionSender { Self::dummy_with_channel(id, config, offset_supply).0 } + #[cfg(test)] + pub(crate) fn dummy_with_disconnect_channel( + id: ClientActorId, + config: ClientConfig, + offset_supply: impl DurableOffsetSupply + 'static, + ) -> ( + Self, + ClientConnectionReceiver, + mpsc::UnboundedReceiver, + ) { + let (mut sender, receiver) = Self::dummy_with_channel(id, config, offset_supply); + let (disconnect_tx, disconnect_rx) = mpsc::unbounded_channel(); + sender.disconnect_tx = Some(ClientDisconnectSender::new(disconnect_tx)); + (sender, receiver, disconnect_rx) + } + pub fn is_cancelled(&self) -> bool { self.cancelled.load(Ordering::Relaxed) } + pub fn disconnect(&self, close_frame: crate::client::messages::CloseFrame) -> Result<(), ClientDisconnectError> { + if self.cancelled.load(Relaxed) { + return Err(ClientDisconnectError::Cancelled); + } + self.disconnect_tx + .as_ref() + .ok_or(ClientDisconnectError::NoDisconnectHandle)? + .send(close_frame) + } + /// Send a message to the client. For data-related messages, you should probably use /// `BroadcastQueue::send` to ensure that the client sees data messages in a consistent order. /// @@ -393,8 +446,13 @@ impl ClientConnectionSender { match self.sendtx.try_send(message) { Err(mpsc::error::TrySendError::Full(_)) => { - // we've hit CLIENT_CHANNEL_CAPACITY messages backed up in - // the channel, so forcibly kick the client + // We've hit `CLIENT_CHANNEL_CAPACITY` messages backed up in the channel, + // so forcibly kick the client. + // + // Mark the sender cancelled first so subsequent ordered sends + // fail fast immediately, then request websocket close using the + // same control-plane close path as any other server-initiated + // disconnect. tracing::warn!( identity = %self.id.identity, connection_id = %self.id.connection_id, @@ -406,8 +464,13 @@ impl ClientConnectionSender { self.id, self.sendtx.capacity(), ); - self.abort_handle.abort(); self.cancelled.store(true, Ordering::Relaxed); + if let Some(disconnect_tx) = &self.disconnect_tx { + let _ = disconnect_tx.send(crate::client::messages::CloseFrame { + code: crate::client::messages::CloseCode::Again, + reason: "client channel capacity exceeded".into(), + }); + } return Err(ClientSendError::Cancelled); } Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected), @@ -803,6 +866,7 @@ impl ClientConnection { config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, + disconnect_tx: Option, actor: impl FnOnce(ClientConnection, ClientConnectionReceiver) -> Fut, _proof_of_client_connected_call: Connected, ) -> ClientConnection @@ -817,25 +881,9 @@ impl ClientConnection { let (sendtx, sendrx) = mpsc::channel::(CLIENT_CHANNEL_CAPACITY); - let (fut_tx, fut_rx) = oneshot::channel::(); - // weird dance so that we can get an abort_handle into ClientConnection let module_info = module.info.clone(); let database_identity = module_info.database_identity; let client_identity = id.identity; - let abort_handle = tokio::spawn(async move { - let Ok(fut) = fut_rx.await else { return }; - - let _gauge_guard = module_info.metrics.connected_clients.inc_scope(); - module_info.metrics.ws_clients_spawned.inc(); - scopeguard::defer! { - let database_identity = module_info.database_identity; - log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); - module_info.metrics.ws_clients_aborted.inc(); - }; - - fut.await - }) - .abort_handle(); let metrics = ClientConnectionMetrics::new(database_identity, config.protocol); let receiver = ClientConnectionReceiver::new( @@ -849,7 +897,7 @@ impl ClientConnection { auth, config, sendtx, - abort_handle, + disconnect_tx, cancelled: AtomicBool::new(false), metrics: Some(metrics), }); @@ -861,8 +909,17 @@ impl ClientConnection { }; let actor_fut = actor(this.clone(), receiver); - // if this fails, the actor() function called .abort(), which like... okay, I guess? - let _ = fut_tx.send(actor_fut); + tokio::spawn(async move { + let _gauge_guard = module_info.metrics.connected_clients.inc_scope(); + module_info.metrics.ws_clients_spawned.inc(); + scopeguard::defer! { + let database_identity = module_info.database_identity; + log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); + module_info.metrics.ws_clients_aborted.inc(); + }; + + actor_fut.await + }); this } @@ -1433,4 +1490,41 @@ mod tests { offset.mark_durable_at(3); assert_received_update(receiver.recv()).await; } + + #[test] + fn disconnect_without_handle_returns_no_disconnect_handle() { + let sender = ClientConnectionSender::dummy( + ClientActorId::for_test(Identity::ZERO), + ClientConfig::for_test(), + NoneDurableOffset, + ); + let res = sender.disconnect(crate::client::messages::CloseFrame { + code: crate::client::messages::CloseCode::Away, + reason: "disconnect".into(), + }); + assert_matches!(res, Err(ClientDisconnectError::NoDisconnectHandle)); + } + + #[test] + fn send_overflow_marks_cancelled_and_emits_disconnect() { + let (sender, _receiver, mut disconnect_rx) = ClientConnectionSender::dummy_with_disconnect_channel( + ClientActorId::for_test(Identity::ZERO), + ClientConfig::for_test(), + NoneDurableOffset, + ); + + for _ in 0..CLIENT_CHANNEL_CAPACITY_TEST { + sender.send_message(None, empty_tx_update()).unwrap(); + } + + let res = sender.send_message(None, empty_tx_update()); + assert_matches!(res, Err(ClientSendError::Cancelled)); + assert!(sender.is_cancelled()); + + let close_frame = disconnect_rx.try_recv().expect("expected disconnect request"); + assert_eq!(close_frame.code, crate::client::messages::CloseCode::Again); + + let res = sender.send_message(None, empty_tx_update()); + assert_matches!(res, Err(ClientSendError::Cancelled)); + } } diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index fb85730c11c..83cf9ac8bd2 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -1,4 +1,4 @@ -use super::{ClientConnection, DataMessage, WsVersion}; +use super::{messages::CloseFrame, ClientConnection, DataMessage, WsVersion}; use crate::client::message_handlers_v1::MessageExecutionError; use spacetimedb_lib::bsatn; use std::time::Instant; @@ -15,6 +15,9 @@ pub enum MessageHandleError { #[error(transparent)] Execution(#[from] MessageExecutionError), + #[error("Client should be disconnected with close frame {0:?}")] + DisconnectClient(CloseFrame), + #[error("unsupported websocket version: {0}")] UnsupportedVersion(&'static str), } diff --git a/crates/core/src/client/message_handlers_v2.rs b/crates/core/src/client/message_handlers_v2.rs index d228fda9fcd..3d2fb0ac621 100644 --- a/crates/core/src/client/message_handlers_v2.rs +++ b/crates/core/src/client/message_handlers_v2.rs @@ -1,4 +1,4 @@ -use crate::client::MessageExecutionError; +use crate::client::{messages::CloseFrame, MessageExecutionError}; use super::{ClientConnection, DataMessage, MessageHandleError}; use serde::de::Error as _; @@ -52,22 +52,37 @@ pub(super) async fn handle_decoded_message( let res = client.enqueue_reducer_v2(reducer, args, request_id, timer, flags).await; match res { Ok(_) => { - // If this was not a success, we would have already sent an error message. + // If this was not a success, i.e. the reducer returned an error and rolled back, + // we have already sent an error message. Ok(()) } Err(e) => { let err_msg = format!("{e:#}"); - let server_message = ws_v2::ServerMessage::ReducerResult(ws_v2::ReducerResult { - request_id, - // Maybe we should use the same timestamp that was used for the reducer context, but this is probably fine for now. - timestamp: Timestamp::now(), - result: ws_v2::ReducerOutcome::InternalError(err_msg.into()), - }); - // TODO: Should we kill the client here, or does it mean the client is already dead. - if let Err(send_err) = client.send_message(None, server_message) { - log::warn!("Failed to send reducer error to client: {send_err}"); + + if let Some(code) = e.close_code() { + // pgoldman 2026-05-14: I've sorta bolted on this error path, + // which attempts to instruct clients on how to handle disconnects, + // as a way to bypass the existing error-handling code which goes through `MessageExecutionError`. + // Prior to my changes, an error in `CallReducer` never resulted in a `MessageExecutionError`, + // as all errors were treated like `ErrorClientConnectionBehavior::RespondError`. + return Err(MessageHandleError::DisconnectClient(CloseFrame { + code, + reason: err_msg.into(), + })); + } else { + let server_message = ws_v2::ServerMessage::ReducerResult(ws_v2::ReducerResult { + request_id, + // Maybe we should use the same timestamp that was used for the reducer context, but this is probably fine for now. + timestamp: Timestamp::now(), + result: ws_v2::ReducerOutcome::InternalError(err_msg.into()), + }); + + // TODO: Should we kill the client here, or does it mean the client is already dead. + if let Err(send_err) = client.send_message(None, server_message) { + log::warn!("Failed to send reducer error to client: {send_err}"); + } + Ok(()) } - Ok(()) } } } diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 798596b5bca..79ae6031ac4 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -17,6 +17,7 @@ use spacetimedb_lib::{AlgebraicValue, ConnectionId, TimeDuration, Timestamp}; use spacetimedb_primitives::TableId; use spacetimedb_sats::bsatn; use spacetimedb_schema::table_name::TableName; +use std::borrow::Cow; use std::sync::Arc; use std::time::Instant; @@ -823,3 +824,37 @@ impl ToProtocol for ProcedureResultMessage { } } } + +/// Reasons for a WebSocket connection to be closed. +/// +/// This is a subset of the WebSocket close codes defined in [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-7.4), +/// along with some extensions documented by [IANA](https://www.iana.org/assignments/websocket/websocket.xml#close-code-number). +/// We use the same names as [Tungstenite](https://docs.rs/tungstenite/latest/tungstenite/protocol/frame/coding/enum.CloseCode.html). +/// We don't use the actual Tungstenite `CloseCode` enum because the spacetimedb-core crate doesn't depend on Tungstenite. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum CloseCode { + /// Closing due to receiving an invalid message from the client. + /// + /// We send this when the client sends an invalid request, + /// e.g. a reducer or procedure call with ill-typed or unparseable arguments. + Invalid, + /// Closing, but the client should attempt to reconnect. + /// + /// We send this e.g. when a connection closes due to a leader failover. + Again, + /// Closing because the database to which the client was connected has gone away. + Away, + /// Protocol error. The catch-all. + Protocol, +} + +/// A WebSocket close frame. +/// +/// We'll convert this into a [`tungstenite::protocol::frame::CloseFrame`](https://docs.rs/tungstenite/latest/tungstenite/protocol/frame/struct.CloseFrame.html) +/// and send it to the client when issuing a server-initiated disconnection. +/// We don't directly use the Tungstenite type because the spacetimedb-core crate doesn't depend on Tungstenite. +#[derive(Clone, Debug)] +pub struct CloseFrame { + pub code: CloseCode, + pub reason: Cow<'static, str>, +} diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index daa506a8cf5..f8eb08a00c3 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -63,11 +63,24 @@ impl FunctionArgs { bsatn: OnceCell::new(), json: OnceCell::with_value(json), }, - FunctionArgs::Bsatn(bytes) => ArgsTuple { - tuple: seed.deserialize(bsatn::Deserializer::new(&mut &bytes[..]))?, - bsatn: OnceCell::with_value(bytes), - json: OnceCell::new(), - }, + FunctionArgs::Bsatn(bytes) => { + let mut reader = &bytes[..]; + let tuple = seed.deserialize(bsatn::Deserializer::new(&mut reader))?; + // Reject inputs which leave an unused suffix after parsing the arguments: + // these are most likely an erroneous input of an incorrect type, + // where by chance a prefix of that input successfully parses at the actual argument type. + // E.g. this will trigger when a function accepts an i32 argument, but a client provides an i64. + anyhow::ensure!( + reader.is_empty(), + "After reading function arguments, expected EOF but found {} bytes remaining", + reader.len() + ); + ArgsTuple { + tuple, + bsatn: OnceCell::with_value(bytes), + json: OnceCell::new(), + } + } FunctionArgs::Nullary => { anyhow::ensure!(seed.params().elements.is_empty(), "failed to typecheck args"); ArgsTuple::nullary() @@ -196,3 +209,48 @@ pub enum AbiCall { ProcedureAbortMutTransaction, ProcedureHttpRequest, } + +#[cfg(test)] +mod test { + use super::*; + use spacetimedb_lib::sats::{AlgebraicType, ProductType, WithTypespace}; + + struct TestFunctionDef { + params: ProductType, + name: Identifier, + } + + impl FunctionDef for TestFunctionDef { + fn params(&self) -> &ProductType { + &self.params + } + fn name(&self) -> &Identifier { + &self.name + } + } + + impl TestFunctionDef { + fn args_seed(&'_ self) -> ArgsSeed<'_, Self> { + ArgsSeed(WithTypespace::empty(self)) + } + } + + #[test] + fn reject_too_long_args_buffer() { + let i32_args_def = TestFunctionDef { + params: [AlgebraicType::I32].into(), + name: Identifier::for_test("i32_args"), + }; + + let args = bsatn::to_vec(&-1i64).unwrap(); + + // Sanity check: assert that the error below from `FunctionArgs::into_tuple` + // is specifically due to the extra machinery in `FunctionArgs::_into_tuple`, + // not because the prefix of `args` fails to parse at the type in `TestFunctionDef`. + assert_eq!(Ok(-1i32), bsatn::from_slice::(&args[..])); + + let args = FunctionArgs::Bsatn(args.into()); + // Assert that the `FunctionArgs` reader errors when passed a buffer that's too long. + assert!(args.into_tuple(i32_args_def.args_seed()).is_err()); + } +} diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 553c7ff685c..1639f258acd 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -2,7 +2,9 @@ use super::{ ArgsTuple, FunctionArgs, InvalidProcedureArguments, InvalidReducerArguments, ReducerCallResult, ReducerId, ReducerOutcome, Scheduler, }; -use crate::client::messages::{OneOffQueryResponseMessage, ProcedureResultMessage, SerializableMessage}; +use crate::client::messages::{ + CloseCode, CloseFrame, OneOffQueryResponseMessage, ProcedureResultMessage, SerializableMessage, +}; use crate::client::{ClientActorId, ClientConnectionSender, WsVersion}; use crate::database_logger::{DatabaseLogger, LogLevel, Record}; use crate::db::relational_db::{RelationalDB, Tx}; @@ -1435,6 +1437,27 @@ pub enum ReducerCallError { LifecycleReducer(Lifecycle), } +impl ReducerCallError { + /// When a reducer call by a client results in this error variant, + /// should the host disconnect that client, and what WebSocket close code should it use? + /// + /// Returns `None` if the connection should stay open. + pub fn close_code(&self) -> Option { + match self { + // These errors all result from outdated or incorrect client bindings. + // The developer of the client application needs to re-run `spacetime generate`. + Self::Args(_) | Self::NoSuchReducer | Self::LifecycleReducer(_) => Some(CloseCode::Invalid), + + // These errors all result most commonly from scheduling or replication changes. + // A reconnect will not necessarily result in a successful connection, but it might, + // and it will at least give a useful diagnostic about the new state of the database. + Self::NoSuchModule(_) | Self::WorkerError(_) => Some(CloseCode::Again), + + Self::ScheduleReducerNotFound => unreachable!("WebSocket client's don't directly invoke schedules"), + } + } +} + #[derive(Debug, PartialEq, Eq)] pub enum ViewOutcome { Success, @@ -1511,8 +1534,43 @@ pub enum ProcedureCallError { NoSuchProcedure, #[error("Procedure terminated due to insufficient budget")] OutOfEnergy, + #[error("Unable to deserialize the procedure's return value: {0}")] + InvalidReturnValue(bsatn::DecodeError), #[error("The module instance encountered a fatal error: {0}")] - InternalError(String), + GuestPanic(String), +} + +impl ProcedureCallError { + /// When a procedure call by a client results in this error variant, + /// should the host disconnect that client, and what WebSocket close code should it use? + /// + /// Returns `None` if the connection should stay open. + pub fn close_code(&self) -> Option { + match self { + // These errors all result from outdated or incorrect client bindings. + // The developer of the client application needs to re-run `spacetime generate`. + Self::Args(_) | Self::NoSuchProcedure => Some(CloseCode::Invalid), + + // This error results most commonly from scheduling or replication changes. + // A reconnect will not necessarily result in a successful connection, but it might, + // and it will at least give a useful diagnostic about the new state of the database. + Self::NoSuchModule(_) => Some(CloseCode::Again), + + // A guest panic is benign, and future calls by the client may succeed. + Self::GuestPanic(_) => None, + + // This is a weird error, and probably indicates that the database is broken somehow, + // but it's not a problem with the client, so I (pgoldman 2026-05-14) guess we'll just send an error message? + Self::InvalidReturnValue(_) => None, + + // TODO(procedure-energy): Re-evaluate the correct behavior here. + // This error may mean that the individual procedure call was terminated due to exceeding its budget, + // i.e. running too long and consuming too many CPU cycles, + // or it may mean that the database as a whole has been suspended. + // Because of the former case, future calls by the client may succeed, so don't disconnect. + Self::OutOfEnergy => None, + } + } } #[derive(thiserror::Error, Debug)] @@ -2616,12 +2674,25 @@ impl ModuleHost { self.subscriptions().send_procedure_message(sender, message, tx_offset) } WsVersion::V2 | WsVersion::V3 => { - let (status, timestamp, execution_duration) = match result { + let send_message = |sender, status, timestamp, execution_duration| { + let message = ws_v2::ProcedureResult { + status, + timestamp, + total_host_execution_duration: execution_duration, + request_id, + }; + + self.subscriptions() + .send_procedure_message_v2(sender, message, tx_offset) + }; + + match result { Ok(ProcedureCallResult { return_val, execution_duration, start_timestamp, - }) => ( + }) => send_message( + sender, ws_v2::ProcedureStatus::Returned( bsatn::to_vec(&return_val) .expect("Procedure return value failed to serialize to BSATN") @@ -2630,22 +2701,22 @@ impl ModuleHost { start_timestamp, TimeDuration::from(execution_duration), ), - Err(err) => ( - ws_v2::ProcedureStatus::InternalError(err.to_string().into()), - Timestamp::UNIX_EPOCH, - TimeDuration::ZERO, - ), - }; - - let message = ws_v2::ProcedureResult { - status, - timestamp, - total_host_execution_duration: execution_duration, - request_id, - }; - - self.subscriptions() - .send_procedure_message_v2(sender, message, tx_offset) + Err(err) => match err.close_code() { + None => send_message( + sender, + ws_v2::ProcedureStatus::InternalError(err.to_string().into()), + Timestamp::UNIX_EPOCH, + TimeDuration::ZERO, + ), + Some(code) => { + let _ = sender.disconnect(CloseFrame { + code, + reason: err.to_string().into(), + }); + Ok(()) + } + }, + } } } } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index b269bbb92e4..18385cd1d5f 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -791,14 +791,14 @@ impl InstanceCommon { // return Err(ProcedureCallError::OutOfEnergy); // } else { - Err(ProcedureCallError::InternalError(format!("{err}"))) + Err(ProcedureCallError::GuestPanic(format!("{err}"))) } } Ok(return_val) => { let return_type = &procedure_def.return_type; let seed = spacetimedb_sats::WithTypespace::new(self.info.module_def.typespace(), return_type); seed.deserialize(bsatn::Deserializer::new(&mut &return_val[..])) - .map_err(|err| ProcedureCallError::InternalError(format!("{err}"))) + .map_err(ProcedureCallError::InvalidReturnValue) .map(|return_val| ProcedureCallResult { return_val, execution_duration: timer.map(|timer| timer.elapsed()).unwrap_or_default(),