From 0221739db45380b5d39a83e9248b0a36e62b6f70 Mon Sep 17 00:00:00 2001 From: Phoebe Goldman Date: Wed, 13 May 2026 11:15:53 -0400 Subject: [PATCH 1/3] Reject BSATN function args where the buffer is too long BSATN parsing generally accepts the case where there are unused trailing bytes in a buffer after parsing a type. This allows both building up compound-typed objects by repeatedly parsing their members, and packing multiple values into the same buffer sequentially. However, it has an unfortunate consequence when parsing untrusted inputs: if a client submits an input e.g. for a reducer call which is not of the expected type, but has a prefix that parses at the expected type, a direct use of the BSATN parser will accept it and silently ignore the suffix. One simple example is a client attempting to submit an i64 when SpacetimeDB expects an i32, resulting in the high 4 bytes of the client submission being ignored, potentially resulting in a different number being parsed than the one submitted. In this commit, we check when parsing user-submitted function arguments that not only did the parse succeed, but that it also consumed the entire input. If the entire input was not consumed, we treat it as an error in the class described above. --- crates/core/src/host/mod.rs | 68 ++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index daa506a8cf5..f8eb08a00c3 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -63,11 +63,24 @@ impl FunctionArgs { bsatn: OnceCell::new(), json: OnceCell::with_value(json), }, - FunctionArgs::Bsatn(bytes) => ArgsTuple { - tuple: seed.deserialize(bsatn::Deserializer::new(&mut &bytes[..]))?, - bsatn: OnceCell::with_value(bytes), - json: OnceCell::new(), - }, + FunctionArgs::Bsatn(bytes) => { + let mut reader = &bytes[..]; + let tuple = seed.deserialize(bsatn::Deserializer::new(&mut reader))?; + // Reject inputs which leave an unused suffix after parsing the arguments: + // these are most likely an erroneous input of an incorrect type, + // where by chance a prefix of that input successfully parses at the actual argument type. + // E.g. this will trigger when a function accepts an i32 argument, but a client provides an i64. + anyhow::ensure!( + reader.is_empty(), + "After reading function arguments, expected EOF but found {} bytes remaining", + reader.len() + ); + ArgsTuple { + tuple, + bsatn: OnceCell::with_value(bytes), + json: OnceCell::new(), + } + } FunctionArgs::Nullary => { anyhow::ensure!(seed.params().elements.is_empty(), "failed to typecheck args"); ArgsTuple::nullary() @@ -196,3 +209,48 @@ pub enum AbiCall { ProcedureAbortMutTransaction, ProcedureHttpRequest, } + +#[cfg(test)] +mod test { + use super::*; + use spacetimedb_lib::sats::{AlgebraicType, ProductType, WithTypespace}; + + struct TestFunctionDef { + params: ProductType, + name: Identifier, + } + + impl FunctionDef for TestFunctionDef { + fn params(&self) -> &ProductType { + &self.params + } + fn name(&self) -> &Identifier { + &self.name + } + } + + impl TestFunctionDef { + fn args_seed(&'_ self) -> ArgsSeed<'_, Self> { + ArgsSeed(WithTypespace::empty(self)) + } + } + + #[test] + fn reject_too_long_args_buffer() { + let i32_args_def = TestFunctionDef { + params: [AlgebraicType::I32].into(), + name: Identifier::for_test("i32_args"), + }; + + let args = bsatn::to_vec(&-1i64).unwrap(); + + // Sanity check: assert that the error below from `FunctionArgs::into_tuple` + // is specifically due to the extra machinery in `FunctionArgs::_into_tuple`, + // not because the prefix of `args` fails to parse at the type in `TestFunctionDef`. + assert_eq!(Ok(-1i32), bsatn::from_slice::(&args[..])); + + let args = FunctionArgs::Bsatn(args.into()); + // Assert that the `FunctionArgs` reader errors when passed a buffer that's too long. + assert!(args.into_tuple(i32_args_def.args_seed()).is_err()); + } +} From aafae305ab072db37962a8be6d1aa148d0065748 Mon Sep 17 00:00:00 2001 From: Phoebe Goldman Date: Fri, 15 May 2026 12:45:22 -0400 Subject: [PATCH 2/3] Partial change: disconnect clients when they send invalid calls When a client sends an invalid `CallReducer` or `CallProcedure` message, previously, we'd send them an error response but continue their connection. That was silly; there's significant classes of error which mean the connection is broken and should be killed. With this change, a client-supplied invalid `CallReducer` will result in a disconnect. A client-supplied invalid `CallProcedure` will panic due to hitting a `todo!()` which will be filled in in a subsequent commit. As part of this change, I defined `CloseFrame` in spacetimedb-core as a mirror to tunstenite's CloseFrame. This led to a minor audit of our close codes, and revealed one incorrect use: `CloseCode::Error` is defined by the spec as "internal error", but we were sending it in response to a client error. I have replaced it with `CloseCode::Protocol`, which is "protocol error". --- crates/client-api/src/routes/database.rs | 4 +- crates/client-api/src/routes/subscribe.rs | 64 +++++++--- crates/core/src/client/message_handlers.rs | 5 +- crates/core/src/client/message_handlers_v2.rs | 39 ++++-- crates/core/src/client/messages.rs | 35 ++++++ crates/core/src/host/module_host.rs | 111 ++++++++++++++---- .../src/host/wasm_common/module_host_actor.rs | 4 +- .../subscription/module_subscription_actor.rs | 10 +- 8 files changed, 216 insertions(+), 56 deletions(-) diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index e1464814656..d977e41247a 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -127,7 +127,9 @@ fn map_procedure_error(e: ProcedureCallError, procedure: &str) -> (StatusCode, S StatusCode::NOT_FOUND } ProcedureCallError::OutOfEnergy => StatusCode::PAYMENT_REQUIRED, - ProcedureCallError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, + ProcedureCallError::GuestPanic(_) | ProcedureCallError::InvalidReturnValue(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } }; log::error!("Error while invoking procedure {e:#}"); (status_code, format!("{:#}", anyhow::anyhow!(e))) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index e2131abfa31..c3b84d137db 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt::Display; use std::future::Future; use std::num::NonZeroUsize; @@ -23,8 +24,8 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, - ToProtocol, + serialize, serialize_v2, CloseCode as StCloseCode, CloseFrame as StCloseFrame, IdentityTokenMessage, + InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, @@ -52,7 +53,8 @@ use tokio_tungstenite::tungstenite::Utf8Bytes; use crate::auth::SpacetimeAuth; use crate::util::serde::humantime_duration; use crate::util::websocket::{ - CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError, + CloseCode as WsCloseCode, CloseFrame as WsCloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, + WebSocketUpgrade, WsError, }; use crate::util::{NameOrIdentity, XForwardedFor}; use crate::{log_and_500, Authorization, ControlStateDelegate, NodeDelegate}; @@ -678,8 +680,8 @@ async fn ws_main_loop( // Branch is disabled if we already sent a close frame. res = &mut watch_hotswap, if !closed => { if let Err(NoSuchModule) = res { - let close = CloseFrame { - code: CloseCode::Away, + let close = StCloseFrame { + code: StCloseCode::Away, reason: "module exited".into() }; unordered_tx(close.into()); @@ -791,10 +793,8 @@ async fn ws_recv_task( continue; } log::debug!("Client caused error: {e}"); - let close = CloseFrame { - code: CloseCode::Error, - reason: format!("{e:#}").into(), - }; + let close = close_frame_for_error(e); + // If the send task has exited, also exit this recv task. // No need to send the close handshake in that case; the client is already gone. if unordered_tx.send(close.into()).is_err() { @@ -804,6 +804,17 @@ async fn ws_recv_task( } } +fn close_frame_for_error(e: MessageHandleError) -> StCloseFrame { + if let MessageHandleError::DisconnectClient(frame) = e { + frame + } else { + StCloseFrame { + code: StCloseCode::Protocol, + reason: format!("{e:#}").into(), + } + } +} + /// Stream that consumes a stream of [`WsMessage`]s and yields [`ClientMessage`]s. /// /// Terminates if: @@ -918,9 +929,9 @@ fn ws_recv_queue( recv_queue_gauge: IntGauge, mut ws: impl Stream> + Unpin + Send + 'static, ) -> impl Stream> { - const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Again, - reason: Utf8Bytes::from_static("too many requests"), + const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(StCloseFrame { + code: StCloseCode::Again, + reason: Cow::Borrowed("too many requests"), }); let on_message_after_close = move |client_id| { log::warn!("client {client_id} sent message after close or error"); @@ -1051,7 +1062,7 @@ fn ws_client_message_handler( #[derive(Debug, From)] enum UnorderedWsMessage { /// Server-initiated close. - Close(CloseFrame), + Close(spacetimedb::client::messages::CloseFrame), /// Server-initiated ping. Ping(Bytes), /// Error calling a reducer. @@ -1202,7 +1213,7 @@ async fn ws_send_loop_inner( } // Then send the close frame. log::trace!("sending close frame"); - if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await { + if let Err(e) = ws.send(WsMessage::Close(Some(convert_close_frame(close_frame)))).await { log::warn!("error sending close frame: {e:#}"); break; } @@ -1279,6 +1290,21 @@ async fn ws_send_loop_inner( } } +fn convert_close_frame(frame: StCloseFrame) -> WsCloseFrame { + WsCloseFrame { + code: match frame.code { + StCloseCode::Again => WsCloseCode::Again, + StCloseCode::Invalid => WsCloseCode::Invalid, + StCloseCode::Away => WsCloseCode::Away, + StCloseCode::Protocol => WsCloseCode::Protocol, + }, + reason: match frame.reason { + Cow::Borrowed(reason) => Utf8Bytes::from_static(reason), + Cow::Owned(reason) => reason.into(), + }, + } +} + #[derive(From)] enum OutboundWsMessage { Error(MessageExecutionError), @@ -1523,7 +1549,7 @@ enum ClientMessage { Message(DataMessage), Ping(Bytes), Pong(Bytes), - Close(Option), + Close(Option), } impl ClientMessage { @@ -1827,7 +1853,7 @@ mod tests { unordered_tx .send(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + code: WsCloseCode::Away, reason: "done".into(), })) .unwrap(); @@ -1841,7 +1867,7 @@ mod tests { async fn send_loop_terminates_if_sink_cant_be_fed() { let input = [ Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + code: WsCloseCode::Away, reason: "bah!".into(), })), Either::Left(UnorderedWsMessage::Ping(Bytes::new())), @@ -1890,7 +1916,7 @@ mod tests { async fn send_loop_terminates_if_sink_cant_be_flushed() { let input = [ Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + code: WsCloseCode::Away, reason: "bah!".into(), })), Either::Left(UnorderedWsMessage::Ping(Bytes::new())), @@ -2189,7 +2215,7 @@ mod tests { unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())).unwrap(); unordered_tx .send(UnorderedWsMessage::Close(CloseFrame { - code: CloseCode::Away, + code: WsCloseCode::Away, reason: "we're done".into(), })) .unwrap(); diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index fb85730c11c..83cf9ac8bd2 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -1,4 +1,4 @@ -use super::{ClientConnection, DataMessage, WsVersion}; +use super::{messages::CloseFrame, ClientConnection, DataMessage, WsVersion}; use crate::client::message_handlers_v1::MessageExecutionError; use spacetimedb_lib::bsatn; use std::time::Instant; @@ -15,6 +15,9 @@ pub enum MessageHandleError { #[error(transparent)] Execution(#[from] MessageExecutionError), + #[error("Client should be disconnected with close frame {0:?}")] + DisconnectClient(CloseFrame), + #[error("unsupported websocket version: {0}")] UnsupportedVersion(&'static str), } diff --git a/crates/core/src/client/message_handlers_v2.rs b/crates/core/src/client/message_handlers_v2.rs index d228fda9fcd..3d2fb0ac621 100644 --- a/crates/core/src/client/message_handlers_v2.rs +++ b/crates/core/src/client/message_handlers_v2.rs @@ -1,4 +1,4 @@ -use crate::client::MessageExecutionError; +use crate::client::{messages::CloseFrame, MessageExecutionError}; use super::{ClientConnection, DataMessage, MessageHandleError}; use serde::de::Error as _; @@ -52,22 +52,37 @@ pub(super) async fn handle_decoded_message( let res = client.enqueue_reducer_v2(reducer, args, request_id, timer, flags).await; match res { Ok(_) => { - // If this was not a success, we would have already sent an error message. + // If this was not a success, i.e. the reducer returned an error and rolled back, + // we have already sent an error message. Ok(()) } Err(e) => { let err_msg = format!("{e:#}"); - let server_message = ws_v2::ServerMessage::ReducerResult(ws_v2::ReducerResult { - request_id, - // Maybe we should use the same timestamp that was used for the reducer context, but this is probably fine for now. - timestamp: Timestamp::now(), - result: ws_v2::ReducerOutcome::InternalError(err_msg.into()), - }); - // TODO: Should we kill the client here, or does it mean the client is already dead. - if let Err(send_err) = client.send_message(None, server_message) { - log::warn!("Failed to send reducer error to client: {send_err}"); + + if let Some(code) = e.close_code() { + // pgoldman 2026-05-14: I've sorta bolted on this error path, + // which attempts to instruct clients on how to handle disconnects, + // as a way to bypass the existing error-handling code which goes through `MessageExecutionError`. + // Prior to my changes, an error in `CallReducer` never resulted in a `MessageExecutionError`, + // as all errors were treated like `ErrorClientConnectionBehavior::RespondError`. + return Err(MessageHandleError::DisconnectClient(CloseFrame { + code, + reason: err_msg.into(), + })); + } else { + let server_message = ws_v2::ServerMessage::ReducerResult(ws_v2::ReducerResult { + request_id, + // Maybe we should use the same timestamp that was used for the reducer context, but this is probably fine for now. + timestamp: Timestamp::now(), + result: ws_v2::ReducerOutcome::InternalError(err_msg.into()), + }); + + // TODO: Should we kill the client here, or does it mean the client is already dead. + if let Err(send_err) = client.send_message(None, server_message) { + log::warn!("Failed to send reducer error to client: {send_err}"); + } + Ok(()) } - Ok(()) } } } diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 798596b5bca..79ae6031ac4 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -17,6 +17,7 @@ use spacetimedb_lib::{AlgebraicValue, ConnectionId, TimeDuration, Timestamp}; use spacetimedb_primitives::TableId; use spacetimedb_sats::bsatn; use spacetimedb_schema::table_name::TableName; +use std::borrow::Cow; use std::sync::Arc; use std::time::Instant; @@ -823,3 +824,37 @@ impl ToProtocol for ProcedureResultMessage { } } } + +/// Reasons for a WebSocket connection to be closed. +/// +/// This is a subset of the WebSocket close codes defined in [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-7.4), +/// along with some extensions documented by [IANA](https://www.iana.org/assignments/websocket/websocket.xml#close-code-number). +/// We use the same names as [Tungstenite](https://docs.rs/tungstenite/latest/tungstenite/protocol/frame/coding/enum.CloseCode.html). +/// We don't use the actual Tungstenite `CloseCode` enum because the spacetimedb-core crate doesn't depend on Tungstenite. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum CloseCode { + /// Closing due to receiving an invalid message from the client. + /// + /// We send this when the client sends an invalid request, + /// e.g. a reducer or procedure call with ill-typed or unparseable arguments. + Invalid, + /// Closing, but the client should attempt to reconnect. + /// + /// We send this e.g. when a connection closes due to a leader failover. + Again, + /// Closing because the database to which the client was connected has gone away. + Away, + /// Protocol error. The catch-all. + Protocol, +} + +/// A WebSocket close frame. +/// +/// We'll convert this into a [`tungstenite::protocol::frame::CloseFrame`](https://docs.rs/tungstenite/latest/tungstenite/protocol/frame/struct.CloseFrame.html) +/// and send it to the client when issuing a server-initiated disconnection. +/// We don't directly use the Tungstenite type because the spacetimedb-core crate doesn't depend on Tungstenite. +#[derive(Clone, Debug)] +pub struct CloseFrame { + pub code: CloseCode, + pub reason: Cow<'static, str>, +} diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 553c7ff685c..dc9d2e232c4 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -2,7 +2,9 @@ use super::{ ArgsTuple, FunctionArgs, InvalidProcedureArguments, InvalidReducerArguments, ReducerCallResult, ReducerId, ReducerOutcome, Scheduler, }; -use crate::client::messages::{OneOffQueryResponseMessage, ProcedureResultMessage, SerializableMessage}; +use crate::client::messages::{ + CloseCode, CloseFrame, OneOffQueryResponseMessage, ProcedureResultMessage, SerializableMessage, +}; use crate::client::{ClientActorId, ClientConnectionSender, WsVersion}; use crate::database_logger::{DatabaseLogger, LogLevel, Record}; use crate::db::relational_db::{RelationalDB, Tx}; @@ -1435,6 +1437,27 @@ pub enum ReducerCallError { LifecycleReducer(Lifecycle), } +impl ReducerCallError { + /// When a reducer call by a client results in this error variant, + /// should the host disconnect that client, and what WebSocket close code should it use? + /// + /// Returns `None` if the connection should stay open. + pub fn close_code(&self) -> Option { + match self { + // These errors all result from outdated or incorrect client bindings. + // The developer of the client application needs to re-run `spacetime generate`. + Self::Args(_) | Self::NoSuchReducer | Self::LifecycleReducer(_) => Some(CloseCode::Invalid), + + // These errors all result most commonly from scheduling or replication changes. + // A reconnect will not necessarily result in a successful connection, but it might, + // and it will at least give a useful diagnostic about the new state of the database. + Self::NoSuchModule(_) | Self::WorkerError(_) => Some(CloseCode::Again), + + Self::ScheduleReducerNotFound => unreachable!("WebSocket client's don't directly invoke schedules"), + } + } +} + #[derive(Debug, PartialEq, Eq)] pub enum ViewOutcome { Success, @@ -1511,8 +1534,43 @@ pub enum ProcedureCallError { NoSuchProcedure, #[error("Procedure terminated due to insufficient budget")] OutOfEnergy, + #[error("Unable to deserialize the procedure's return value: {0}")] + InvalidReturnValue(bsatn::DecodeError), #[error("The module instance encountered a fatal error: {0}")] - InternalError(String), + GuestPanic(String), +} + +impl ProcedureCallError { + /// When a procedure call by a client results in this error variant, + /// should the host disconnect that client, and what WebSocket close code should it use? + /// + /// Returns `None` if the connection should stay open. + pub fn close_code(&self) -> Option { + match self { + // These errors all result from outdated or incorrect client bindings. + // The developer of the client application needs to re-run `spacetime generate`. + Self::Args(_) | Self::NoSuchProcedure => Some(CloseCode::Invalid), + + // This error results most commonly from scheduling or replication changes. + // A reconnect will not necessarily result in a successful connection, but it might, + // and it will at least give a useful diagnostic about the new state of the database. + Self::NoSuchModule(_) => Some(CloseCode::Again), + + // A guest panic is benign, and future calls by the client may succeed. + Self::GuestPanic(_) => None, + + // This is a weird error, and probably indicates that the database is broken somehow, + // but it's not a problem with the client, so I (pgoldman 2026-05-14) guess we'll just send an error message? + Self::InvalidReturnValue(_) => None, + + // TODO(procedure-energy): Re-evaluate the correct behavior here. + // This error may mean that the individual procedure call was terminated due to exceeding its budget, + // i.e. running too long and consuming too many CPU cycles, + // or it may mean that the database as a whole has been suspended. + // Because of the former case, future calls by the client may succeed, so don't disconnect. + Self::OutOfEnergy => None, + } + } } #[derive(thiserror::Error, Debug)] @@ -2616,12 +2674,25 @@ impl ModuleHost { self.subscriptions().send_procedure_message(sender, message, tx_offset) } WsVersion::V2 | WsVersion::V3 => { - let (status, timestamp, execution_duration) = match result { + let send_message = |sender, status, timestamp, execution_duration| { + let message = ws_v2::ProcedureResult { + status, + timestamp, + total_host_execution_duration: execution_duration, + request_id, + }; + + self.subscriptions() + .send_procedure_message_v2(sender, message, tx_offset) + }; + + match result { Ok(ProcedureCallResult { return_val, execution_duration, start_timestamp, - }) => ( + }) => send_message( + sender, ws_v2::ProcedureStatus::Returned( bsatn::to_vec(&return_val) .expect("Procedure return value failed to serialize to BSATN") @@ -2630,22 +2701,22 @@ impl ModuleHost { start_timestamp, TimeDuration::from(execution_duration), ), - Err(err) => ( - ws_v2::ProcedureStatus::InternalError(err.to_string().into()), - Timestamp::UNIX_EPOCH, - TimeDuration::ZERO, - ), - }; - - let message = ws_v2::ProcedureResult { - status, - timestamp, - total_host_execution_duration: execution_duration, - request_id, - }; - - self.subscriptions() - .send_procedure_message_v2(sender, message, tx_offset) + Err(err) => match err.close_code() { + None => send_message( + sender, + ws_v2::ProcedureStatus::InternalError(err.to_string().into()), + Timestamp::UNIX_EPOCH, + TimeDuration::ZERO, + ), + Some(code) => self.subscriptions().disconnect_client( + sender, + CloseFrame { + code, + reason: err.to_string().into(), + }, + ), + }, + } } } } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index b269bbb92e4..18385cd1d5f 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -791,14 +791,14 @@ impl InstanceCommon { // return Err(ProcedureCallError::OutOfEnergy); // } else { - Err(ProcedureCallError::InternalError(format!("{err}"))) + Err(ProcedureCallError::GuestPanic(format!("{err}"))) } } Ok(return_val) => { let return_type = &procedure_def.return_type; let seed = spacetimedb_sats::WithTypespace::new(self.info.module_def.typespace(), return_type); seed.deserialize(bsatn::Deserializer::new(&mut &return_val[..])) - .map_err(|err| ProcedureCallError::InternalError(format!("{err}"))) + .map_err(ProcedureCallError::InvalidReturnValue) .map(|return_val| ProcedureCallResult { return_val, execution_duration: timer.map(|timer| timer.elapsed()).unwrap_or_default(), diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index d748fdd09ab..381fc3f965b 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -8,7 +8,7 @@ use super::query::compile_query_with_hashes; use super::tx::DeltaTx; use super::{collect_table_update, TableUpdateType}; use crate::client::messages::{ - ProcedureResultMessage, SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, + CloseFrame, ProcedureResultMessage, SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage, }; use crate::client::{ClientActorId, ClientConnectionSender, Protocol, WsVersion}; @@ -1127,6 +1127,14 @@ impl ModuleSubscriptions { .send_client_message_v2(recipient, tx_offset, message) } + pub fn disconnect_client( + &self, + recipient: Arc, + close_frame: CloseFrame, + ) -> Result<(), BroadcastError> { + todo!() + } + pub fn send_one_off_query_message_v2( &self, recipient: Arc, From a687af3bf73f2747e207a784a0a2313206c56c38 Mon Sep 17 00:00:00 2001 From: Phoebe Goldman Date: Fri, 15 May 2026 14:03:17 -0400 Subject: [PATCH 3/3] Add a disconnect channel to `ClientConnectionSender` This commit adds and implements `ClientConnectionSender::disconnect`, which does what you expect. This required adding a new member to `ClientConnectionSender`, the `disconnect_tx`, along which disconnect messages are sent. Adding the `disconnect_tx` also made it convenient to tidy the `UnorderedWsMessage` channel and rename that concept to `WsControlMessage`. --- crates/client-api/src/routes/subscribe.rs | 174 +++++++++--------- crates/core/src/client.rs | 6 +- crates/core/src/client/client_connection.rs | 156 ++++++++++++---- crates/core/src/host/module_host.rs | 10 +- .../subscription/module_subscription_actor.rs | 10 +- 5 files changed, 226 insertions(+), 130 deletions(-) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index c3b84d137db..ce7eb107638 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -28,8 +28,8 @@ use spacetimedb::client::messages::{ InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ - ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, - MessageHandleError, MeteredReceiver, MeteredSender, OutboundMessage, Protocol, WsVersion, + ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, ClientDisconnectSender, DataMessage, + MessageExecutionError, MessageHandleError, MeteredReceiver, MeteredSender, OutboundMessage, Protocol, WsVersion, }; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::NoSuchModule; @@ -269,7 +269,8 @@ where "websocket: Database accepted connection from {client_log_string}; spawning ws_client_actor and ClientConnection" ); - let actor = |client, receiver| ws_client_actor(ws_opts, client, ws, receiver); + let (disconnect_tx, disconnect_rx) = mpsc::unbounded_channel(); + let actor = |client, receiver| ws_client_actor(ws_opts, client, ws, receiver, disconnect_rx); let client = ClientConnection::spawn( client_id, auth.into(), @@ -277,6 +278,7 @@ where client_config, leader.replica_id, module_rx, + Some(ClientDisconnectSender::new(disconnect_tx)), actor, connected, ) @@ -434,13 +436,14 @@ async fn ws_client_actor( client: ClientConnection, ws: WebSocketStream, sendrx: ClientConnectionReceiver, + disconnect_rx: mpsc::UnboundedReceiver, ) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { tokio::spawn(client.disconnect()); }); - ws_client_actor_inner(&mut client, options, ws, sendrx).await; + ws_client_actor_inner(&mut client, options, ws, sendrx, disconnect_rx).await; ScopeGuard::into_inner(client).disconnect().await; } @@ -450,14 +453,21 @@ async fn ws_client_actor_inner( config: WebSocketOptions, ws: WebSocketStream, sendrx: ClientConnectionReceiver, + mut disconnect_rx: mpsc::UnboundedReceiver, ) { let database = client.module().info().database_identity; let client_id = client.id; let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database); let state = Arc::new(ActorState::new(database, client_id, config)); - // Channel for [`UnorderedWsMessage`]s. - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + // Channel for websocket control traffic (`Close`, `Ping`, unordered execution errors). + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); + let ws_control_tx_for_disconnect = ws_control_tx.clone(); + tokio::spawn(async move { + while let Some(close_frame) = disconnect_rx.recv().await { + let _ = ws_control_tx_for_disconnect.send(WsControlMessage::Close(close_frame)); + } + }); // Split websocket into send and receive halves. let (ws_send, ws_recv) = ws.split(); @@ -469,13 +479,13 @@ async fn ws_client_actor_inner( let bsatn_rlb_pool = client.module().subscriptions().bsatn_rlb_pool.clone(); // Spawn a task to send outgoing messages - // obtained from `sendrx` and `unordered_rx`. + // obtained from `sendrx` and `ws_control_rx`. let send_task = tokio::spawn(ws_send_loop( state.clone(), client.config, ws_send, sendrx, - unordered_rx, + ws_control_rx, bsatn_rlb_pool, )); // Spawn a task to handle incoming messages. @@ -490,7 +500,7 @@ async fn ws_client_actor_inner( async move { client.handle_message(data, timer.into()).await } } }, - unordered_tx.clone(), + ws_control_tx.clone(), ws_recv, client.config.version, )); @@ -503,7 +513,7 @@ async fn ws_client_actor_inner( }; ws_main_loop(state, hotswap, idle_timer, send_task, recv_task, move |msg| { - let _ = unordered_tx.send(msg); + let _ = ws_control_tx.send(msg); }) .await; log::info!("Client connection ended: {client_id}"); @@ -580,13 +590,13 @@ async fn ws_client_actor_inner( /// The idle timer should be reset whenever data is received from the websocket. /// /// * **send_task**: -/// Task handling outgoing messages. Holds the receive end of `unordered_tx`. +/// Task handling outgoing messages. Holds the receive end of `ws_control_tx`. /// /// If the task returns, the connection is considered bad, and the main loop /// exits. If the task panicked, the panic is resumed on the current thread. /// /// Note that the send task must not terminate after it has sent a `Close` -/// frame (via `unordered_tx`) -- the websocket protocol mandates that the +/// frame (via `ws_control_tx`) -- the websocket protocol mandates that the /// initiator of the close handshake wait for the other end to respond with /// a `Close` frame. Thus, the loop must continue to poll `recv_task` and not /// exit due to `send_task` being complete. @@ -602,7 +612,7 @@ async fn ws_client_actor_inner( /// /// See [`ws_recv_task`]. /// -/// * **unordered_tx**: +/// * **ws_control_tx**: /// Channel connected to `send_task` that allows the loop to send `Ping` and /// `Close` frames. /// @@ -618,7 +628,7 @@ async fn ws_main_loop( idle_timer: impl Future, mut send_task: JoinHandle<()>, mut recv_task: JoinHandle<()>, - unordered_tx: impl Fn(UnorderedWsMessage), + ws_control_tx: impl Fn(WsControlMessage), ) where HotswapWatcher: Future>, { @@ -684,7 +694,7 @@ async fn ws_main_loop( code: StCloseCode::Away, reason: "module exited".into() }; - unordered_tx(close.into()); + ws_control_tx(close.into()); } watch_hotswap.set(hotswap()); }, @@ -703,7 +713,7 @@ async fn ws_main_loop( _ = ping_interval.tick(), if !closed => { let was_ponged = state.reset_ponged(); if was_ponged { - unordered_tx(UnorderedWsMessage::Ping(Bytes::new())); + ws_control_tx(WsControlMessage::Ping(Bytes::new())); } } } @@ -747,7 +757,7 @@ async fn ws_idle_timer(mut activity: watch::Receiver) { /// `idle_tx` is the sending end of a [`ws_idle_timer`]. The [`ws_recv_loop`] /// sends a new, extended deadline whenever it receives a message. /// -/// `unordered_tx` is used to send message execution errors +/// `ws_control_tx` is used to send message execution errors /// or to initiate a close handshake. /// /// Initiates a close handshake if the `message_handler` returns any variant @@ -756,7 +766,7 @@ async fn ws_idle_timer(mut activity: watch::Receiver) { /// Terminates if: /// /// - the `ws` stream is exhausted -/// - or, `unordered_tx` is already closed +/// - or, `ws_control_tx` is already closed /// /// In the latter case, we assume that the connection is in an errored state, /// such that we wouldn't be able to receive any more messages anyway. @@ -765,7 +775,7 @@ async fn ws_recv_task( idle_tx: watch::Sender, client_closed_metric: IntGauge, message_handler: impl Fn(DataMessage, Instant) -> MessageHandler, - unordered_tx: mpsc::UnboundedSender, + ws_control_tx: mpsc::UnboundedSender, ws: impl Stream> + Unpin + Send + 'static, ws_version: WsVersion, ) where @@ -774,7 +784,7 @@ async fn ws_recv_task( let recv_queue_gauge = WORKER_METRICS .total_incoming_queue_length .with_label_values(&state.database); - let recv_queue = ws_recv_queue(state.clone(), unordered_tx.clone(), recv_queue_gauge, ws); + let recv_queue = ws_recv_queue(state.clone(), ws_control_tx.clone(), recv_queue_gauge, ws); let recv_loop = pin!(ws_recv_loop(state.clone(), idle_tx, recv_queue)); let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop); pin_mut!(recv_handler); @@ -787,7 +797,7 @@ async fn ws_recv_task( { log::error!("{err:#}"); // If the send task has exited, also exit this recv task. - if unordered_tx.send(err.into()).is_err() { + if ws_control_tx.send(err.into()).is_err() { break; } continue; @@ -797,7 +807,7 @@ async fn ws_recv_task( // If the send task has exited, also exit this recv task. // No need to send the close handshake in that case; the client is already gone. - if unordered_tx.send(close.into()).is_err() { + if ws_control_tx.send(close.into()).is_err() { break; }; } @@ -914,7 +924,7 @@ fn ws_recv_loop( /// /// The channel is initialized with [`ActorConfig::incoming_queue_length`]. /// If it is at capacity, a connection shutdown is initiated by sending -/// [`UnorderedWsMessage::Close`] via `unordered_tx`. +/// [`WsControlMessage::Close`] via `ws_control_tx`. /// /// Returns the channel receiver. /// @@ -925,11 +935,11 @@ fn ws_recv_loop( /// [#1851]: https://github.com/clockworklabs/SpacetimeDBPrivate/issues/1851 fn ws_recv_queue( state: Arc, - unordered_tx: mpsc::UnboundedSender, + ws_control_tx: mpsc::UnboundedSender, recv_queue_gauge: IntGauge, mut ws: impl Stream> + Unpin + Send + 'static, ) -> impl Stream> { - const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(StCloseFrame { + const CLOSE: WsControlMessage = WsControlMessage::Close(StCloseFrame { code: StCloseCode::Again, reason: Cow::Borrowed("too many requests"), }); @@ -961,7 +971,7 @@ fn ws_recv_queue( // - Then exit the loop, as we won't be processing any // more messages, and we don't expect a close response // to arrive from the client. - if unordered_tx.send(CLOSE).is_err() { + if ws_control_tx.send(CLOSE).is_err() { state.close(); break; } @@ -1060,7 +1070,7 @@ fn ws_client_message_handler( /// Outgoing messages that don't need to be ordered wrt subscription updates. #[derive(Debug, From)] -enum UnorderedWsMessage { +enum WsControlMessage { /// Server-initiated close. Close(spacetimedb::client::messages::CloseFrame), /// Server-initiated ping. @@ -1104,32 +1114,32 @@ impl Receiver for mpsc::Receiver { /// Consumes `messages`, which yields subscription updates and reducer call /// results. Note that [`SerializableMessage`]s require serialization and /// potentially compression, which can be costly. -/// Also consumes `unordered`, which yields [`UnorderedWsMessage`]s. +/// Also consumes `ws_control`, which yields [`WsControlMessage`]s. /// /// Terminates if: /// -/// - `unordered` is closed +/// - `ws_control` is closed /// - an error occurs sending to the `ws` sink /// -/// If an [`UnorderedWsMessage::Close`] is encountered, a close frame is sent +/// If an [`WsControlMessage::Close`] is encountered, a close frame is sent /// to the `ws` sink, and `state.close()` is called. When this happens, /// `messages` will no longer be polled (no data can be sent after a close /// frame anyways), so `messages.close()` will be called. /// -/// Keeps polling `unordered` if `state.closed()`, but discards all data. +/// Keeps polling `ws_control` if `state.closed()`, but discards all data. /// This is so `ws_client_actor_inner` keeps polling the receive end of the /// socket until the close handshake completes -- it would otherwise exit early -/// when sending to `unordered` fails. +/// when sending to `ws_control` fails. async fn ws_send_loop( state: Arc, config: ClientConfig, ws: impl Sink + Unpin, messages: impl Receiver, - unordered: mpsc::UnboundedReceiver, + ws_control: mpsc::UnboundedReceiver, bsatn_rlb_pool: BsatnRowListBuilderPool, ) { let metrics = SendMetrics::new(state.database); - ws_send_loop_inner(state, ws, messages, unordered, move |encode_rx, frames_tx| { + ws_send_loop_inner(state, ws, messages, ws_control, move |encode_rx, frames_tx| { ws_encode_task(metrics, config, encode_rx, frames_tx, bsatn_rlb_pool) }) .await @@ -1139,7 +1149,7 @@ async fn ws_send_loop_inner( state: Arc, mut ws: impl Sink + Unpin, mut messages: impl Receiver, - mut unordered: mpsc::UnboundedReceiver, + mut ws_control: mpsc::UnboundedReceiver, encoder: impl FnOnce(mpsc::UnboundedReceiver, mpsc::UnboundedSender) -> Encoder, ) where T: Into, @@ -1149,7 +1159,7 @@ async fn ws_send_loop_inner( // The number of frames we'll `feed` to the `ws` sink in one iteration // of the `select!` loop. // - // This batching is done to allow control messages appearing on `unordered` + // This batching is done to allow control messages appearing on `ws_control` // to be interleaved with the sending of large messages split across some // number of frames. // @@ -1182,19 +1192,19 @@ async fn ws_send_loop_inner( biased; // Check for control messages or execution errors. - maybe_msg = unordered.recv() => { + maybe_msg = ws_control.recv() => { let Some(msg) = maybe_msg else { break; }; // We shall not send more data after a close frame, - // but keep polling `unordered` so that `ws_client_actor` keeps + // but keep polling `ws_control` so that `ws_client_actor` keeps // waiting for an acknowledgement from the client, // even if it spuriously initiates another close itself. if closed { continue; } match msg { - UnorderedWsMessage::Close(close_frame) => { + WsControlMessage::Close(close_frame) => { log::trace!("intiating close"); // Send outstanding frames until one that has the FIN // bit set. Ensures the client won't receive partial @@ -1232,14 +1242,14 @@ async fn ws_send_loop_inner( // so let senders know. messages.close(); }, - UnorderedWsMessage::Ping(bytes) => { + WsControlMessage::Ping(bytes) => { log::trace!("sending ping"); if let Err(e) = ws.feed(WsMessage::Ping(bytes)).await { log::warn!("error sending ping: {e:#}"); break; } }, - UnorderedWsMessage::Error(err) => { + WsControlMessage::Error(err) => { log::trace!("encoding execution error"); encode_tx .send(err.into()) @@ -1812,17 +1822,17 @@ mod tests { } #[tokio::test] - async fn send_loop_terminates_when_unordered_closed() { + async fn send_loop_terminates_when_ws_control_closed() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state, ClientConfig::for_test(), sink::drain(), messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); @@ -1831,7 +1841,7 @@ mod tests { drop(messages_tx); assert!(is_pending(&mut send_loop).await); - drop(unordered_tx); + drop(ws_control_tx); send_loop.await; } @@ -1839,21 +1849,21 @@ mod tests { async fn send_loop_close_message_closes_state_and_messages() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), sink::drain(), messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); - unordered_tx - .send(UnorderedWsMessage::Close(CloseFrame { - code: WsCloseCode::Away, + ws_control_tx + .send(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "done".into(), })) .unwrap(); @@ -1866,12 +1876,12 @@ mod tests { #[tokio::test] async fn send_loop_terminates_if_sink_cant_be_fed() { let input = [ - Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: WsCloseCode::Away, + Either::Left(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "bah!".into(), })), - Either::Left(UnorderedWsMessage::Ping(Bytes::new())), - Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + Either::Left(WsControlMessage::Ping(Bytes::new())), + Either::Left(WsControlMessage::Error(MessageExecutionError { reducer: None, reducer_id: None, caller_identity: Identity::ZERO, @@ -1892,20 +1902,20 @@ mod tests { for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnfeedableSink, messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); match message { - Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Left(ws_control) => ws_control_tx.send(ws_control).unwrap(), Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; @@ -1915,12 +1925,12 @@ mod tests { #[tokio::test] async fn send_loop_terminates_if_sink_cant_be_flushed() { let input = [ - Either::Left(UnorderedWsMessage::Close(CloseFrame { - code: WsCloseCode::Away, + Either::Left(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "bah!".into(), })), - Either::Left(UnorderedWsMessage::Ping(Bytes::new())), - Either::Left(UnorderedWsMessage::Error(MessageExecutionError { + Either::Left(WsControlMessage::Ping(Bytes::new())), + Either::Left(WsControlMessage::Error(MessageExecutionError { reducer: None, reducer_id: None, caller_identity: Identity::ZERO, @@ -1941,20 +1951,20 @@ mod tests { for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnflushableSink, messages_rx, - unordered_rx, + ws_control_rx, BsatnRowListBuilderPool::new(), ); pin_mut!(send_loop); match message { - Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), + Either::Left(ws_control) => ws_control_tx.send(ws_control).unwrap(), Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; @@ -2032,11 +2042,11 @@ mod tests { let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline()); // Pretend we received a pong immediately after sending a ping, // but only five times. - let unordered_tx = { + let ws_control_tx = { let state = state.clone(); let pings = AtomicUsize::new(0); move |m| { - if let UnorderedWsMessage::Ping(_) = m { + if let WsControlMessage::Ping(_) = m { let n = pings.fetch_add(1, Ordering::Relaxed); if n < 5 { state.set_ponged(); @@ -2056,7 +2066,7 @@ mod tests { ws_idle_timer(idle_rx), tokio::spawn(future::pending()), tokio::spawn(future::pending()), - unordered_tx, + ws_control_tx, ) .await } @@ -2083,10 +2093,10 @@ mod tests { let state = Arc::new(dummy_actor_state()); let (_idle_tx, idle_rx) = watch::channel(state.next_idle_deadline()); - let unordered_tx = { + let ws_control_tx = { let state = state.clone(); move |m| { - if let UnorderedWsMessage::Close(_) = m { + if let WsControlMessage::Close(_) = m { state.close(); } } @@ -2113,7 +2123,7 @@ mod tests { } }), tokio::spawn(future::pending()), - unordered_tx, + ws_control_tx, ) .await }) @@ -2138,15 +2148,15 @@ mod tests { ..<_>::default() })); - let (unordered_tx, mut unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, mut ws_control_rx) = mpsc::unbounded_channel(); let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}"))))); let metric = IntGauge::new("bleep", "unhelpful").unwrap(); - let received = ws_recv_queue(state, unordered_tx, metric.clone(), input) + let received = ws_recv_queue(state, ws_control_tx, metric.clone(), input) .collect::>() .await; - assert_matches!(unordered_rx.recv().await, Some(UnorderedWsMessage::Close(_))); + assert_matches!(ws_control_rx.recv().await, Some(WsControlMessage::Close(_))); // Queue length metric should be zero assert_eq!(metric.get(), 0); // Should have received all of the input. @@ -2160,11 +2170,11 @@ mod tests { ..<_>::default() })); - let (unordered_tx, _) = mpsc::unbounded_channel(); + let (ws_control_tx, _) = mpsc::unbounded_channel(); let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}"))))); let metric = IntGauge::new("bleep", "unhelpful").unwrap(); - let received = ws_recv_queue(state.clone(), unordered_tx, metric.clone(), input) + let received = ws_recv_queue(state.clone(), ws_control_tx, metric.clone(), input) .collect::>() .await; @@ -2180,7 +2190,7 @@ mod tests { let state = Arc::new(dummy_actor_state()); let mut received = Vec::new(); let (messages_tx, messages_rx) = mpsc::channel(1); - let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); + let (ws_control_tx, ws_control_rx) = mpsc::unbounded_channel(); #[derive(From)] enum OutgoingBytes { @@ -2205,24 +2215,24 @@ mod tests { const NUM_CONTROL_FRAMES: usize = 2; let send_loop = tokio::spawn(async move { - ws_send_loop_inner(state, &mut received, messages_rx, unordered_rx, encoder).await; + ws_send_loop_inner(state, &mut received, messages_rx, ws_control_rx, encoder).await; received }); messages_tx.send(Bytes::from_static(&[1; MESSAGE_SIZE])).await.unwrap(); // Yield task to give the send loop a chance to receive the message. tokio::task::yield_now().await; // Send ping, then close. - unordered_tx.send(UnorderedWsMessage::Ping(Bytes::new())).unwrap(); - unordered_tx - .send(UnorderedWsMessage::Close(CloseFrame { - code: WsCloseCode::Away, + ws_control_tx.send(WsControlMessage::Ping(Bytes::new())).unwrap(); + ws_control_tx + .send(WsControlMessage::Close(StCloseFrame { + code: StCloseCode::Away, reason: "we're done".into(), })) .unwrap(); // Shut down the loop. drop(messages_tx); - drop(unordered_tx); + drop(ws_control_tx); let received = send_loop.await.unwrap(); let ping_pos = received diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index 812d03c0701..b23b21c7c76 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -11,9 +11,9 @@ mod message_handlers_v3; pub mod messages; pub use client_connection::{ - ClientConfig, ClientConnection, ClientConnectionReceiver, ClientConnectionSender, ClientSendError, DataMessage, - MeteredDeque, MeteredReceiver, MeteredSender, MeteredUnboundedReceiver, MeteredUnboundedSender, Protocol, - WsVersion, + ClientConfig, ClientConnection, ClientConnectionReceiver, ClientConnectionSender, ClientDisconnectError, + ClientDisconnectSender, ClientSendError, DataMessage, MeteredDeque, MeteredReceiver, MeteredSender, + MeteredUnboundedReceiver, MeteredUnboundedSender, Protocol, WsVersion, }; pub use client_connection_index::ClientActorIndex; pub use message_handlers::MessageHandleError; diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 0a3fa198b01..de1166ad7f5 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -27,8 +27,7 @@ use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use tokio::sync::mpsc::error::{SendError, TrySendError}; -use tokio::sync::{mpsc, oneshot, watch}; -use tokio::task::AbortHandle; +use tokio::sync::{mpsc, watch}; use tracing::{trace, warn}; #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] @@ -262,7 +261,15 @@ pub struct ClientConnectionSender { pub auth: ConnectionAuthCtx, pub config: ClientConfig, sendtx: mpsc::Sender, - abort_handle: AbortHandle, + /// Optional because dummy/test senders are not necessarily backed by a + /// live websocket actor control queue. + /// + /// Production websocket-backed senders receive this at spawn-time so they + /// can request a transport-level close through the websocket control path. + /// Test constructors such as [`ClientConnectionSender::dummy_with_channel`] + /// and [`ClientConnectionSender::dummy`] intentionally leave this as + /// `None` unless a test explicitly wires a control queue. + disconnect_tx: Option, cancelled: AtomicBool, /// Handles on Prometheus metrics related to connections to this database. @@ -309,6 +316,21 @@ impl ClientConnectionMetrics { } } +#[derive(Clone, Debug)] +pub struct ClientDisconnectSender(mpsc::UnboundedSender); + +impl ClientDisconnectSender { + pub fn new(inner: mpsc::UnboundedSender) -> Self { + Self(inner) + } + + pub fn send(&self, close_frame: crate::client::messages::CloseFrame) -> Result<(), ClientDisconnectError> { + self.0 + .send(close_frame) + .map_err(|_| ClientDisconnectError::Disconnected) + } +} + #[derive(Debug, thiserror::Error)] pub enum ClientSendError { #[error("client disconnected")] @@ -317,6 +339,16 @@ pub enum ClientSendError { Cancelled, } +#[derive(Debug, thiserror::Error)] +pub enum ClientDisconnectError { + #[error("client was already cancelled")] + Cancelled, + #[error("disconnect channel disconnected")] + Disconnected, + #[error("disconnect handle is not configured")] + NoDisconnectHandle, +} + impl ClientConnectionSender { pub fn dummy_with_channel( id: ClientActorId, @@ -324,11 +356,6 @@ impl ClientConnectionSender { offset_supply: impl DurableOffsetSupply + 'static, ) -> (Self, ClientConnectionReceiver) { let (sendtx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY_TEST); - // just make something up, it doesn't need to be attached to a real task - let abort_handle = match tokio::runtime::Handle::try_current() { - Ok(h) => h.spawn(async {}).abort_handle(), - Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(), - }; let receiver = ClientConnectionReceiver::new(config.confirmed_reads, MeteredReceiver::new(rx), offset_supply); let cancelled = AtomicBool::new(false); @@ -346,7 +373,7 @@ impl ClientConnectionSender { auth: ConnectionAuthCtx::try_from(dummy_claims).expect("dummy claims should always be valid"), config, sendtx, - abort_handle, + disconnect_tx: None, cancelled, metrics: None, }; @@ -357,10 +384,36 @@ impl ClientConnectionSender { Self::dummy_with_channel(id, config, offset_supply).0 } + #[cfg(test)] + pub(crate) fn dummy_with_disconnect_channel( + id: ClientActorId, + config: ClientConfig, + offset_supply: impl DurableOffsetSupply + 'static, + ) -> ( + Self, + ClientConnectionReceiver, + mpsc::UnboundedReceiver, + ) { + let (mut sender, receiver) = Self::dummy_with_channel(id, config, offset_supply); + let (disconnect_tx, disconnect_rx) = mpsc::unbounded_channel(); + sender.disconnect_tx = Some(ClientDisconnectSender::new(disconnect_tx)); + (sender, receiver, disconnect_rx) + } + pub fn is_cancelled(&self) -> bool { self.cancelled.load(Ordering::Relaxed) } + pub fn disconnect(&self, close_frame: crate::client::messages::CloseFrame) -> Result<(), ClientDisconnectError> { + if self.cancelled.load(Relaxed) { + return Err(ClientDisconnectError::Cancelled); + } + self.disconnect_tx + .as_ref() + .ok_or(ClientDisconnectError::NoDisconnectHandle)? + .send(close_frame) + } + /// Send a message to the client. For data-related messages, you should probably use /// `BroadcastQueue::send` to ensure that the client sees data messages in a consistent order. /// @@ -393,8 +446,13 @@ impl ClientConnectionSender { match self.sendtx.try_send(message) { Err(mpsc::error::TrySendError::Full(_)) => { - // we've hit CLIENT_CHANNEL_CAPACITY messages backed up in - // the channel, so forcibly kick the client + // We've hit `CLIENT_CHANNEL_CAPACITY` messages backed up in the channel, + // so forcibly kick the client. + // + // Mark the sender cancelled first so subsequent ordered sends + // fail fast immediately, then request websocket close using the + // same control-plane close path as any other server-initiated + // disconnect. tracing::warn!( identity = %self.id.identity, connection_id = %self.id.connection_id, @@ -406,8 +464,13 @@ impl ClientConnectionSender { self.id, self.sendtx.capacity(), ); - self.abort_handle.abort(); self.cancelled.store(true, Ordering::Relaxed); + if let Some(disconnect_tx) = &self.disconnect_tx { + let _ = disconnect_tx.send(crate::client::messages::CloseFrame { + code: crate::client::messages::CloseCode::Again, + reason: "client channel capacity exceeded".into(), + }); + } return Err(ClientSendError::Cancelled); } Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected), @@ -803,6 +866,7 @@ impl ClientConnection { config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, + disconnect_tx: Option, actor: impl FnOnce(ClientConnection, ClientConnectionReceiver) -> Fut, _proof_of_client_connected_call: Connected, ) -> ClientConnection @@ -817,25 +881,9 @@ impl ClientConnection { let (sendtx, sendrx) = mpsc::channel::(CLIENT_CHANNEL_CAPACITY); - let (fut_tx, fut_rx) = oneshot::channel::(); - // weird dance so that we can get an abort_handle into ClientConnection let module_info = module.info.clone(); let database_identity = module_info.database_identity; let client_identity = id.identity; - let abort_handle = tokio::spawn(async move { - let Ok(fut) = fut_rx.await else { return }; - - let _gauge_guard = module_info.metrics.connected_clients.inc_scope(); - module_info.metrics.ws_clients_spawned.inc(); - scopeguard::defer! { - let database_identity = module_info.database_identity; - log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); - module_info.metrics.ws_clients_aborted.inc(); - }; - - fut.await - }) - .abort_handle(); let metrics = ClientConnectionMetrics::new(database_identity, config.protocol); let receiver = ClientConnectionReceiver::new( @@ -849,7 +897,7 @@ impl ClientConnection { auth, config, sendtx, - abort_handle, + disconnect_tx, cancelled: AtomicBool::new(false), metrics: Some(metrics), }); @@ -861,8 +909,17 @@ impl ClientConnection { }; let actor_fut = actor(this.clone(), receiver); - // if this fails, the actor() function called .abort(), which like... okay, I guess? - let _ = fut_tx.send(actor_fut); + tokio::spawn(async move { + let _gauge_guard = module_info.metrics.connected_clients.inc_scope(); + module_info.metrics.ws_clients_spawned.inc(); + scopeguard::defer! { + let database_identity = module_info.database_identity; + log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); + module_info.metrics.ws_clients_aborted.inc(); + }; + + actor_fut.await + }); this } @@ -1433,4 +1490,41 @@ mod tests { offset.mark_durable_at(3); assert_received_update(receiver.recv()).await; } + + #[test] + fn disconnect_without_handle_returns_no_disconnect_handle() { + let sender = ClientConnectionSender::dummy( + ClientActorId::for_test(Identity::ZERO), + ClientConfig::for_test(), + NoneDurableOffset, + ); + let res = sender.disconnect(crate::client::messages::CloseFrame { + code: crate::client::messages::CloseCode::Away, + reason: "disconnect".into(), + }); + assert_matches!(res, Err(ClientDisconnectError::NoDisconnectHandle)); + } + + #[test] + fn send_overflow_marks_cancelled_and_emits_disconnect() { + let (sender, _receiver, mut disconnect_rx) = ClientConnectionSender::dummy_with_disconnect_channel( + ClientActorId::for_test(Identity::ZERO), + ClientConfig::for_test(), + NoneDurableOffset, + ); + + for _ in 0..CLIENT_CHANNEL_CAPACITY_TEST { + sender.send_message(None, empty_tx_update()).unwrap(); + } + + let res = sender.send_message(None, empty_tx_update()); + assert_matches!(res, Err(ClientSendError::Cancelled)); + assert!(sender.is_cancelled()); + + let close_frame = disconnect_rx.try_recv().expect("expected disconnect request"); + assert_eq!(close_frame.code, crate::client::messages::CloseCode::Again); + + let res = sender.send_message(None, empty_tx_update()); + assert_matches!(res, Err(ClientSendError::Cancelled)); + } } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index dc9d2e232c4..1639f258acd 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -2708,13 +2708,13 @@ impl ModuleHost { Timestamp::UNIX_EPOCH, TimeDuration::ZERO, ), - Some(code) => self.subscriptions().disconnect_client( - sender, - CloseFrame { + Some(code) => { + let _ = sender.disconnect(CloseFrame { code, reason: err.to_string().into(), - }, - ), + }); + Ok(()) + } }, } } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 381fc3f965b..d748fdd09ab 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -8,7 +8,7 @@ use super::query::compile_query_with_hashes; use super::tx::DeltaTx; use super::{collect_table_update, TableUpdateType}; use crate::client::messages::{ - CloseFrame, ProcedureResultMessage, SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, + ProcedureResultMessage, SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage, }; use crate::client::{ClientActorId, ClientConnectionSender, Protocol, WsVersion}; @@ -1127,14 +1127,6 @@ impl ModuleSubscriptions { .send_client_message_v2(recipient, tx_offset, message) } - pub fn disconnect_client( - &self, - recipient: Arc, - close_frame: CloseFrame, - ) -> Result<(), BroadcastError> { - todo!() - } - pub fn send_one_off_query_message_v2( &self, recipient: Arc,